core.models.gemnet_oc.layers.atom_update_block#

Copyright (c) Meta, Inc. and its affiliates. This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Classes#

AtomUpdateBlock

Aggregate the message embeddings of the atoms

OutputBlock

Combines the atom update block and subsequent final dense layer.

Module Contents#

class core.models.gemnet_oc.layers.atom_update_block.AtomUpdateBlock(emb_size_atom: int, emb_size_edge: int, emb_size_rbf: int, nHidden: int, activation=None)#

Bases: torch.nn.Module

Aggregate the message embeddings of the atoms

Parameters:
  • emb_size_atom (int) – Embedding size of the atoms.

  • emb_size_edge (int) – Embedding size of the edges.

  • emb_size_rbf (int) – Embedding size of the radial basis.

  • nHidden (int) – Number of residual blocks.

  • activation (callable/str) – Name of the activation function to use in the dense layers.

dense_rbf#
scale_sum#
layers#
get_mlp(units_in: int, units: int, nHidden: int, activation)#
forward(h: torch.Tensor, m, basis_rad, idx_atom)#
Returns:

h – Atom embedding.

Return type:

torch.Tensor, shape=(nAtoms, emb_size_atom)

class core.models.gemnet_oc.layers.atom_update_block.OutputBlock(emb_size_atom: int, emb_size_edge: int, emb_size_rbf: int, nHidden: int, nHidden_afteratom: int, activation: str | None = None, direct_forces: bool = True)#

Bases: AtomUpdateBlock

Combines the atom update block and subsequent final dense layer.

Parameters:
  • emb_size_atom (int) – Embedding size of the atoms.

  • emb_size_edge (int) – Embedding size of the edges.

  • emb_size_rbf (int) – Embedding size of the radial basis.

  • nHidden (int) – Number of residual blocks before adding the atom embedding.

  • nHidden_afteratom (int) – Number of residual blocks after adding the atom embedding.

  • activation (str) – Name of the activation function to use in the dense layers.

  • direct_forces (bool) – If true directly predict forces, i.e. without taking the gradient of the energy potential.

direct_forces#
seq_energy_pre#
forward(h: torch.Tensor, m: torch.Tensor, basis_rad, idx_atom)#
Returns:

  • torch.Tensor, shape=(nAtoms, emb_size_atom) – Output atom embeddings.

  • torch.Tensor, shape=(nEdges, emb_size_edge) – Output edge embeddings.