core.models.gemnet.layers.interaction_block#
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Classes#
Interaction block for GemNet-T/dT. |
|
Triplet-based message passing block. |
Module Contents#
- class core.models.gemnet.layers.interaction_block.InteractionBlockTripletsOnly(emb_size_atom: int, emb_size_edge: int, emb_size_trip: int, emb_size_rbf: int, emb_size_cbf: int, emb_size_bil_trip: int, num_before_skip: int, num_after_skip: int, num_concat: int, num_atom: int, activation: str | None = None, name: str = 'Interaction')#
Bases:
torch.nn.Module
Interaction block for GemNet-T/dT.
- Parameters:
emb_size_atom (int) – Embedding size of the atoms.
emb_size_edge (int) – Embedding size of the edges.
emb_size_trip (int) – (Down-projected) Embedding size in the triplet message passing block.
emb_size_rbf (int) – Embedding size of the radial basis transformation.
emb_size_cbf (int) – Embedding size of the circular basis transformation (one angle).
emb_size_bil_trip (int) – Embedding size of the edge embeddings in the triplet-based message passing block after the bilinear layer.
num_before_skip (int) – Number of residual blocks before the first skip connection.
num_after_skip (int) – Number of residual blocks after the first skip connection.
num_concat (int) – Number of residual blocks after the concatenation.
num_atom (int) – Number of residual blocks in the atom embedding blocks.
activation (str) – Name of the activation function to use in the dense layers except for the final dense layer.
- name#
- dense_ca#
- trip_interaction#
- layers_before_skip#
- layers_after_skip#
- atom_update#
- concat_layer#
- residual_m#
- inv_sqrt_2#
- forward(h: torch.Tensor, m: torch.Tensor, rbf3, cbf3, id3_ragged_idx, id_swap, id3_ba, id3_ca, rbf_h, idx_s, idx_t)#
- Returns:
h (torch.Tensor, shape=(nEdges, emb_size_atom)) – Atom embeddings.
m (torch.Tensor, shape=(nEdges, emb_size_edge)) – Edge embeddings (c->a).
- class core.models.gemnet.layers.interaction_block.TripletInteraction(emb_size_edge: int, emb_size_trip: int, emb_size_bilinear: int, emb_size_rbf: int, emb_size_cbf: int, activation: str | None = None, name: str = 'TripletInteraction', **kwargs)#
Bases:
torch.nn.Module
Triplet-based message passing block.
- Parameters:
emb_size_edge (int) – Embedding size of the edges.
emb_size_trip (int) – (Down-projected) Embedding size of the edge embeddings after the hadamard product with rbf.
emb_size_bilinear (int) – Embedding size of the edge embeddings after the bilinear layer.
emb_size_rbf (int) – Embedding size of the radial basis transformation.
emb_size_cbf (int) – Embedding size of the circular basis transformation (one angle).
activation (str) – Name of the activation function to use in the dense layers except for the final dense layer.
- name#
- dense_ba#
- mlp_rbf#
- scale_rbf#
- mlp_cbf#
- scale_cbf_sum#
- down_projection#
- up_projection_ca#
- up_projection_ac#
- inv_sqrt_2#
- forward(m: torch.Tensor, rbf3, cbf3, id3_ragged_idx, id_swap, id3_ba, id3_ca)#
- Returns:
m – Edge embeddings (c->a).
- Return type:
torch.Tensor, shape=(nEdges, emb_size_edge)