
Copyright (c) Meta, Inc. and its affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.



Wrapper around the continuous-filter convolutional neural network SchNet from the

Module Contents#

class core.models.schnet.SchNetWrap(use_pbc: bool = True, use_pbc_single: bool = False, regress_forces: bool = True, otf_graph: bool = False, hidden_channels: int = 128, num_filters: int = 128, num_interactions: int = 6, num_gaussians: int = 50, cutoff: float = 10.0, readout: str = 'add')#

Bases: torch_geometric.nn.SchNet, fairchem.core.models.base.GraphModelMixin

Wrapper around the continuous-filter convolutional neural network SchNet from the “SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions”. Each layer uses interaction block of the form:

\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))),\]
  • use_pbc (bool, optional) – If set to True, account for periodic boundary conditions. (default: True)

  • use_pbc_single (bool,optional) – Process batch PBC graphs one at a time

  • regress_forces (bool, optional) – If set to True, predict forces by differentiating energy with respect to positions. (default: True)

  • otf_graph (bool, optional) – If set to True, compute graph edges on the fly. (default: False)

  • hidden_channels (int, optional) – Number of hidden channels. (default: 128)

  • num_filters (int, optional) – Number of filters to use. (default: 128)

  • num_interactions (int, optional) – Number of interaction blocks (default: 6)

  • num_gaussians (int, optional) – The number of gaussians \(\mu\). (default: 50)

  • cutoff (float, optional) – Cutoff distance for interatomic interactions. (default: 10.0)

  • readout (string, optional) – Whether to apply "add" or "mean" global aggregation. (default: "add")

num_targets = 1#
max_neighbors = 50#

Forward pass.

  • z (torch.Tensor) – Atomic number of each atom with shape [num_atoms].

  • pos (torch.Tensor) – Coordinates of each atom with shape [num_atoms, 3].

  • batch (torch.Tensor, optional) – Batch indices assigning each atom to a separate molecule with shape [num_atoms]. (default: None)

property num_params: int#