core.models.scn

Contents

core.models.scn#

Submodules#

Classes#

SphericalChannelNetwork

Spherical Channel Network

Package Contents#

class core.models.scn.SphericalChannelNetwork(use_pbc: bool = True, use_pbc_single: bool = True, regress_forces: bool = True, otf_graph: bool = False, max_num_neighbors: int = 20, cutoff: float = 8.0, max_num_elements: int = 90, num_interactions: int = 8, lmax: int = 6, mmax: int = 1, num_resolutions: int = 2, sphere_channels: int = 128, sphere_channels_reduce: int = 128, hidden_channels: int = 256, num_taps: int = -1, use_grid: bool = True, num_bands: int = 1, num_sphere_samples: int = 128, num_basis_functions: int = 128, distance_function: str = 'gaussian', basis_width_scalar: float = 1.0, distance_resolution: float = 0.02, show_timing_info: bool = False, direct_forces: bool = True)#

Bases: torch.nn.Module, fairchem.core.models.base.GraphModelMixin

Spherical Channel Network Paper: Spherical Channels for Modeling Atomic Interactions

Parameters:
  • use_pbc (bool) – Use periodic boundary conditions

  • use_pbc_single (bool) – Process batch PBC graphs one at a time

  • regress_forces (bool) – Compute forces

  • otf_graph (bool) – Compute graph On The Fly (OTF)

  • max_num_neighbors (int) – Maximum number of neighbors per atom

  • cutoff (float) – Maximum distance between nieghboring atoms in Angstroms

  • max_num_elements (int) – Maximum atomic number

  • num_interactions (int) – Number of layers in the GNN

  • lmax (int) – Maximum degree of the spherical harmonics (1 to 10)

  • mmax (int) – Maximum order of the spherical harmonics (0 or 1)

  • num_resolutions (int) – Number of resolutions used to compute messages, further away atoms has lower resolution (1 or 2)

  • sphere_channels (int) – Number of spherical channels

  • sphere_channels_reduce (int) – Number of spherical channels used during message passing (downsample or upsample)

  • hidden_channels (int) – Number of hidden units in message passing

  • num_taps (int) – Number of taps or rotations used during message passing (1 or otherwise set automatically based on mmax)

  • use_grid (bool) – Use non-linear pointwise convolution during aggregation

  • num_bands (int) – Number of bands used during message aggregation for the 1x1 pointwise convolution (1 or 2)

  • num_sphere_samples (int) – Number of samples used to approximate the integration of the sphere in the output blocks

  • num_basis_functions (int) – Number of basis functions used for distance and atomic number blocks

  • distance_function ("gaussian", "sigmoid", "linearsigmoid", "silu") – Basis function used for distances

  • basis_width_scalar (float) – Width of distance basis function

  • distance_resolution (float) – Distance between distance basis functions in Angstroms

  • show_timing_info (bool) – Show timing and memory info

energy_fc1: torch.nn.Linear#
energy_fc2: torch.nn.Linear#
energy_fc3: torch.nn.Linear#
force_fc1: torch.nn.Linear#
force_fc2: torch.nn.Linear#
force_fc3: torch.nn.Linear#
regress_forces#
use_pbc#
use_pbc_single#
cutoff#
otf_graph#
show_timing_info#
max_num_elements#
hidden_channels#
num_interactions#
num_atoms = 0#
num_sphere_samples#
sphere_channels#
sphere_channels_reduce#
num_basis_functions#
distance_resolution#
grad_forces = False#
lmax#
mmax#
basis_width_scalar#
sphere_basis#
use_grid#
distance_function#
counter = 0#
act#
sphere_embedding#
num_gaussians#
sphharm_list = []#
edge_blocks#
forward(data)#
_forward_helper(data)#
_init_edge_rot_mat(data, edge_index, edge_distance_vec)#
_rank_edge_distances(edge_distance, edge_index, max_num_neighbors: int) torch.Tensor#
property num_params: int#