core.models.gemnet_oc.layers.interaction_block#

Copyright (c) Meta, Inc. and its affiliates. This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Classes#

InteractionBlock

Interaction block for GemNet-Q/dQ.

QuadrupletInteraction

Quadruplet-based message passing block.

TripletInteraction

Triplet-based message passing block.

PairInteraction

Pair-based message passing block.

Module Contents#

class core.models.gemnet_oc.layers.interaction_block.InteractionBlock(emb_size_atom: int, emb_size_edge: int, emb_size_trip_in: int, emb_size_trip_out: int, emb_size_quad_in: int, emb_size_quad_out: int, emb_size_a2a_in: int, emb_size_a2a_out: int, emb_size_rbf: int, emb_size_cbf: int, emb_size_sbf: int, num_before_skip: int, num_after_skip: int, num_concat: int, num_atom: int, num_atom_emb_layers: int = 0, quad_interaction: bool = False, atom_edge_interaction: bool = False, edge_atom_interaction: bool = False, atom_interaction: bool = False, activation=None)#

Bases: torch.nn.Module

Interaction block for GemNet-Q/dQ.

Parameters:
  • emb_size_atom (int) – Embedding size of the atoms.

  • emb_size_edge (int) – Embedding size of the edges.

  • emb_size_trip_in (int) – (Down-projected) embedding size of the quadruplet edge embeddings before the bilinear layer.

  • emb_size_trip_out (int) – (Down-projected) embedding size of the quadruplet edge embeddings after the bilinear layer.

  • emb_size_quad_in (int) – (Down-projected) embedding size of the quadruplet edge embeddings before the bilinear layer.

  • emb_size_quad_out (int) – (Down-projected) embedding size of the quadruplet edge embeddings after the bilinear layer.

  • emb_size_a2a_in (int) – Embedding size in the atom interaction before the bilinear layer.

  • emb_size_a2a_out (int) – Embedding size in the atom interaction after the bilinear layer.

  • emb_size_rbf (int) – Embedding size of the radial basis transformation.

  • emb_size_cbf (int) – Embedding size of the circular basis transformation (one angle).

  • emb_size_sbf (int) – Embedding size of the spherical basis transformation (two angles).

  • num_before_skip (int) – Number of residual blocks before the first skip connection.

  • num_after_skip (int) – Number of residual blocks after the first skip connection.

  • num_concat (int) – Number of residual blocks after the concatenation.

  • num_atom (int) – Number of residual blocks in the atom embedding blocks.

  • num_atom_emb_layers (int) – Number of residual blocks for transforming atom embeddings.

  • quad_interaction (bool) – Whether to use quadruplet interactions.

  • atom_edge_interaction (bool) – Whether to use atom-to-edge interactions.

  • edge_atom_interaction (bool) – Whether to use edge-to-atom interactions.

  • atom_interaction (bool) – Whether to use atom-to-atom interactions.

  • activation (str) – Name of the activation function to use in the dense layers.

dense_ca#
trip_interaction#
layers_before_skip#
layers_after_skip#
atom_emb_layers#
atom_update#
concat_layer#
residual_m#
inv_sqrt_2#
num_eint#
inv_sqrt_num_eint#
num_aint#
inv_sqrt_num_aint#
forward(h, m, bases_qint, bases_e2e, bases_a2e, bases_e2a, basis_a2a_rad, basis_atom_update, edge_index_main, a2ee2a_graph, a2a_graph, id_swap, trip_idx_e2e, trip_idx_a2e, trip_idx_e2a, quad_idx)#
Returns:

  • h (torch.Tensor, shape=(nEdges, emb_size_atom)) – Atom embeddings.

  • m (torch.Tensor, shape=(nEdges, emb_size_edge)) – Edge embeddings (c->a).

class core.models.gemnet_oc.layers.interaction_block.QuadrupletInteraction(emb_size_edge, emb_size_quad_in, emb_size_quad_out, emb_size_rbf, emb_size_cbf, emb_size_sbf, symmetric_mp=True, activation=None)#

