core.models.gemnet_oc.layers.embedding_block#
Copyright (c) Meta, Inc. and its affiliates. This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Classes#
Initial atom embeddings based on the atom type |
|
Edge embedding based on the concatenation of atom embeddings |
Module Contents#
- class core.models.gemnet_oc.layers.embedding_block.AtomEmbedding(emb_size: int, num_elements: int)#
Bases:
torch.nn.Module
Initial atom embeddings based on the atom type
- Parameters:
emb_size (int) – Atom embeddings size
- emb_size#
- embeddings#
- forward(Z) torch.Tensor #
- Returns:
h – Atom embeddings.
- Return type:
torch.Tensor, shape=(nAtoms, emb_size)
- class core.models.gemnet_oc.layers.embedding_block.EdgeEmbedding(atom_features: int, edge_features: int, out_features: int, activation: str | None = None)#
Bases:
torch.nn.Module
Edge embedding based on the concatenation of atom embeddings and a subsequent dense layer.
- Parameters:
atom_features (int) – Embedding size of the atom embedding.
edge_features (int) – Embedding size of the input edge embedding.
out_features (int) – Embedding size after the dense layer.
activation (str) – Activation function used in the dense layer.
- dense#
- forward(h: torch.Tensor, m: torch.Tensor, edge_index) torch.Tensor #
- Parameters:
h (torch.Tensor, shape (num_atoms, atom_features)) – Atom embeddings.
m (torch.Tensor, shape (num_edges, edge_features)) – Radial basis in embedding block, edge embedding in interaction block.
- Returns:
m_st – Edge embeddings.
- Return type:
torch.Tensor, shape=(nEdges, emb_size)