core.common.relaxation.optimizable

Contents

core.common.relaxation.optimizable#

Copyright (c) Meta, Inc. and its affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Code based on ase.optimize

Attributes#

Classes#

Optimizable

OptimizableBatch

A Batch version of ase Optimizable Atoms

OptimizableUnitCellBatch

Modify the supercell and the atom positions in relaxations.

Functions#

compare_batches(→ list[str])

Compare properties between two batches

Module Contents#

class core.common.relaxation.optimizable.Optimizable#
core.common.relaxation.optimizable.ALL_CHANGES: set[str]#
core.common.relaxation.optimizable.compare_batches(batch1: torch_geometric.data.Batch | None, batch2: torch_geometric.data.Batch, tol: float = 1e-06, excluded_properties: set[str] | None = None) list[str]#

Compare properties between two batches

Parameters:
  • batch1 – atoms batch

  • batch2 – atoms batch

  • tol – tolerance used to compare equility of floating point properties

  • excluded_properties – list of properties to exclude from comparison

Returns:

list of system changes, property names that are differente between batch1 and batch2

class core.common.relaxation.optimizable.OptimizableBatch(batch: torch_geometric.data.Batch, trainer: fairchem.core.trainers.BaseTrainer, transform: torch.nn.Module | None = None, mask_converged: bool = True, numpy: bool = False, masked_eps: float = 1e-08)#

Bases: ase.optimize.optimize.Optimizable

A Batch version of ase Optimizable Atoms

This class can be used with ML relaxations in fairchem.core.relaxations.ml_relaxation or in ase relaxations classes, i.e. ase.optimize.lbfgs

ignored_changes: ClassVar[set[str]]#
batch#
trainer#
transform#
numpy#
mask_converged#
_cached_batch = None#
_update_mask = None#
torch_results#
results#
_eps#
otf_graph = True#
property device#
property batch_indices#

Get the batch indices specifying which position/force corresponds to which batch.

property converged_mask#
property update_mask#
check_state(batch: torch_geometric.data.Batch, tol: float = 1e-12) bool#

Check for any system changes since last calculation.

_predict() None#

Run prediction if batch has any changes.

get_property(name, no_numpy: bool = False) torch.Tensor | numpy.typing.NDArray#

Get a predicted property by name.

get_positions() torch.Tensor | numpy.typing.NDArray#

Get the batch positions

set_positions(positions: torch.Tensor | numpy.typing.NDArray) None#

Set the atom positions in the batch.

get_forces(apply_constraint: bool = False, no_numpy: bool = False) torch.Tensor | numpy.typing.NDArray#

Get predicted batch forces.

get_potential_energy(**kwargs) torch.Tensor | numpy.typing.NDArray#

Get predicted energy as the sum of all batch energies.

get_potential_energies() torch.Tensor | numpy.typing.NDArray#

Get the predicted energy for each system in batch.

get_cells() torch.Tensor#

Get batch crystallographic cells.

set_cells(cells: torch.Tensor | numpy.typing.NDArray) None#

Set batch cells.

get_volumes() torch.Tensor#

Get a tensor of volumes for each cell in batch

iterimages() torch_geometric.data.Batch#
get_max_forces(forces: torch.Tensor | None = None, apply_constraint: bool = False) torch.Tensor#

Get the maximum forces per structure in batch

converged(forces: torch.Tensor | numpy.typing.NDArray | None, fmax: float, max_forces: torch.Tensor | None = None) bool#

Check if norm of all predicted forces are below fmax

get_atoms_list() list[ase.Atoms]#

Get ase Atoms objects corresponding to the batch

update_graph()#

Update the graph if model does not use otf_graph.

__len__() int#
class core.common.relaxation.optimizable.OptimizableUnitCellBatch(batch: torch_geometric.data.Batch, trainer: fairchem.core.trainers.BaseTrainer, transform: torch.nn.Module | None = None, numpy: bool = False, mask_converged: bool = True, mask: collections.abc.Sequence[bool] | None = None, cell_factor: float | torch.Tensor | None = None, hydrostatic_strain: bool = False, constant_volume: bool = False, scalar_pressure: float = 0.0, masked_eps: float = 1e-08)#

Bases: OptimizableBatch

Modify the supercell and the atom positions in relaxations.

Based on ase UnitCellFilter to work on data batches

orig_cells#
stress = None#
hydrostatic_strain#
constant_volume#
pressure#
cell_factor#
_batch_trace#
_batch_diag#
property batch_indices#

Get the batch indices specifying which position/force corresponds to which batch.

We augment this to specify the batch indices for augmented positions and forces.

deform_grad()#

Get the cell deformation matrix

get_positions()#

Get positions and cell deformation gradient.

set_positions(positions: torch.Tensor | numpy.typing.NDArray)#

Set positions and cell.

positions has shape (natoms + ncells * 3, 3). the first natoms rows are the positions of the atoms, the last nsystems * three rows are the deformation tensor for each cell.

get_potential_energy(**kwargs)#

returns potential energy including enthalpy PV term.

get_forces(apply_constraint: bool = False, no_numpy: bool = False) torch.Tensor | numpy.typing.NDArray#

Get forces and unit cell stress.

__len__()#