core.models.gemnet_oc.layers.embedding_block#

Copyright (c) Meta, Inc. and its affiliates. This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Classes#

AtomEmbedding

Initial atom embeddings based on the atom type

EdgeEmbedding

Edge embedding based on the concatenation of atom embeddings

Module Contents#

class core.models.gemnet_oc.layers.embedding_block.AtomEmbedding(emb_size: int, num_elements: int)#

Bases: torch.nn.Module

Initial atom embeddings based on the atom type

Parameters:

emb_size (int) – Atom embeddings size

emb_size#
embeddings#
forward(Z) torch.Tensor#
Returns:

h – Atom embeddings.

Return type:

torch.Tensor, shape=(nAtoms, emb_size)

class core.models.gemnet_oc.layers.embedding_block.EdgeEmbedding(atom_features: int, edge_features: int, out_features: int, activation: str | None = None)#

Bases: torch.nn.Module

Edge embedding based on the concatenation of atom embeddings and a subsequent dense layer.

Parameters:
  • atom_features (int) – Embedding size of the atom embedding.

  • edge_features (int) – Embedding size of the input edge embedding.

  • out_features (int) – Embedding size after the dense layer.

  • activation (str) – Activation function used in the dense layer.

dense#
forward(h: torch.Tensor, m: torch.Tensor, edge_index) torch.Tensor#
Parameters:
  • h (torch.Tensor, shape (num_atoms, atom_features)) – Atom embeddings.

  • m (torch.Tensor, shape (num_edges, edge_features)) – Radial basis in embedding block, edge embedding in interaction block.

Returns:

m_st – Edge embeddings.

Return type:

torch.Tensor, shape=(nEdges, emb_size)