core.modules.normalization.normalizer#
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Classes#
Normalize/denormalize a tensor and optionally add a atom reference offset. |
Functions#
|
Build a target data normalizers with optional atom ref |
|
Estimate mean and rmsd from data to create normalizers |
|
Create a dictionary with element references from a config. |
Module Contents#
- class core.modules.normalization.normalizer.Normalizer(mean: float | torch.Tensor = 0.0, rmsd: float | torch.Tensor = 1.0)#
Bases:
torch.nn.ModuleNormalize/denormalize a tensor and optionally add a atom reference offset.
- norm(tensor: torch.Tensor) torch.Tensor#
- denorm(normed_tensor: torch.Tensor) torch.Tensor#
- forward(normed_tensor: torch.Tensor) torch.Tensor#
- load_state_dict(state_dict: collections.abc.Mapping[str, Any], strict: bool = True, assign: bool = False)#
Copy parameters and buffers from
state_dictinto this module and its descendants.If
strictisTrue, then the keys ofstate_dictmust exactly match the keys returned by this module’sstate_dict()function.Warning
If
assignisTruethe optimizer must be created after the call toload_state_dictunlessget_swap_module_params_on_conversion()isTrue.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dictmatch the keys returned by this module’sstate_dict()function. Default:Trueassign (bool, optional) – When set to
False, the properties of the tensors in the current module are preserved whereas setting it toTruepreserves properties of the Tensors in the state dict. The only exception is therequires_gradfield ofParameterfor which the value from the module is preserved. Default:False
- Returns:
missing_keysis a list of str containing any keys that are expectedby this module but missing from the provided
state_dict.
unexpected_keysis a list of str containing the keys that are notexpected by this module but present in the provided
state_dict.
- Return type:
NamedTuplewithmissing_keysandunexpected_keysfields
Note
If a parameter or buffer is registered as
Noneand its corresponding key exists instate_dict,load_state_dict()will raise aRuntimeError.
- core.modules.normalization.normalizer.create_normalizer(file: str | pathlib.Path | None = None, state_dict: dict | None = None, tensor: torch.Tensor | None = None, mean: float | torch.Tensor | None = None, rmsd: float | torch.Tensor | None = None, stdev: float | torch.Tensor | None = None) Normalizer#
Build a target data normalizers with optional atom ref
Only one of file, state_dict, tensor, or (mean and rmsd) will be used to create a normalizer. If more than one set of inputs are given priority will be given following the order in which they are listed above.
- Parameters:
file (str or Path) – path to pt or npz file.
state_dict (dict) – a state dict for Normalizer module
tensor (Tensor) – a tensor with target values used to compute mean and std
mean (float | Tensor) – mean of target data
rmsd (float | Tensor) – rmsd of target data, rmsd from mean = stdev, rmsd from 0 = rms
stdev – standard deviation (deprecated, use rmsd instead)
- Returns:
Normalizer
- core.modules.normalization.normalizer.fit_normalizers(targets: list[str], dataset: torch.utils.data.Dataset, batch_size: int, override_values: dict[str, dict[str, float]] | None = None, rmsd_correction: int | None = None, element_references: dict | None = None, num_batches: int | None = None, num_workers: int = 0, shuffle: bool = True, seed: int = 0) dict[str, Normalizer]#
Estimate mean and rmsd from data to create normalizers
- Parameters:
targets – list of target names
dataset – data set to fit linear references with
batch_size – size of batch
override_values – dictionary with target names and values to override. i.e. {“forces”: {“mean”: 0.0}} will set the forces mean to zero.
rmsd_correction – correction to use when computing mean in std/rmsd. See docs for torch.std. If not given, will always use 0 when mean == 0, and 1 otherwise.
element_references
num_batches – number of batches to use in fit. If not given will use all batches
num_workers – number of workers to use in data loader Note setting num_workers > 1 leads to finicky multiprocessing issues when using this function in distributed mode. The issue has to do with pickling the functions in load_normalizers_from_config see function below…
shuffle – whether to shuffle when loading the dataset
seed – random seed used to shuffle the sampler if shuffle=True
- Returns:
dict of normalizer objects
- core.modules.normalization.normalizer.load_normalizers_from_config(config: dict[str, Any], dataset: torch.utils.data.Dataset, seed: int = 0, checkpoint_dir: str | pathlib.Path | None = None, element_references: dict[str, fairchem.core.modules.normalization.element_references.LinearReferences] | None = None) dict[str, Normalizer]#
Create a dictionary with element references from a config.