core.models.escn.escn_exportable#
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Classes#
Equivariant Spherical Channel Network |
|
Layer block: Perform one layer (message passing and aggregation) of the GNN |
|
Message block: Perform message passing |
|
SO(2) Block: Perform SO(2) convolutions for all m (orders) |
|
SO(2) Conv: Perform an SO(2) convolution |
|
Edge Block: Compute invariant edge representation from edge diatances and atomic numbers |
|
Energy Block: Output block computing the energy |
|
Force Block: Output block computing the per atom forces |
Module Contents#
- class core.models.escn.escn_exportable.eSCN(max_neighbors: int = 300, cutoff: float = 8.0, max_num_elements: int = 100, num_layers: int = 8, lmax: int = 4, mmax: int = 2, sphere_channels: int = 128, hidden_channels: int = 256, edge_channels: int = 128, num_sphere_samples: int = 128, distance_function: str = 'gaussian', basis_width_scalar: float = 1.0, distance_resolution: float = 0.02, resolution: int | None = None, compile: bool = False, export: bool = False, rescale_grid: bool = False)#
Bases:
torch.nn.Module
,fairchem.core.models.base.GraphModelMixin
Equivariant Spherical Channel Network Paper: Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs
- Parameters:
max_neighbors (int) – Max neighbors to take per node, when using the graph generation
cutoff (float) – Maximum distance between nieghboring atoms in Angstroms
max_num_elements (int) – Maximum atomic number
num_layers (int) – Number of layers in the GNN
lmax (int) – maximum degree of the spherical harmonics (1 to 10)
mmax (int) – maximum order of the spherical harmonics (0 to lmax)
sphere_channels (int) – Number of spherical channels (one set per resolution)
hidden_channels (int) – Number of hidden units in message passing
num_sphere_samples (int) – Number of samples used to approximate the integration of the sphere in the output blocks
edge_channels (int) – Number of channels for the edge invariant features
distance_function ("gaussian", "sigmoid", "linearsigmoid", "silu") – Basis function used for distances
basis_width_scalar (float) – Width of distance basis function
distance_resolution (float) – Distance between distance basis functions in Angstroms
compile (bool) – use torch.compile on the forward
export (bool) – use the exportable version of the module
- max_neighbors#
- cutoff#
- max_num_elements#
- num_layers#
- num_sphere_samples#
- sphere_channels#
- edge_channels#
- distance_resolution#
- lmax#
- mmax#
- basis_width_scalar#
- distance_function#
- compile#
Compile this Module’s forward using
torch.compile()
.This Module’s __call__ method is compiled and all arguments are passed as-is to
torch.compile()
.See
torch.compile()
for details on the arguments for this function.
- export#
- rescale_grid#
- act#
- sphere_embedding#
- num_gaussians#
- SO3_grid#
- layer_blocks#
- energy_block#
- force_block#
- sphere_points#
- sphharm_weights: torch.nn.Parameter#
- sph_feature_size#
- forward_trainable(data: torch_geometric.data.batch.Batch) dict[str, torch.Tensor] #
- forward(pos: torch.Tensor, batch_idx: torch.Tensor, natoms: torch.Tensor, atomic_numbers: torch.Tensor, edge_index: torch.Tensor, edge_distance: torch.Tensor, edge_distance_vec: torch.Tensor) list[torch.Tensor] #
N: num atoms N: batch size E: num edges
pos: [N, 3] atom positions batch_idx: [N] batch index of each atom natoms: [B] number of atoms in each batch atomic_numbers: [N] atomic number per atom edge_index: [2, E] edges between source and target atoms edge_distance: [E] cartesian distance for each edge edge_distance_vec: [E, 3] direction vector of edges (includes pbc)
- _init_edge_rot_mat(edge_distance_vec)#
- property num_params: int#
- class core.models.escn.escn_exportable.LayerBlock(layer_idx: int, sphere_channels: int, hidden_channels: int, edge_channels: int, lmax: int, mmax: int, distance_expansion, max_num_elements: int, SO3_grid: fairchem.core.models.escn.so3_exportable.SO3_Grid, act)#
Bases:
torch.nn.Module
Layer block: Perform one layer (message passing and aggregation) of the GNN
- Parameters:
layer_idx (int) – Layer number
sphere_channels (int) – Number of spherical channels
hidden_channels (int) – Number of hidden channels used during the SO(2) conv
edge_channels (int) – Size of invariant edge embedding
lmax (int) degrees (l)
mmax (int) – orders (m) for each resolution
distance_expansion (func) – Function used to compute distance embedding
max_num_elements (int) – Maximum number of atomic numbers
SO3_grid (SO3_grid) – Class used to convert from grid the spherical harmonic representations
act (function) – Non-linear activation function
- layer_idx#
- act#
- lmax#
- mmax#
- sphere_channels#
- SO3_grid#
- message_block#
- fc1_sphere#
- fc2_sphere#
- fc3_sphere#
- forward(x: torch.Tensor, atomic_numbers: torch.Tensor, edge_distance: torch.Tensor, edge_index: torch.Tensor, wigner: torch.Tensor) torch.Tensor #
- class core.models.escn.escn_exportable.MessageBlock(layer_idx: int, sphere_channels: int, hidden_channels: int, edge_channels: int, lmax: int, mmax: int, distance_expansion, max_num_elements: int, SO3_grid: fairchem.core.models.escn.so3_exportable.SO3_Grid, act)#
Bases:
torch.nn.Module
Message block: Perform message passing
- Parameters:
layer_idx (int) – Layer number
sphere_channels (int) – Number of spherical channels
hidden_channels (int) – Number of hidden channels used during the SO(2) conv
edge_channels (int) – Size of invariant edge embedding
lmax (int) – degrees (l) for each resolution
mmax (int) – orders (m) for each resolution
distance_expansion (func) – Function used to compute distance embedding
max_num_elements (int) – Maximum number of atomic numbers
SO3_grid (SO3_grid) – Class used to convert from grid the spherical harmonic representations
act (function) – Non-linear activation function
- layer_idx#
- act#
- sphere_channels#
- SO3_grid#
- lmax#
- mmax#
- edge_channels#
- out_mask#
- edge_block#
- so2_block_source#
- so2_block_target#
- forward(x: torch.Tensor, atomic_numbers: torch.Tensor, edge_distance: torch.Tensor, edge_index: torch.Tensor, wigner: torch.Tensor) torch.Tensor #
- class core.models.escn.escn_exportable.SO2Block(sphere_channels: int, hidden_channels: int, edge_channels: int, lmax: int, mmax: int, act)#
Bases:
torch.nn.Module
SO(2) Block: Perform SO(2) convolutions for all m (orders)
- Parameters:
sphere_channels (int) – Number of spherical channels
hidden_channels (int) – Number of hidden channels used during the SO(2) conv
edge_channels (int) – Size of invariant edge embedding
lmax (int) – degrees (l) for each resolution
mmax (int) – orders (m) for each resolution
act (function) – Non-linear activation function
- sphere_channels#
- lmax#
- mmax#
- act#
- mappingReduced#
- fc1_dist0#
- fc1_m0#
- fc2_m0#
- so2_conv#
- forward(x: torch.Tensor, x_edge: torch.Tensor)#
- class core.models.escn.escn_exportable.SO2Conv(m: int, sphere_channels: int, hidden_channels: int, edge_channels: int, lmax: int, mmax: int, act)#
Bases:
torch.nn.Module
SO(2) Conv: Perform an SO(2) convolution
- Parameters:
m (int) – Order of the spherical harmonic coefficients
sphere_channels (int) – Number of spherical channels
hidden_channels (int) – Number of hidden channels used during the SO(2) conv
edge_channels (int) – Size of invariant edge embedding
lmax (int) – degrees (l) for each resolution
mmax (int) – orders (m) for each resolution
act (function) – Non-linear activation function
- lmax#
- mmax#
- sphere_channels#
- m#
- act#
- fc1_dist#
- fc1_r#
- fc2_r#
- fc1_i#
- fc2_i#
- forward(x_m, x_edge) torch.Tensor #
- class core.models.escn.escn_exportable.EdgeBlock(edge_channels, distance_expansion, max_num_elements, act)#
Bases:
torch.nn.Module
Edge Block: Compute invariant edge representation from edge diatances and atomic numbers
- Parameters:
edge_channels (int) – Size of invariant edge embedding
distance_expansion (func) – Function used to compute distance embedding
max_num_elements (int) – Maximum number of atomic numbers
act (function) – Non-linear activation function
- in_channels#
- distance_expansion#
- act#
- edge_channels#
- max_num_elements#
- fc1_dist#
- source_embedding#
- target_embedding#
- fc1_edge_attr#
- forward(edge_distance, source_element, target_element)#
- class core.models.escn.escn_exportable.EnergyBlock(num_channels: int, num_sphere_samples: int, act)#
Bases:
torch.nn.Module
Energy Block: Output block computing the energy
- Parameters:
num_channels (int) – Number of channels
num_sphere_samples (int) – Number of samples used to approximate the integral on the sphere
act (function) – Non-linear activation function
- num_channels#
- num_sphere_samples#
- act#
- fc1#
- fc2#
- fc3#
- forward(x_pt) torch.Tensor #
- class core.models.escn.escn_exportable.ForceBlock(num_channels: int, num_sphere_samples: int, act)#
Bases:
torch.nn.Module
Force Block: Output block computing the per atom forces
- Parameters:
num_channels (int) – Number of channels
num_sphere_samples (int) – Number of samples used to approximate the integral on the sphere
act (function) – Non-linear activation function
- num_channels#
- num_sphere_samples#
- act#
- fc1#
- fc2#
- fc3#
- forward(x_pt, sphere_points) torch.Tensor #