core.models.schnet#
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Classes#
Wrapper around the continuous-filter convolutional neural network SchNet from the |
Module Contents#
- class core.models.schnet.SchNetWrap(use_pbc: bool = True, use_pbc_single: bool = False, regress_forces: bool = True, otf_graph: bool = False, hidden_channels: int = 128, num_filters: int = 128, num_interactions: int = 6, num_gaussians: int = 50, cutoff: float = 10.0, readout: str = 'add')#
Bases:
torch_geometric.nn.SchNet
,fairchem.core.models.base.GraphModelMixin
Wrapper around the continuous-filter convolutional neural network SchNet from the “SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions”. Each layer uses interaction block of the form:
\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))),\]- Parameters:
use_pbc (bool, optional) – If set to
True
, account for periodic boundary conditions. (default:True
)use_pbc_single (bool,optional) – Process batch PBC graphs one at a time
regress_forces (bool, optional) – If set to
True
, predict forces by differentiating energy with respect to positions. (default:True
)otf_graph (bool, optional) – If set to
True
, compute graph edges on the fly. (default:False
)hidden_channels (int, optional) – Number of hidden channels. (default:
128
)num_filters (int, optional) – Number of filters to use. (default:
128
)num_interactions (int, optional) – Number of interaction blocks (default:
6
)num_gaussians (int, optional) – The number of gaussians \(\mu\). (default:
50
)cutoff (float, optional) – Cutoff distance for interatomic interactions. (default:
10.0
)readout (string, optional) – Whether to apply
"add"
or"mean"
global aggregation. (default:"add"
)
- num_targets = 1#
- regress_forces#
- use_pbc#
- use_pbc_single#
- cutoff#
- otf_graph#
- max_neighbors = 50#
- reduce#
- _forward(data)#
- forward(data)#
Forward pass.
- Parameters:
z (torch.Tensor) – Atomic number of each atom with shape
[num_atoms]
.pos (torch.Tensor) – Coordinates of each atom with shape
[num_atoms, 3]
.batch (torch.Tensor, optional) – Batch indices assigning each atom to a separate molecule with shape
[num_atoms]
. (default:None
)
- property num_params: int#