core.modules.normalization.normalizer#
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Classes#
Normalize/denormalize a tensor and optionally add a atom reference offset. |
Functions#
|
Build a target data normalizers with optional atom ref |
|
Estimate mean and rmsd from data to create normalizers |
|
Create a dictionary with element references from a config. |
Module Contents#
- class core.modules.normalization.normalizer.Normalizer(mean: float | torch.Tensor = 0.0, rmsd: float | torch.Tensor = 1.0)#
Bases:
torch.nn.Module
Normalize/denormalize a tensor and optionally add a atom reference offset.
- norm(tensor: torch.Tensor) torch.Tensor #
- denorm(normed_tensor: torch.Tensor) torch.Tensor #
- forward(normed_tensor: torch.Tensor) torch.Tensor #
- load_state_dict(state_dict: collections.abc.Mapping[str, Any], strict: bool = True, assign: bool = False)#
Copy parameters and buffers from
state_dict
into this module and its descendants.If
strict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.Warning
If
assign
isTrue
the optimizer must be created after the call toload_state_dict
unlessget_swap_module_params_on_conversion()
isTrue
.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
assign (bool, optional) – When
False
, the properties of the tensors in the current module are preserved while whenTrue
, the properties of the Tensors in the state dict are preserved. The only exception is therequires_grad
field ofDefault: ``False`
- Returns:
- missing_keys is a list of str containing any keys that are expected
by this module but missing from the provided
state_dict
.
- unexpected_keys is a list of str containing the keys that are not
expected by this module but present in the provided
state_dict
.
- Return type:
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- core.modules.normalization.normalizer.create_normalizer(file: str | pathlib.Path | None = None, state_dict: dict | None = None, tensor: torch.Tensor | None = None, mean: float | torch.Tensor | None = None, rmsd: float | torch.Tensor | None = None, stdev: float | torch.Tensor | None = None) Normalizer #
Build a target data normalizers with optional atom ref
Only one of file, state_dict, tensor, or (mean and rmsd) will be used to create a normalizer. If more than one set of inputs are given priority will be given following the order in which they are listed above.
- Parameters:
file (str or Path) – path to pt or npz file.
state_dict (dict) – a state dict for Normalizer module
tensor (Tensor) – a tensor with target values used to compute mean and std
mean (float | Tensor) – mean of target data
rmsd (float | Tensor) – rmsd of target data, rmsd from mean = stdev, rmsd from 0 = rms
stdev – standard deviation (deprecated, use rmsd instead)
- Returns:
Normalizer
- core.modules.normalization.normalizer.fit_normalizers(targets: list[str], dataset: torch.utils.data.Dataset, batch_size: int, override_values: dict[str, dict[str, float]] | None = None, rmsd_correction: int | None = None, element_references: dict | None = None, num_batches: int | None = None, num_workers: int = 0, shuffle: bool = True, seed: int = 0) dict[str, Normalizer] #
Estimate mean and rmsd from data to create normalizers
- Parameters:
targets – list of target names
dataset – data set to fit linear references with
batch_size – size of batch
override_values – dictionary with target names and values to override. i.e. {“forces”: {“mean”: 0.0}} will set the forces mean to zero.
rmsd_correction – correction to use when computing mean in std/rmsd. See docs for torch.std. If not given, will always use 0 when mean == 0, and 1 otherwise.
element_references
num_batches – number of batches to use in fit. If not given will use all batches
num_workers – number of workers to use in data loader Note setting num_workers > 1 leads to finicky multiprocessing issues when using this function in distributed mode. The issue has to do with pickling the functions in load_normalizers_from_config see function below…
shuffle – whether to shuffle when loading the dataset
seed – random seed used to shuffle the sampler if shuffle=True
- Returns:
dict of normalizer objects
- core.modules.normalization.normalizer.load_normalizers_from_config(config: dict[str, Any], dataset: torch.utils.data.Dataset, seed: int = 0, checkpoint_dir: str | pathlib.Path | None = None, element_references: dict[str, fairchem.core.modules.normalization.element_references.LinearReferences] | None = None) dict[str, Normalizer] #
Create a dictionary with element references from a config.