core.models.gemnet_gp.layers.interaction_block#

Copyright (c) Meta, Inc. and its affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Classes#

InteractionBlockTripletsOnly

Interaction block for GemNet-T/dT.

TripletInteraction

Triplet-based message passing block.

Module Contents#

class core.models.gemnet_gp.layers.interaction_block.InteractionBlockTripletsOnly(emb_size_atom: int, emb_size_edge: int, emb_size_trip: int, emb_size_rbf: int, emb_size_cbf: int, emb_size_bil_trip: int, num_before_skip: int, num_after_skip: int, num_concat: int, num_atom: int, activation: str | None = None, name: str = 'Interaction')#

Bases: torch.nn.Module

Interaction block for GemNet-T/dT.

Parameters:
  • emb_size_atom (int) – Embedding size of the atoms.

  • emb_size_edge (int) – Embedding size of the edges.

  • emb_size_trip (int) – (Down-projected) Embedding size in the triplet message passing block.

  • emb_size_rbf (int) – Embedding size of the radial basis transformation.

  • emb_size_cbf (int) – Embedding size of the circular basis transformation (one angle).

  • emb_size_bil_trip (int) – Embedding size of the edge embeddings in the triplet-based message passing block after the bilinear layer.

  • num_before_skip (int) – Number of residual blocks before the first skip connection.

  • num_after_skip (int) – Number of residual blocks after the first skip connection.

  • num_concat (int) – Number of residual blocks after the concatenation.

  • num_atom (int) – Number of residual blocks in the atom embedding blocks.

  • activation (str) – Name of the activation function to use in the dense layers except for the final dense layer.

name#
dense_ca#
trip_interaction#
layers_before_skip#
layers_after_skip#
atom_update#
concat_layer#
residual_m#
inv_sqrt_2#
forward(h: torch.Tensor, m: torch.Tensor, rbf3, cbf3, id3_ragged_idx, id_swap, id3_ba, id3_ca, rbf_h, idx_s, idx_t, edge_offset, Kmax, nAtoms)#
Returns:

  • h (torch.Tensor, shape=(nEdges, emb_size_atom)) – Atom embeddings.

  • m (torch.Tensor, shape=(nEdges, emb_size_edge)) – Edge embeddings (c->a).

  • Node (h)

  • Edge (m, rbf3, id_swap, rbf_h, idx_s, idx_t, cbf3[0], cbf3[1] (dense))

  • Triplet (id3_ragged_idx, id3_ba, id3_ca)

class core.models.gemnet_gp.layers.interaction_block.TripletInteraction(emb_size_edge: int, emb_size_trip: int, emb_size_bilinear: int, emb_size_rbf: int, emb_size_cbf: int, activation: str | None = None, name: str = 'TripletInteraction', **kwargs)#

Bases: torch.nn.Module

Triplet-based message passing block.

Parameters:
  • emb_size_edge (int) – Embedding size of the edges.

  • emb_size_trip (int) – (Down-projected) Embedding size of the edge embeddings after the hadamard product with rbf.

  • emb_size_bilinear (int) – Embedding size of the edge embeddings after the bilinear layer.

  • emb_size_rbf (int) – Embedding size of the radial basis transformation.

  • emb_size_cbf (int) – Embedding size of the circular basis transformation (one angle).

  • activation (str) – Name of the activation function to use in the dense layers except for the final dense layer.

name#
dense_ba#
mlp_rbf#
scale_rbf#
mlp_cbf#
scale_cbf_sum#
down_projection#
up_projection_ca#
up_projection_ac#
inv_sqrt_2#
forward(m: torch.Tensor, rbf3, cbf3, id3_ragged_idx, id_swap, id3_ba, id3_ca, edge_offset, Kmax)#
Returns:

m – Edge embeddings (c->a).

Return type:

torch.Tensor, shape=(nEdges, emb_size_edge)