core.modules.exponential_moving_average#
Copied (and improved) from: fadel/pytorch_ema (MIT license)
Classes#
Maintains (exponential) moving average of a set of parameters. |
Module Contents#
- class core.modules.exponential_moving_average.ExponentialMovingAverage(parameters: collections.abc.Iterable[torch.nn.Parameter], decay: float, use_num_updates: bool = False)#
Maintains (exponential) moving average of a set of parameters.
- Parameters:
parameters – Iterable of torch.nn.Parameter (typically from model.parameters()).
decay – The exponential decay.
use_num_updates – Whether to use number of updates when computing averages.
- decay#
- num_updates: int | None#
- shadow_params#
- collected_params: list[torch.nn.Parameter] = []#
- _params_refs#
- _get_parameters(parameters: collections.abc.Iterable[torch.nn.Parameter] | None) collections.abc.Iterable[torch.nn.Parameter] #
- update(parameters: collections.abc.Iterable[torch.nn.Parameter] | None = None) None #
Update currently maintained parameters.
Call this every time the parameters are updated, such as the result of the optimizer.step() call.
- Parameters:
parameters – Iterable of torch.nn.Parameter; usually the same set of parameters used to initialize this object. If None, the parameters with which this ExponentialMovingAverage was initialized will be used.
- copy_to(parameters: collections.abc.Iterable[torch.nn.Parameter] | None = None) None #
Copy current parameters into given collection of parameters.
- Parameters:
parameters – Iterable of torch.nn.Parameter; the parameters to be updated with the stored moving averages. If None, the parameters with which this ExponentialMovingAverage was initialized will be used.
- store(parameters: collections.abc.Iterable[torch.nn.Parameter] | None = None) None #
Save the current parameters for restoring later.
- Parameters:
parameters – Iterable of torch.nn.Parameter; the parameters to be temporarily stored. If None, the parameters of with which this ExponentialMovingAverage was initialized will be used.
- restore(parameters: collections.abc.Iterable[torch.nn.Parameter] | None = None) None #
Restore the parameters stored with the store method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the copy_to method. After validation (or model saving), use this to restore the former parameters.
- Parameters:
parameters – Iterable of torch.nn.Parameter; the parameters to be updated with the stored parameters. If None, the parameters with which this ExponentialMovingAverage was initialized will be used.
- state_dict() dict #
Returns the state of the ExponentialMovingAverage as a dict.
- load_state_dict(state_dict: dict) None #
Loads the ExponentialMovingAverage state.
- Parameters:
state_dict (dict) – EMA state. Should be an object returned from a call to
state_dict()
.