core.models.gemnet_gp.layers.atom_update_block#

Copyright (c) Meta, Inc. and its affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Classes#

AtomUpdateBlock

Aggregate the message embeddings of the atoms

OutputBlock

Combines the atom update block and subsequent final dense layer.

Functions#

scatter_sum(→ torch.Tensor)

Clone of torch_scatter.scatter_sum but without in-place operations

Module Contents#

core.models.gemnet_gp.layers.atom_update_block.scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: torch.Tensor | None = None, dim_size: int | None = None) torch.Tensor#

Clone of torch_scatter.scatter_sum but without in-place operations

class core.models.gemnet_gp.layers.atom_update_block.AtomUpdateBlock(emb_size_atom: int, emb_size_edge: int, emb_size_rbf: int, nHidden: int, activation: str | None = None, name: str = 'atom_update')#

Bases: torch.nn.Module

Aggregate the message embeddings of the atoms

Parameters:
  • emb_size_atom (int) – Embedding size of the atoms.

  • emb_size_atom – Embedding size of the edges.

  • nHidden (int) – Number of residual blocks.

  • activation (callable/str) – Name of the activation function to use in the dense layers.

name#
dense_rbf#
scale_sum#
layers#
get_mlp(units_in: int, units: int, nHidden: int, activation: str | None)#
forward(nAtoms: int, m: int, rbf, id_j)#
Returns:

h – Atom embedding.

Return type:

torch.Tensor, shape=(nAtoms, emb_size_atom)

class core.models.gemnet_gp.layers.atom_update_block.OutputBlock(emb_size_atom: int, emb_size_edge: int, emb_size_rbf: int, nHidden: int, num_targets: int, activation: str | None = None, direct_forces: bool = True, output_init: str = 'HeOrthogonal', name: str = 'output', **kwargs)#

Bases: AtomUpdateBlock

Combines the atom update block and subsequent final dense layer.

Parameters:
  • emb_size_atom (int) – Embedding size of the atoms.

  • emb_size_atom – Embedding size of the edges.

  • nHidden (int) – Number of residual blocks.

  • num_targets (int) – Number of targets.

  • activation (str) – Name of the activation function to use in the dense layers except for the final dense layer.

  • direct_forces (bool) – If true directly predict forces without taking the gradient of the energy potential.

  • output_init (int) – Kernel initializer of the final dense layer.

dense_rbf_F: core.models.gemnet_gp.layers.base_layers.Dense#
out_forces: core.models.gemnet_gp.layers.base_layers.Dense#
out_energy: core.models.gemnet_gp.layers.base_layers.Dense#
output_init#
direct_forces#
seq_energy#
reset_parameters() None#
forward(nAtoms: int, m, rbf, id_j: torch.Tensor)#
Returns:

  • (E, F) (tuple)

  • - E (torch.Tensor, shape=(nAtoms, num_targets))

  • - F (torch.Tensor, shape=(nEdges, num_targets))

  • Energy and force prediction