core.models.schnet#

Copyright (c) Meta, Inc. and its affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Classes#

SchNetWrap

Wrapper around the continuous-filter convolutional neural network SchNet from the

Module Contents#

class core.models.schnet.SchNetWrap(use_pbc: bool = True, use_pbc_single: bool = False, regress_forces: bool = True, otf_graph: bool = False, hidden_channels: int = 128, num_filters: int = 128, num_interactions: int = 6, num_gaussians: int = 50, cutoff: float = 10.0, readout: str = 'add')#

Bases: torch_geometric.nn.SchNet, fairchem.core.models.base.GraphModelMixin

Wrapper around the continuous-filter convolutional neural network SchNet from the “SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions”. Each layer uses interaction block of the form:

\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))),\]
Parameters:
  • use_pbc (bool, optional) – If set to True, account for periodic boundary conditions. (default: True)

  • use_pbc_single (bool,optional) – Process batch PBC graphs one at a time

  • regress_forces (bool, optional) – If set to True, predict forces by differentiating energy with respect to positions. (default: True)

  • otf_graph (bool, optional) – If set to True, compute graph edges on the fly. (default: False)

  • hidden_channels (int, optional) – Number of hidden channels. (default: 128)

  • num_filters (int, optional) – Number of filters to use. (default: 128)

  • num_interactions (int, optional) – Number of interaction blocks (default: 6)

  • num_gaussians (int, optional) – The number of gaussians \(\mu\). (default: 50)

  • cutoff (float, optional) – Cutoff distance for interatomic interactions. (default: 10.0)

  • readout (string, optional) – Whether to apply "add" or "mean" global aggregation. (default: "add")

num_targets = 1#
regress_forces#
use_pbc#
use_pbc_single#
cutoff#
otf_graph#
max_neighbors = 50#
reduce#
_forward(data)#
forward(data)#

Forward pass.

Parameters:
  • z (torch.Tensor) – Atomic number of each atom with shape [num_atoms].

  • pos (torch.Tensor) – Coordinates of each atom with shape [num_atoms, 3].

  • batch (torch.Tensor, optional) – Batch indices assigning each atom to a separate molecule with shape [num_atoms]. (default: None)

property num_params: int#