core.models.gemnet_oc.layers.efficient#
Copyright (c) Meta, Inc. and its affiliates. This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Classes#
Embed a basis (CBF, SBF), optionally using the efficient reformulation. |
|
Efficient reformulation of the bilinear layer and subsequent summation. |
Module Contents#
- class core.models.gemnet_oc.layers.efficient.BasisEmbedding(num_radial: int, emb_size_interm: int, num_spherical: int | None = None)#
Bases:
torch.nn.Module
Embed a basis (CBF, SBF), optionally using the efficient reformulation.
- Parameters:
num_radial (int) – Number of radial basis functions.
emb_size_interm (int) – Intermediate embedding size of triplets/quadruplets.
num_spherical (int) – Number of circular/spherical basis functions. Only required if there is a circular/spherical basis.
- weight: torch.nn.Parameter#
- num_radial#
- num_spherical#
- reset_parameters() None #
- forward(rad_basis, sph_basis=None, idx_rad_outer=None, idx_rad_inner=None, idx_sph_outer=None, idx_sph_inner=None, num_atoms=None)#
- Parameters:
rad_basis (torch.Tensor, shape=(num_edges, num_radial or num_orders * num_radial)) – Raw radial basis.
sph_basis (torch.Tensor, shape=(num_triplets or num_quadruplets, num_spherical)) – Raw spherical or circular basis.
idx_rad_outer (torch.Tensor, shape=(num_edges)) – Atom associated with each radial basis value. Optional, used for efficient edge aggregation.
idx_rad_inner (torch.Tensor, shape=(num_edges)) – Enumerates radial basis values per atom. Optional, used for efficient edge aggregation.
idx_sph_outer (torch.Tensor, shape=(num_triplets or num_quadruplets)) – Edge associated with each circular/spherical basis value. Optional, used for efficient triplet/quadruplet aggregation.
idx_sph_inner (torch.Tensor, shape=(num_triplets or num_quadruplets)) – Enumerates circular/spherical basis values per edge. Optional, used for efficient triplet/quadruplet aggregation.
num_atoms (int) – Total number of atoms. Optional, used for efficient edge aggregation.
- Returns:
rad_W1 (torch.Tensor, shape=(num_edges, emb_size_interm, num_spherical))
sph (torch.Tensor, shape=(num_edges, Kmax, num_spherical)) – Kmax = maximum number of neighbors of the edges
- class core.models.gemnet_oc.layers.efficient.EfficientInteractionBilinear(emb_size_in: int, emb_size_interm: int, emb_size_out: int)#
Bases:
torch.nn.Module
Efficient reformulation of the bilinear layer and subsequent summation.
- Parameters:
emb_size_in (int) – Embedding size of input triplets/quadruplets.
emb_size_interm (int) – Intermediate embedding size of the basis transformation.
emb_size_out (int) – Embedding size of output triplets/quadruplets.
- emb_size_in#
- emb_size_interm#
- emb_size_out#
- bilinear#
- forward(basis, m, idx_agg_outer, idx_agg_inner, idx_agg2_outer=None, idx_agg2_inner=None, agg2_out_size=None)#
- Parameters:
basis (Tuple (torch.Tensor, torch.Tensor),) –
- shapes=((num_edges, emb_size_interm, num_spherical),
(num_edges, num_spherical, Kmax))
First element: Radial basis multiplied with weight matrix Second element: Circular/spherical basis
m (torch.Tensor, shape=(num_edges, emb_size_in)) – Input edge embeddings
idx_agg_outer (torch.Tensor, shape=(num_triplets or num_quadruplets)) – Output edge aggregating this intermediate triplet/quadruplet edge.
idx_agg_inner (torch.Tensor, shape=(num_triplets or num_quadruplets)) – Enumerates intermediate edges per output edge.
idx_agg2_outer (torch.Tensor, shape=(num_edges)) – Output atom aggregating this edge.
idx_agg2_inner (torch.Tensor, shape=(num_edges)) – Enumerates edges per output atom.
agg2_out_size (int) – Number of output embeddings when aggregating twice. Typically the number of atoms.
- Returns:
m_ca – Aggregated edge/atom embeddings.
- Return type:
torch.Tensor, shape=(num_edges, emb_size)