core.modules.exponential_moving_average#

Copied (and improved) from: fadel/pytorch_ema (MIT license)

Classes#

ExponentialMovingAverage

Maintains (exponential) moving average of a set of parameters.

Module Contents#

class core.modules.exponential_moving_average.ExponentialMovingAverage(parameters: collections.abc.Iterable[torch.nn.Parameter], decay: float, use_num_updates: bool = False)#

Maintains (exponential) moving average of a set of parameters.

Parameters:
  • parameters – Iterable of torch.nn.Parameter (typically from model.parameters()).

  • decay – The exponential decay.

  • use_num_updates – Whether to use number of updates when computing averages.

decay#
num_updates: int | None#
shadow_params#
collected_params: list[torch.nn.Parameter] = []#
_params_refs#
_get_parameters(parameters: collections.abc.Iterable[torch.nn.Parameter] | None) collections.abc.Iterable[torch.nn.Parameter]#
update(parameters: collections.abc.Iterable[torch.nn.Parameter] | None = None) None#

Update currently maintained parameters.

Call this every time the parameters are updated, such as the result of the optimizer.step() call.

Parameters:

parameters – Iterable of torch.nn.Parameter; usually the same set of parameters used to initialize this object. If None, the parameters with which this ExponentialMovingAverage was initialized will be used.

copy_to(parameters: collections.abc.Iterable[torch.nn.Parameter] | None = None) None#

Copy current parameters into given collection of parameters.

Parameters:

parameters – Iterable of torch.nn.Parameter; the parameters to be updated with the stored moving averages. If None, the parameters with which this ExponentialMovingAverage was initialized will be used.

store(parameters: collections.abc.Iterable[torch.nn.Parameter] | None = None) None#

Save the current parameters for restoring later.

Parameters:

parameters – Iterable of torch.nn.Parameter; the parameters to be temporarily stored. If None, the parameters of with which this ExponentialMovingAverage was initialized will be used.

restore(parameters: collections.abc.Iterable[torch.nn.Parameter] | None = None) None#

Restore the parameters stored with the store method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the copy_to method. After validation (or model saving), use this to restore the former parameters.

Parameters:

parameters – Iterable of torch.nn.Parameter; the parameters to be updated with the stored parameters. If None, the parameters with which this ExponentialMovingAverage was initialized will be used.

state_dict() dict#

Returns the state of the ExponentialMovingAverage as a dict.

load_state_dict(state_dict: dict) None#

Loads the ExponentialMovingAverage state.

Parameters:

state_dict (dict) – EMA state. Should be an object returned from a call to state_dict().