core.models.gemnet#

Submodules#

Classes#

GemNetT

GemNet-T, triplets-only variant of GemNet

Package Contents#

class core.models.gemnet.GemNetT(num_spherical: int, num_radial: int, num_blocks: int, emb_size_atom: int, emb_size_edge: int, emb_size_trip: int, emb_size_rbf: int, emb_size_cbf: int, emb_size_bil_trip: int, num_before_skip: int, num_after_skip: int, num_concat: int, num_atom: int, regress_forces: bool = True, direct_forces: bool = False, cutoff: float = 6.0, max_neighbors: int = 50, rbf: dict | None = None, envelope: dict | None = None, cbf: dict | None = None, extensive: bool = True, otf_graph: bool = False, use_pbc: bool = True, use_pbc_single: bool = False, output_init: str = 'HeOrthogonal', activation: str = 'swish', num_elements: int = 83, scale_file: str | None = None)#

Bases: torch.nn.Module, fairchem.core.models.base.GraphModelMixin

GemNet-T, triplets-only variant of GemNet

Parameters:
  • num_spherical (int) – Controls maximum frequency.

  • num_radial (int) – Controls maximum frequency.

  • num_blocks (int) – Number of building blocks to be stacked.

  • emb_size_atom (int) – Embedding size of the atoms.

  • emb_size_edge (int) – Embedding size of the edges.

  • emb_size_trip (int) – (Down-projected) Embedding size in the triplet message passing block.

  • emb_size_rbf (int) – Embedding size of the radial basis transformation.

  • emb_size_cbf (int) – Embedding size of the circular basis transformation (one angle).

  • emb_size_bil_trip (int) – Embedding size of the edge embeddings in the triplet-based message passing block after the bilinear layer.

  • num_before_skip (int) – Number of residual blocks before the first skip connection.

  • num_after_skip (int) – Number of residual blocks after the first skip connection.

  • num_concat (int) – Number of residual blocks after the concatenation.

  • num_atom (int) – Number of residual blocks in the atom embedding blocks.

  • regress_forces (bool) – Whether to predict forces. Default: True

  • direct_forces (bool) – If True predict forces based on aggregation of interatomic directions. If False predict forces based on negative gradient of energy potential.

  • cutoff (float) – Embedding cutoff for interactomic directions in Angstrom.

  • rbf (dict) – Name and hyperparameters of the radial basis function.

  • envelope (dict) – Name and hyperparameters of the envelope function.

  • cbf (dict) – Name and hyperparameters of the cosine basis function.

  • extensive (bool) – Whether the output should be extensive (proportional to the number of atoms)

  • output_init (str) – Initialization method for the final dense layer.

  • activation (str) – Name of the activation function.

  • scale_file (str) – Path to the json file containing the scaling factors.

num_blocks#
extensive#
cutoff#
max_neighbors#
regress_forces#
otf_graph#
use_pbc#
use_pbc_single#
direct_forces#
radial_basis#
cbf_basis3#
mlp_rbf3#
mlp_cbf3#
mlp_rbf_h#
mlp_rbf_out#
atom_emb#
edge_emb#
out_blocks#
int_blocks#
shared_parameters#
get_triplets(edge_index, num_atoms)#

Get all b->a for each edge c->a. It is possible that b=c, as long as the edges are distinct.

Returns:

  • id3_ba (torch.Tensor, shape (num_triplets,)) – Indices of input edge b->a of each triplet b->a<-c

  • id3_ca (torch.Tensor, shape (num_triplets,)) – Indices of output edge c->a of each triplet b->a<-c

  • id3_ragged_idx (torch.Tensor, shape (num_triplets,)) – Indices enumerating the copies of id3_ca for creating a padded matrix

select_symmetric_edges(tensor: torch.Tensor, mask: torch.Tensor, reorder_idx: torch.Tensor, inverse_neg) torch.Tensor#
reorder_symmetric_edges(edge_index, cell_offsets, neighbors, edge_dist, edge_vector)#

Reorder edges to make finding counter-directional edges easier.

Some edges are only present in one direction in the data, since every atom has a maximum number of neighbors. Since we only use i->j edges here, we lose some j->i edges and add others by making it symmetric. We could fix this by merging edge_index with its counter-edges, including the cell_offsets, and then running torch.unique. But this does not seem worth it.

select_edges(data, edge_index, cell_offsets, neighbors, edge_dist, edge_vector, cutoff=None)#
generate_interaction_graph(data)#
forward(data)#
property num_params#