core.models.gemnet.layers.atom_update_block#
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Classes#
Aggregate the message embeddings of the atoms |
|
Combines the atom update block and subsequent final dense layer. |
Module Contents#
- class core.models.gemnet.layers.atom_update_block.AtomUpdateBlock(emb_size_atom: int, emb_size_edge: int, emb_size_rbf: int, nHidden: int, activation=None, name: str = 'atom_update')#
Bases:
torch.nn.Module
Aggregate the message embeddings of the atoms
- Parameters:
emb_size_atom (int) – Embedding size of the atoms.
emb_size_atom – Embedding size of the edges.
nHidden (int) – Number of residual blocks.
activation (callable/str) – Name of the activation function to use in the dense layers.
- name#
- dense_rbf#
- scale_sum#
- layers#
- get_mlp(units_in, units, nHidden, activation)#
- forward(h, m, rbf, id_j)#
- Returns:
h – Atom embedding.
- Return type:
torch.Tensor, shape=(nAtoms, emb_size_atom)
- class core.models.gemnet.layers.atom_update_block.OutputBlock(emb_size_atom: int, emb_size_edge: int, emb_size_rbf: int, nHidden: int, num_targets: int, activation=None, direct_forces: bool = True, output_init: str = 'HeOrthogonal', name: str = 'output', **kwargs)#
Bases:
AtomUpdateBlock
Combines the atom update block and subsequent final dense layer.
- Parameters:
emb_size_atom (int) – Embedding size of the atoms.
emb_size_atom – Embedding size of the edges.
nHidden (int) – Number of residual blocks.
num_targets (int) – Number of targets.
activation (str) – Name of the activation function to use in the dense layers except for the final dense layer.
direct_forces (bool) – If true directly predict forces without taking the gradient of the energy potential.
output_init (int) – Kernel initializer of the final dense layer.
- output_init#
- direct_forces#
- seq_energy#
- out_energy#
- reset_parameters() None #
- forward(h, m, rbf, id_j)#
- Returns:
(E, F) (tuple)
- E (torch.Tensor, shape=(nAtoms, num_targets))
- F (torch.Tensor, shape=(nEdges, num_targets))
Energy and force prediction