core.models.gemnet_oc.layers.efficient#

Copyright (c) Meta, Inc. and its affiliates. This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Classes#

BasisEmbedding

Embed a basis (CBF, SBF), optionally using the efficient reformulation.

EfficientInteractionBilinear

Efficient reformulation of the bilinear layer and subsequent summation.

Module Contents#

class core.models.gemnet_oc.layers.efficient.BasisEmbedding(num_radial: int, emb_size_interm: int, num_spherical: int | None = None)#

Bases: torch.nn.Module

Embed a basis (CBF, SBF), optionally using the efficient reformulation.

Parameters:
  • num_radial (int) – Number of radial basis functions.

  • emb_size_interm (int) – Intermediate embedding size of triplets/quadruplets.

  • num_spherical (int) – Number of circular/spherical basis functions. Only required if there is a circular/spherical basis.

weight: torch.nn.Parameter#
num_radial#
num_spherical#
reset_parameters() None#
forward(rad_basis, sph_basis=None, idx_rad_outer=None, idx_rad_inner=None, idx_sph_outer=None, idx_sph_inner=None, num_atoms=None)#
Parameters:
  • rad_basis (torch.Tensor, shape=(num_edges, num_radial or num_orders * num_radial)) – Raw radial basis.

  • sph_basis (torch.Tensor, shape=(num_triplets or num_quadruplets, num_spherical)) – Raw spherical or circular basis.

  • idx_rad_outer (torch.Tensor, shape=(num_edges)) – Atom associated with each radial basis value. Optional, used for efficient edge aggregation.

  • idx_rad_inner (torch.Tensor, shape=(num_edges)) – Enumerates radial basis values per atom. Optional, used for efficient edge aggregation.

  • idx_sph_outer (torch.Tensor, shape=(num_triplets or num_quadruplets)) – Edge associated with each circular/spherical basis value. Optional, used for efficient triplet/quadruplet aggregation.

  • idx_sph_inner (torch.Tensor, shape=(num_triplets or num_quadruplets)) – Enumerates circular/spherical basis values per edge. Optional, used for efficient triplet/quadruplet aggregation.

  • num_atoms (int) – Total number of atoms. Optional, used for efficient edge aggregation.

Returns:

  • rad_W1 (torch.Tensor, shape=(num_edges, emb_size_interm, num_spherical))

  • sph (torch.Tensor, shape=(num_edges, Kmax, num_spherical)) – Kmax = maximum number of neighbors of the edges

class core.models.gemnet_oc.layers.efficient.EfficientInteractionBilinear(emb_size_in: int, emb_size_interm: int, emb_size_out: int)#

Bases: torch.nn.Module

Efficient reformulation of the bilinear layer and subsequent summation.

Parameters:
  • emb_size_in (int) – Embedding size of input triplets/quadruplets.

  • emb_size_interm (int) – Intermediate embedding size of the basis transformation.

  • emb_size_out (int) – Embedding size of output triplets/quadruplets.

emb_size_in#
emb_size_interm#
emb_size_out#
bilinear#
forward(basis, m, idx_agg_outer, idx_agg_inner, idx_agg2_outer=None, idx_agg2_inner=None, agg2_out_size=None)#
Parameters:
  • basis (Tuple (torch.Tensor, torch.Tensor),) –

    shapes=((num_edges, emb_size_interm, num_spherical),

    (num_edges, num_spherical, Kmax))

    First element: Radial basis multiplied with weight matrix Second element: Circular/spherical basis

  • m (torch.Tensor, shape=(num_edges, emb_size_in)) – Input edge embeddings

  • idx_agg_outer (torch.Tensor, shape=(num_triplets or num_quadruplets)) – Output edge aggregating this intermediate triplet/quadruplet edge.

  • idx_agg_inner (torch.Tensor, shape=(num_triplets or num_quadruplets)) – Enumerates intermediate edges per output edge.

  • idx_agg2_outer (torch.Tensor, shape=(num_edges)) – Output atom aggregating this edge.

  • idx_agg2_inner (torch.Tensor, shape=(num_edges)) – Enumerates edges per output atom.

  • agg2_out_size (int) – Number of output embeddings when aggregating twice. Typically the number of atoms.

Returns:

m_ca – Aggregated edge/atom embeddings.

Return type:

torch.Tensor, shape=(num_edges, emb_size)