core.models.gemnet_gp.layers.embedding_block#

Copyright (c) Meta, Inc. and its affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Classes#

AtomEmbedding

Initial atom embeddings based on the atom type

EdgeEmbedding

Edge embedding based on the concatenation of atom embeddings and subsequent dense layer.

Module Contents#

class core.models.gemnet_gp.layers.embedding_block.AtomEmbedding(emb_size: int)#

Bases: torch.nn.Module

Initial atom embeddings based on the atom type

Parameters:

emb_size (int) – Atom embeddings size

emb_size#
embeddings#
forward(Z) torch.Tensor#
Returns:

h – Atom embeddings.

Return type:

torch.Tensor, shape=(nAtoms, emb_size)

class core.models.gemnet_gp.layers.embedding_block.EdgeEmbedding(atom_features: int, edge_features: int, num_out_features: int, activation: str | None = None)#

Bases: torch.nn.Module

Edge embedding based on the concatenation of atom embeddings and subsequent dense layer.

Parameters:
  • emb_size (int) – Embedding size after the dense layer.

  • activation (str) – Activation function used in the dense layer.

in_features#
dense#
forward(h, m_rbf, idx_s, idx_t) torch.Tensor#
Parameters:
  • h

  • m_rbf (shape (nEdges, nFeatures)) – in embedding block: m_rbf = rbf ; In interaction block: m_rbf = m_st

  • idx_s

  • idx_t

Returns:

m_st – Edge embeddings.

Return type:

torch.Tensor, shape=(nEdges, emb_size)