core.common.relaxation.optimizable#
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Code based on ase.optimize
Attributes#
Classes#
A Batch version of ase Optimizable Atoms |
|
Modify the supercell and the atom positions in relaxations. |
Functions#
|
Compare properties between two batches |
Module Contents#
- class core.common.relaxation.optimizable.Optimizable#
- core.common.relaxation.optimizable.ALL_CHANGES: set[str]#
- core.common.relaxation.optimizable.compare_batches(batch1: torch_geometric.data.Batch | None, batch2: torch_geometric.data.Batch, tol: float = 1e-06, excluded_properties: set[str] | None = None) list[str] #
Compare properties between two batches
- Parameters:
batch1 – atoms batch
batch2 – atoms batch
tol – tolerance used to compare equility of floating point properties
excluded_properties – list of properties to exclude from comparison
- Returns:
list of system changes, property names that are differente between batch1 and batch2
- class core.common.relaxation.optimizable.OptimizableBatch(batch: torch_geometric.data.Batch, trainer: fairchem.core.trainers.BaseTrainer, transform: torch.nn.Module | None = None, mask_converged: bool = True, numpy: bool = False, masked_eps: float = 1e-08)#
Bases:
ase.optimize.optimize.Optimizable
A Batch version of ase Optimizable Atoms
This class can be used with ML relaxations in fairchem.core.relaxations.ml_relaxation or in ase relaxations classes, i.e. ase.optimize.lbfgs
- ignored_changes: ClassVar[set[str]]#
- batch#
- trainer#
- transform#
- numpy#
- mask_converged#
- _cached_batch = None#
- _update_mask = None#
- torch_results#
- results#
- _eps#
- otf_graph = True#
- property device#
- property batch_indices#
Get the batch indices specifying which position/force corresponds to which batch.
- property converged_mask#
- property update_mask#
- check_state(batch: torch_geometric.data.Batch, tol: float = 1e-12) bool #
Check for any system changes since last calculation.
- _predict() None #
Run prediction if batch has any changes.
- get_property(name, no_numpy: bool = False) torch.Tensor | numpy.typing.NDArray #
Get a predicted property by name.
- get_positions() torch.Tensor | numpy.typing.NDArray #
Get the batch positions
- set_positions(positions: torch.Tensor | numpy.typing.NDArray) None #
Set the atom positions in the batch.
- get_forces(apply_constraint: bool = False, no_numpy: bool = False) torch.Tensor | numpy.typing.NDArray #
Get predicted batch forces.
- get_potential_energy(**kwargs) torch.Tensor | numpy.typing.NDArray #
Get predicted energy as the sum of all batch energies.
- get_potential_energies() torch.Tensor | numpy.typing.NDArray #
Get the predicted energy for each system in batch.
- get_cells() torch.Tensor #
Get batch crystallographic cells.
- set_cells(cells: torch.Tensor | numpy.typing.NDArray) None #
Set batch cells.
- get_volumes() torch.Tensor #
Get a tensor of volumes for each cell in batch
- iterimages() torch_geometric.data.Batch #
- get_max_forces(forces: torch.Tensor | None = None, apply_constraint: bool = False) torch.Tensor #
Get the maximum forces per structure in batch
- converged(forces: torch.Tensor | numpy.typing.NDArray | None, fmax: float, max_forces: torch.Tensor | None = None) bool #
Check if norm of all predicted forces are below fmax
- get_atoms_list() list[ase.Atoms] #
Get ase Atoms objects corresponding to the batch
- update_graph()#
Update the graph if model does not use otf_graph.
- __len__() int #
- class core.common.relaxation.optimizable.OptimizableUnitCellBatch(batch: torch_geometric.data.Batch, trainer: fairchem.core.trainers.BaseTrainer, transform: torch.nn.Module | None = None, numpy: bool = False, mask_converged: bool = True, mask: collections.abc.Sequence[bool] | None = None, cell_factor: float | torch.Tensor | None = None, hydrostatic_strain: bool = False, constant_volume: bool = False, scalar_pressure: float = 0.0, masked_eps: float = 1e-08)#
Bases:
OptimizableBatch
Modify the supercell and the atom positions in relaxations.
Based on ase UnitCellFilter to work on data batches
- orig_cells#
- stress = None#
- hydrostatic_strain#
- constant_volume#
- pressure#
- cell_factor#
- _batch_trace#
- _batch_diag#
- property batch_indices#
Get the batch indices specifying which position/force corresponds to which batch.
We augment this to specify the batch indices for augmented positions and forces.
- deform_grad()#
Get the cell deformation matrix
- get_positions()#
Get positions and cell deformation gradient.
- set_positions(positions: torch.Tensor | numpy.typing.NDArray)#
Set positions and cell.
positions has shape (natoms + ncells * 3, 3). the first natoms rows are the positions of the atoms, the last nsystems * three rows are the deformation tensor for each cell.
- get_potential_energy(**kwargs)#
returns potential energy including enthalpy PV term.
- get_forces(apply_constraint: bool = False, no_numpy: bool = False) torch.Tensor | numpy.typing.NDArray #
Get forces and unit cell stress.
- __len__()#