Bases: torch.nn.Module

Quadruplet-based message passing block.

Parameters:
  • emb_size_edge (int) – Embedding size of the edges.

  • emb_size_quad_in (int) – (Down-projected) embedding size of the quadruplet edge embeddings before the bilinear layer.

  • emb_size_quad_out (int) – (Down-projected) embedding size of the quadruplet edge embeddings after the bilinear layer.

  • emb_size_rbf (int) – Embedding size of the radial basis transformation.

  • emb_size_cbf (int) – Embedding size of the circular basis transformation (one angle).

  • emb_size_sbf (int) – Embedding size of the spherical basis transformation (two angles).

  • symmetric_mp (bool) – Whether to use symmetric message passing and update the edges in both directions.

  • activation (str) – Name of the activation function to use in the dense layers.

symmetric_mp#
dense_db#
mlp_rbf#
scale_rbf#
mlp_cbf#
scale_cbf#
mlp_sbf#
scale_sbf_sum#
down_projection#
up_projection_ca#
inv_sqrt_2#
forward(m, bases, idx, id_swap)#
Returns:

m – Edge embeddings (c->a).

Return type:

torch.Tensor, shape=(nEdges, emb_size_edge)

class core.models.gemnet_oc.layers.interaction_block.TripletInteraction(emb_size_in: int, emb_size_out: int, emb_size_trip_in: int, emb_size_trip_out: int, emb_size_rbf: int, emb_size_cbf: int, symmetric_mp: bool = True, swap_output: bool = True, activation=None)#

Bases: torch.nn.Module

Triplet-based message passing block.

Parameters:
  • emb_size_in (int) – Embedding size of the input embeddings.

  • emb_size_out (int) – Embedding size of the output embeddings.

  • emb_size_trip_in (int) – (Down-projected) embedding size of the quadruplet edge embeddings before the bilinear layer.

  • emb_size_trip_out (int) – (Down-projected) embedding size of the quadruplet edge embeddings after the bilinear layer.

  • emb_size_rbf (int) – Embedding size of the radial basis transformation.

  • emb_size_cbf (int) – Embedding size of the circular basis transformation (one angle).

  • symmetric_mp (bool) – Whether to use symmetric message passing and update the edges in both directions.

  • swap_output (bool) – Whether to swap the output embedding directions. Only relevant if symmetric_mp is False.

  • activation (str) – Name of the activation function to use in the dense layers.

symmetric_mp#
swap_output#
dense_ba#
mlp_rbf#
scale_rbf#
mlp_cbf#
scale_cbf_sum#
down_projection#
up_projection_ca#
inv_sqrt_2#
forward(m, bases, idx, id_swap, expand_idx=None, idx_agg2=None, idx_agg2_inner=None, agg2_out_size=None)#
Returns:

m – Edge embeddings.

Return type:

torch.Tensor, shape=(nEdges, emb_size_edge)

class core.models.gemnet_oc.layers.interaction_block.PairInteraction(emb_size_atom, emb_size_pair_in, emb_size_pair_out, emb_size_rbf, activation=None)#

Bases: torch.nn.Module

Pair-based message passing block.

Parameters:
  • emb_size_atom (int) – Embedding size of the atoms.

  • emb_size_pair_in (int) – Embedding size of the atom pairs before the bilinear layer.

  • emb_size_pair_out (int) – Embedding size of the atom pairs after the bilinear layer.

  • emb_size_rbf (int) – Embedding size of the radial basis transformation.

  • activation (str) – Name of the activation function to use in the dense layers.

bilinear#
scale_rbf_sum#
down_projection#
up_projection#
inv_sqrt_2#
forward(h, rad_basis, edge_index, target_neighbor_idx)#
Returns:

h – Atom embeddings.

Return type:

torch.Tensor, shape=(num_atoms, emb_size_atom)