core.modules.exponential_moving_average#
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Copied (and improved) from: fadel/pytorch_ema (MIT license)
Classes#
Maintains (exponential) moving average of a set of parameters. |
Module Contents#
- class core.modules.exponential_moving_average.ExponentialMovingAverage(parameters: collections.abc.Iterable[torch.nn.Parameter], decay: float, use_num_updates: bool = False)#
Maintains (exponential) moving average of a set of parameters.
- Parameters:
parameters – Iterable of torch.nn.Parameter (typically from model.parameters()).
decay – The exponential decay.
use_num_updates – Whether to use number of updates when computing averages.
- decay#
- num_updates: int | None#
- shadow_params#
- collected_params: list[torch.nn.Parameter] = []#
- _params_refs#
- _get_parameters(parameters: collections.abc.Iterable[torch.nn.Parameter] | None) collections.abc.Iterable[torch.nn.Parameter] #
- update(parameters: collections.abc.Iterable[torch.nn.Parameter] | None = None) None #
Update currently maintained parameters.
Call this every time the parameters are updated, such as the result of the optimizer.step() call.
- Parameters:
parameters – Iterable of torch.nn.Parameter; usually the same set of parameters used to initialize this object. If None, the parameters with which this ExponentialMovingAverage was initialized will be used.
- copy_to(parameters: collections.abc.Iterable[torch.nn.Parameter] | None = None) None #
Copy current parameters into given collection of parameters.
- Parameters:
parameters – Iterable of torch.nn.Parameter; the parameters to be updated with the stored moving averages. If None, the parameters with which this ExponentialMovingAverage was initialized will be used.
- store(parameters: collections.abc.Iterable[torch.nn.Parameter] | None = None) None #
Save the current parameters for restoring later.
- Parameters:
parameters – Iterable of torch.nn.Parameter; the parameters to be temporarily stored. If None, the parameters of with which this ExponentialMovingAverage was initialized will be used.
- restore(parameters: collections.abc.Iterable[torch.nn.Parameter] | None = None) None #
Restore the parameters stored with the store method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the copy_to method. After validation (or model saving), use this to restore the former parameters.
- Parameters:
parameters – Iterable of torch.nn.Parameter; the parameters to be updated with the stored parameters. If None, the parameters with which this ExponentialMovingAverage was initialized will be used.
- state_dict() dict #
Returns the state of the ExponentialMovingAverage as a dict.
- load_state_dict(state_dict: dict) None #
Loads the ExponentialMovingAverage state.
- Parameters:
state_dict (dict) – EMA state. Should be an object returned from a call to
state_dict()
.