core.modules.normalization.normalizer#

Copyright (c) Meta, Inc. and its affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Classes#

Normalizer

Normalize/denormalize a tensor and optionally add a atom reference offset.

Functions#

create_normalizer(→ Normalizer)

Build a target data normalizers with optional atom ref

fit_normalizers(→ dict[str, Normalizer])

Estimate mean and rmsd from data to create normalizers

load_normalizers_from_config(→ dict[str, Normalizer])

Create a dictionary with element references from a config.

Module Contents#

class core.modules.normalization.normalizer.Normalizer(mean: float | torch.Tensor = 0.0, rmsd: float | torch.Tensor = 1.0)#

Bases: torch.nn.Module

Normalize/denormalize a tensor and optionally add a atom reference offset.

norm(tensor: torch.Tensor) torch.Tensor#
denorm(normed_tensor: torch.Tensor) torch.Tensor#
forward(normed_tensor: torch.Tensor) torch.Tensor#
load_state_dict(state_dict: collections.abc.Mapping[str, Any], strict: bool = True, assign: bool = False)#

Copy parameters and buffers from state_dict into this module and its descendants.

If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Warning

If assign is True the optimizer must be created after the call to load_state_dict unless get_swap_module_params_on_conversion() is True.

Parameters:
  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

  • assign (bool, optional) – When False, the properties of the tensors in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of Default: ``False`

Returns:

  • missing_keys is a list of str containing any keys that are expected

    by this module but missing from the provided state_dict.

  • unexpected_keys is a list of str containing the keys that are not

    expected by this module but present in the provided state_dict.

Return type:

NamedTuple with missing_keys and unexpected_keys fields

Note

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

core.modules.normalization.normalizer.create_normalizer(file: str | pathlib.Path | None = None, state_dict: dict | None = None, tensor: torch.Tensor | None = None, mean: float | torch.Tensor | None = None, rmsd: float | torch.Tensor | None = None, stdev: float | torch.Tensor | None = None) Normalizer#

Build a target data normalizers with optional atom ref

Only one of file, state_dict, tensor, or (mean and rmsd) will be used to create a normalizer. If more than one set of inputs are given priority will be given following the order in which they are listed above.

Parameters:
  • file (str or Path) – path to pt or npz file.

  • state_dict (dict) – a state dict for Normalizer module

  • tensor (Tensor) – a tensor with target values used to compute mean and std

  • mean (float | Tensor) – mean of target data

  • rmsd (float | Tensor) – rmsd of target data, rmsd from mean = stdev, rmsd from 0 = rms

  • stdev – standard deviation (deprecated, use rmsd instead)

Returns:

Normalizer

core.modules.normalization.normalizer.fit_normalizers(targets: list[str], dataset: torch.utils.data.Dataset, batch_size: int, override_values: dict[str, dict[str, float]] | None = None, rmsd_correction: int | None = None, element_references: dict | None = None, num_batches: int | None = None, num_workers: int = 0, shuffle: bool = True, seed: int = 0) dict[str, Normalizer]#

Estimate mean and rmsd from data to create normalizers

Parameters:
  • targets – list of target names

  • dataset – data set to fit linear references with

  • batch_size – size of batch

  • override_values – dictionary with target names and values to override. i.e. {“forces”: {“mean”: 0.0}} will set the forces mean to zero.

  • rmsd_correction – correction to use when computing mean in std/rmsd. See docs for torch.std. If not given, will always use 0 when mean == 0, and 1 otherwise.

  • element_references

  • num_batches – number of batches to use in fit. If not given will use all batches

  • num_workers – number of workers to use in data loader Note setting num_workers > 1 leads to finicky multiprocessing issues when using this function in distributed mode. The issue has to do with pickling the functions in load_normalizers_from_config see function below…

  • shuffle – whether to shuffle when loading the dataset

  • seed – random seed used to shuffle the sampler if shuffle=True

Returns:

dict of normalizer objects

core.modules.normalization.normalizer.load_normalizers_from_config(config: dict[str, Any], dataset: torch.utils.data.Dataset, seed: int = 0, checkpoint_dir: str | pathlib.Path | None = None, element_references: dict[str, fairchem.core.modules.normalization.element_references.LinearReferences] | None = None) dict[str, Normalizer]#

Create a dictionary with element references from a config.