core.modules.normalization.element_references#

Copyright (c) Meta, Inc. and its affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Classes#

LinearReferences

Represents an elemental linear references model for a target property.

Functions#

create_element_references(→ LinearReferences)

Create an element reference module.

fit_linear_references(→ dict[str, LinearReferences])

Fit a set linear references for a list of targets using a given number of batches.

load_references_from_config(→ dict[str, LinearReferences])

Create a dictionary with element references from a config.

Module Contents#

class core.modules.normalization.element_references.LinearReferences(element_references: torch.Tensor | None = None, max_num_elements: int = 118)#

Bases: torch.nn.Module

Represents an elemental linear references model for a target property.

In an elemental reference associates a value with each chemical element present in the dataset. Elemental references define a chemical composition model, i.e. a rough approximation of a target property (energy) using elemental references is done by summing the elemental references multiplied by the number of times the corresponding element is present.

Elemental references energies can be taken as:
  • the energy of a chemical species in its elemental state (i.e. lowest energy polymorph of single element crystal structures for solids)

  • fitting a linear model to a dataset, where the features are the counts of each element in each data point. see the function fit_linear references below for details

Training GNNs to predict the difference between DFT and the predictions of a chemical composition model represent a useful normalization scheme that can improve model accuracy. See for example the “Alternative reference scheme” section of the OC22 manuscript: https://arxiv.org/pdf/2206.08917

_apply_refs(target: torch.Tensor, batch: torch_geometric.data.Batch, sign: int, reshaped: bool = True) torch.Tensor#

Apply references batch-wise

dereference(target: torch.Tensor, batch: torch_geometric.data.Batch, reshaped: bool = True) torch.Tensor#

Remove linear references

forward(target: torch.Tensor, batch: torch_geometric.data.Batch, reshaped: bool = True) torch.Tensor#

Add linear references

core.modules.normalization.element_references.create_element_references(file: str | pathlib.Path | None = None, state_dict: dict | None = None) LinearReferences#

Create an element reference module.

Parameters:
  • type (str) – type of reference (only linear implemented)

  • file (str or Path) – path to pt or npz file

  • state_dict (dict) – a state dict of a element reference module

Returns:

LinearReference

core.modules.normalization.element_references.fit_linear_references(targets: list[str], dataset: torch.utils.data.Dataset, batch_size: int, num_batches: int | None = None, num_workers: int = 0, max_num_elements: int = 118, log_metrics: bool = True, use_numpy: bool = True, driver: str | None = None, shuffle: bool = True, seed: int = 0) dict[str, LinearReferences]#

Fit a set linear references for a list of targets using a given number of batches.

Parameters:
  • targets – list of target names

  • dataset – data set to fit linear references with

  • batch_size – size of batch

  • num_batches – number of batches to use in fit. If not given will use all batches

  • num_workers – number of workers to use in data loader. Note setting num_workers > 1 leads to finicky multiprocessing issues when using this function in distributed mode. The issue has to do with pickling the functions in load_references_from_config see function below…

  • max_num_elements – max number of elements in dataset. If not given will use an ambitious value of 118

  • log_metrics – if true will compute MAE, RMSE and R2 score of fit and log.

  • use_numpy – use numpy.linalg.lstsq instead of torch. This tends to give better solutions.

  • driver – backend used to solve linear system. See torch.linalg.lstsq docs. Ignored if use_numpy=True

  • shuffle – whether to shuffle when loading the dataset

  • seed – random seed used to shuffle the sampler if shuffle=True

Returns:

dict of fitted LinearReferences objects

core.modules.normalization.element_references.load_references_from_config(config: dict[str, Any], dataset: torch.utils.data.Dataset, seed: int = 0, checkpoint_dir: str | pathlib.Path | None = None) dict[str, LinearReferences]#

Create a dictionary with element references from a config.