core.models.gemnet_oc.layers.interaction_block#
Copyright (c) Meta, Inc. and its affiliates. This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Classes#
Interaction block for GemNet-Q/dQ. |
|
Quadruplet-based message passing block. |
|
Triplet-based message passing block. |
|
Pair-based message passing block. |
Module Contents#
- class core.models.gemnet_oc.layers.interaction_block.InteractionBlock(emb_size_atom: int, emb_size_edge: int, emb_size_trip_in: int, emb_size_trip_out: int, emb_size_quad_in: int, emb_size_quad_out: int, emb_size_a2a_in: int, emb_size_a2a_out: int, emb_size_rbf: int, emb_size_cbf: int, emb_size_sbf: int, num_before_skip: int, num_after_skip: int, num_concat: int, num_atom: int, num_atom_emb_layers: int = 0, quad_interaction: bool = False, atom_edge_interaction: bool = False, edge_atom_interaction: bool = False, atom_interaction: bool = False, activation=None)#
Bases:
torch.nn.Module
Interaction block for GemNet-Q/dQ.
- Parameters:
emb_size_atom (int) – Embedding size of the atoms.
emb_size_edge (int) – Embedding size of the edges.
emb_size_trip_in (int) – (Down-projected) embedding size of the quadruplet edge embeddings before the bilinear layer.
emb_size_trip_out (int) – (Down-projected) embedding size of the quadruplet edge embeddings after the bilinear layer.
emb_size_quad_in (int) – (Down-projected) embedding size of the quadruplet edge embeddings before the bilinear layer.
emb_size_quad_out (int) – (Down-projected) embedding size of the quadruplet edge embeddings after the bilinear layer.
emb_size_a2a_in (int) – Embedding size in the atom interaction before the bilinear layer.
emb_size_a2a_out (int) – Embedding size in the atom interaction after the bilinear layer.
emb_size_rbf (int) – Embedding size of the radial basis transformation.
emb_size_cbf (int) – Embedding size of the circular basis transformation (one angle).
emb_size_sbf (int) – Embedding size of the spherical basis transformation (two angles).
num_before_skip (int) – Number of residual blocks before the first skip connection.
num_after_skip (int) – Number of residual blocks after the first skip connection.
num_concat (int) – Number of residual blocks after the concatenation.
num_atom (int) – Number of residual blocks in the atom embedding blocks.
num_atom_emb_layers (int) – Number of residual blocks for transforming atom embeddings.
quad_interaction (bool) – Whether to use quadruplet interactions.
atom_edge_interaction (bool) – Whether to use atom-to-edge interactions.
edge_atom_interaction (bool) – Whether to use edge-to-atom interactions.
atom_interaction (bool) – Whether to use atom-to-atom interactions.
activation (str) – Name of the activation function to use in the dense layers.
- dense_ca#
- trip_interaction#
- layers_before_skip#
- layers_after_skip#
- atom_emb_layers#
- atom_update#
- concat_layer#
- residual_m#
- inv_sqrt_2#
- inv_sqrt_num_eint#
- inv_sqrt_num_aint#
- forward(h, m, bases_qint, bases_e2e, bases_a2e, bases_e2a, basis_a2a_rad, basis_atom_update, edge_index_main, a2ee2a_graph, a2a_graph, id_swap, trip_idx_e2e, trip_idx_a2e, trip_idx_e2a, quad_idx)#
- Returns:
h (torch.Tensor, shape=(nEdges, emb_size_atom)) – Atom embeddings.
m (torch.Tensor, shape=(nEdges, emb_size_edge)) – Edge embeddings (c->a).
- class core.models.gemnet_oc.layers.interaction_block.QuadrupletInteraction(emb_size_edge, emb_size_quad_in, emb_size_quad_out, emb_size_rbf, emb_size_cbf, emb_size_sbf, symmetric_mp=True, activation=None)#
Bases:
torch.nn.Module
Quadruplet-based message passing block.
- Parameters:
emb_size_edge (int) – Embedding size of the edges.
emb_size_quad_in (int) – (Down-projected) embedding size of the quadruplet edge embeddings before the bilinear layer.
emb_size_quad_out (int) – (Down-projected) embedding size of the quadruplet edge embeddings after the bilinear layer.
emb_size_rbf (int) – Embedding size of the radial basis transformation.
emb_size_cbf (int) – Embedding size of the circular basis transformation (one angle).
emb_size_sbf (int) – Embedding size of the spherical basis transformation (two angles).
symmetric_mp (bool) – Whether to use symmetric message passing and update the edges in both directions.
activation (str) – Name of the activation function to use in the dense layers.
- symmetric_mp#
- dense_db#
- mlp_rbf#
- scale_rbf#
- mlp_cbf#
- scale_cbf#
- mlp_sbf#
- scale_sbf_sum#
- down_projection#
- up_projection_ca#
- inv_sqrt_2#
- forward(m, bases, idx, id_swap)#
- Returns:
m – Edge embeddings (c->a).
- Return type:
torch.Tensor, shape=(nEdges, emb_size_edge)
- class core.models.gemnet_oc.layers.interaction_block.TripletInteraction(emb_size_in: int, emb_size_out: int, emb_size_trip_in: int, emb_size_trip_out: int, emb_size_rbf: int, emb_size_cbf: int, symmetric_mp: bool = True, swap_output: bool = True, activation=None)#
Bases:
torch.nn.Module
Triplet-based message passing block.
- Parameters:
emb_size_in (int) – Embedding size of the input embeddings.
emb_size_out (int) – Embedding size of the output embeddings.
emb_size_trip_in (int) – (Down-projected) embedding size of the quadruplet edge embeddings before the bilinear layer.
emb_size_trip_out (int) – (Down-projected) embedding size of the quadruplet edge embeddings after the bilinear layer.
emb_size_rbf (int) – Embedding size of the radial basis transformation.
emb_size_cbf (int) – Embedding size of the circular basis transformation (one angle).
symmetric_mp (bool) – Whether to use symmetric message passing and update the edges in both directions.
swap_output (bool) – Whether to swap the output embedding directions. Only relevant if symmetric_mp is False.
activation (str) – Name of the activation function to use in the dense layers.
- symmetric_mp#
- swap_output#
- dense_ba#
- mlp_rbf#
- scale_rbf#
- mlp_cbf#
- scale_cbf_sum#
- down_projection#
- up_projection_ca#
- inv_sqrt_2#
- forward(m, bases, idx, id_swap, expand_idx=None, idx_agg2=None, idx_agg2_inner=None, agg2_out_size=None)#
- Returns:
m – Edge embeddings.
- Return type:
torch.Tensor, shape=(nEdges, emb_size_edge)
- class core.models.gemnet_oc.layers.interaction_block.PairInteraction(emb_size_atom, emb_size_pair_in, emb_size_pair_out, emb_size_rbf, activation=None)#
Bases:
torch.nn.Module
Pair-based message passing block.
- Parameters:
emb_size_atom (int) – Embedding size of the atoms.
emb_size_pair_in (int) – Embedding size of the atom pairs before the bilinear layer.
emb_size_pair_out (int) – Embedding size of the atom pairs after the bilinear layer.
emb_size_rbf (int) – Embedding size of the radial basis transformation.
activation (str) – Name of the activation function to use in the dense layers.
- bilinear#
- scale_rbf_sum#
- down_projection#
- up_projection#
- inv_sqrt_2#
- forward(h, rad_basis, edge_index, target_neighbor_idx)#
- Returns:
h – Atom embeddings.
- Return type:
torch.Tensor, shape=(num_atoms, emb_size_atom)