core.models.gemnet#
Submodules#
Classes#
GemNet-T, triplets-only variant of GemNet |
Package Contents#
- class core.models.gemnet.GemNetT(num_spherical: int, num_radial: int, num_blocks: int, emb_size_atom: int, emb_size_edge: int, emb_size_trip: int, emb_size_rbf: int, emb_size_cbf: int, emb_size_bil_trip: int, num_before_skip: int, num_after_skip: int, num_concat: int, num_atom: int, regress_forces: bool = True, direct_forces: bool = False, cutoff: float = 6.0, max_neighbors: int = 50, rbf: dict | None = None, envelope: dict | None = None, cbf: dict | None = None, extensive: bool = True, otf_graph: bool = False, use_pbc: bool = True, use_pbc_single: bool = False, output_init: str = 'HeOrthogonal', activation: str = 'swish', num_elements: int = 83, scale_file: str | None = None)#
Bases:
torch.nn.Module
,fairchem.core.models.base.GraphModelMixin
GemNet-T, triplets-only variant of GemNet
- Parameters:
num_spherical (int) – Controls maximum frequency.
num_radial (int) – Controls maximum frequency.
num_blocks (int) – Number of building blocks to be stacked.
emb_size_atom (int) – Embedding size of the atoms.
emb_size_edge (int) – Embedding size of the edges.
emb_size_trip (int) – (Down-projected) Embedding size in the triplet message passing block.
emb_size_rbf (int) – Embedding size of the radial basis transformation.
emb_size_cbf (int) – Embedding size of the circular basis transformation (one angle).
emb_size_bil_trip (int) – Embedding size of the edge embeddings in the triplet-based message passing block after the bilinear layer.
num_before_skip (int) – Number of residual blocks before the first skip connection.
num_after_skip (int) – Number of residual blocks after the first skip connection.
num_concat (int) – Number of residual blocks after the concatenation.
num_atom (int) – Number of residual blocks in the atom embedding blocks.
regress_forces (bool) – Whether to predict forces. Default: True
direct_forces (bool) – If True predict forces based on aggregation of interatomic directions. If False predict forces based on negative gradient of energy potential.
cutoff (float) – Embedding cutoff for interactomic directions in Angstrom.
rbf (dict) – Name and hyperparameters of the radial basis function.
envelope (dict) – Name and hyperparameters of the envelope function.
cbf (dict) – Name and hyperparameters of the cosine basis function.
extensive (bool) – Whether the output should be extensive (proportional to the number of atoms)
output_init (str) – Initialization method for the final dense layer.
activation (str) – Name of the activation function.
scale_file (str) – Path to the json file containing the scaling factors.
- num_blocks#
- extensive#
- cutoff#
- max_neighbors#
- regress_forces#
- otf_graph#
- use_pbc#
- use_pbc_single#
- direct_forces#
- radial_basis#
- cbf_basis3#
- mlp_rbf3#
- mlp_cbf3#
- mlp_rbf_h#
- mlp_rbf_out#
- atom_emb#
- edge_emb#
- out_blocks#
- int_blocks#
- get_triplets(edge_index, num_atoms)#
Get all b->a for each edge c->a. It is possible that b=c, as long as the edges are distinct.
- Returns:
id3_ba (torch.Tensor, shape (num_triplets,)) – Indices of input edge b->a of each triplet b->a<-c
id3_ca (torch.Tensor, shape (num_triplets,)) – Indices of output edge c->a of each triplet b->a<-c
id3_ragged_idx (torch.Tensor, shape (num_triplets,)) – Indices enumerating the copies of id3_ca for creating a padded matrix
- select_symmetric_edges(tensor: torch.Tensor, mask: torch.Tensor, reorder_idx: torch.Tensor, inverse_neg) torch.Tensor #
- reorder_symmetric_edges(edge_index, cell_offsets, neighbors, edge_dist, edge_vector)#
Reorder edges to make finding counter-directional edges easier.
Some edges are only present in one direction in the data, since every atom has a maximum number of neighbors. Since we only use i->j edges here, we lose some j->i edges and add others by making it symmetric. We could fix this by merging edge_index with its counter-edges, including the cell_offsets, and then running torch.unique. But this does not seem worth it.
- select_edges(data, edge_index, cell_offsets, neighbors, edge_dist, edge_vector, cutoff=None)#
- generate_interaction_graph(data)#
- forward(data)#
- property num_params#