|
import torch |
|
from decoupled_utils import is_torch_xla_available |
|
try: |
|
if not is_torch_xla_available(): |
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled |
|
import deepspeed |
|
except: |
|
is_deepspeed_zero3_enabled = lambda: False |
|
|
|
class ExponentialMovingAverage: |
|
""" |
|
WARNING: DEPRECATED |
|
Maintains (exponential) moving average of a set of parameters. |
|
""" |
|
|
|
def __init__(self, parameters, decay, use_num_updates=True): |
|
""" |
|
Args: |
|
parameters: Iterable of `torch.nn.Parameter`; usually the result of |
|
`model.parameters()`. |
|
decay: The exponential decay. |
|
use_num_updates: Whether to use number of updates when computing |
|
averages. |
|
""" |
|
if decay < 0.0 or decay > 1.0: |
|
raise ValueError('Decay must be between 0 and 1') |
|
self.decay = decay |
|
self.num_updates = 0 if use_num_updates else None |
|
self.shadow_params = [p.clone().detach() |
|
for p in parameters if p.requires_grad] |
|
self.collected_params = [] |
|
|
|
def move_shadow_params_to_device(self, device): |
|
self.shadow_params = [i.to(device) for i in self.shadow_params] |
|
|
|
def update(self, parameters): |
|
""" |
|
Update currently maintained parameters. |
|
|
|
Call this every time the parameters are updated, such as the result of |
|
the `optimizer.step()` call. |
|
|
|
Args: |
|
parameters: Iterable of `torch.nn.Parameter`; usually the same set of |
|
parameters used to initialize this object. |
|
""" |
|
decay = self.decay |
|
if self.num_updates is not None: |
|
self.num_updates += 1 |
|
decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) |
|
one_minus_decay = 1.0 - decay |
|
with torch.no_grad(): |
|
parameters = [p for p in parameters if p.requires_grad] |
|
for s_param, param in zip(self.shadow_params, parameters): |
|
s_param.sub_(one_minus_decay * (s_param - param)) |
|
|
|
def copy_to(self, parameters): |
|
""" |
|
Copy current parameters into given collection of parameters. |
|
|
|
Args: |
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
|
updated with the stored moving averages. |
|
""" |
|
parameters = [p for p in parameters if p.requires_grad] |
|
for s_param, param in zip(self.shadow_params, parameters): |
|
if param.requires_grad: |
|
param.data.copy_(s_param.data) |
|
|
|
def store(self, parameters): |
|
""" |
|
Save the current parameters for restoring later. |
|
|
|
Args: |
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
|
temporarily stored. |
|
""" |
|
self.collected_params = [param.clone() for param in parameters] |
|
|
|
def restore(self, parameters): |
|
""" |
|
Restore the parameters stored with the `store` method. |
|
Useful to validate the model with EMA parameters without affecting the |
|
original optimization process. Store the parameters before the |
|
`copy_to` method. After validation (or model saving), use this to |
|
restore the former parameters. |
|
|
|
Args: |
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
|
updated with the stored parameters. |
|
""" |
|
for c_param, param in zip(self.collected_params, parameters): |
|
param.data.copy_(c_param.data) |
|
|
|
def state_dict(self): |
|
return dict(decay=self.decay, |
|
num_updates=self.num_updates, |
|
shadow_params=self.shadow_params) |
|
|
|
def load_state_dict(self, state_dict): |
|
self.decay = state_dict['decay'] |
|
self.num_updates = state_dict['num_updates'] |
|
self.shadow_params = state_dict['shadow_params'] |
|
|
|
|
|
|
|
from diffusers.utils import ( |
|
is_transformers_available, |
|
) |
|
from typing import Iterable, Union, Optional |
|
import contextlib |
|
import transformers |
|
import copy |
|
|
|
|
|
class EMAModel: |
|
""" |
|
Exponential Moving Average of models weights |
|
""" |
|
|
|
def __init__( |
|
self, |
|
parameters: Iterable[torch.nn.Parameter], |
|
decay: float = 0.9999, |
|
min_decay: float = 0.0, |
|
update_after_step: int = 0, |
|
use_ema_warmup: bool = False, |
|
inv_gamma: Union[float, int] = 1.0, |
|
power: Union[float, int] = 2 / 3, |
|
foreach: bool = False, |
|
): |
|
""" |
|
Args: |
|
parameters (Iterable[torch.nn.Parameter]): The parameters to track. |
|
decay (float): The decay factor for the exponential moving average. |
|
min_decay (float): The minimum decay factor for the exponential moving average. |
|
update_after_step (int): The number of steps to wait before starting to update the EMA weights. |
|
use_ema_warmup (bool): Whether to use EMA warmup. |
|
inv_gamma (float): |
|
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. |
|
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. |
|
foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster. |
|
device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA |
|
weights will be stored on CPU. |
|
|
|
@crowsonkb's notes on EMA Warmup: |
|
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan |
|
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), |
|
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 |
|
at 215.4k steps). |
|
""" |
|
|
|
parameters = list(parameters) |
|
self.shadow_params = [p.clone().detach() for p in parameters] |
|
|
|
self.temp_stored_params = None |
|
|
|
self.decay = decay |
|
self.min_decay = min_decay |
|
self.update_after_step = update_after_step |
|
self.use_ema_warmup = use_ema_warmup |
|
self.inv_gamma = inv_gamma |
|
self.power = power |
|
self.optimization_step = 0 |
|
self.cur_decay_value = None |
|
self.foreach = foreach |
|
|
|
def get_decay(self, optimization_step: int) -> float: |
|
""" |
|
Compute the decay factor for the exponential moving average. |
|
""" |
|
step = max(0, optimization_step - self.update_after_step - 1) |
|
|
|
if step <= 0: |
|
return 0.0 |
|
|
|
if self.use_ema_warmup: |
|
cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power |
|
else: |
|
cur_decay_value = (1 + step) / (10 + step) |
|
|
|
cur_decay_value = min(cur_decay_value, self.decay) |
|
|
|
cur_decay_value = max(cur_decay_value, self.min_decay) |
|
return cur_decay_value |
|
|
|
@torch.no_grad() |
|
def step(self, parameters: Iterable[torch.nn.Parameter]): |
|
parameters = list(parameters) |
|
|
|
self.optimization_step += 1 |
|
|
|
|
|
decay = self.get_decay(self.optimization_step) |
|
self.cur_decay_value = decay |
|
one_minus_decay = 1 - decay |
|
|
|
context_manager = contextlib.nullcontext |
|
|
|
if self.foreach: |
|
if is_transformers_available() and is_deepspeed_zero3_enabled(): |
|
context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None) |
|
|
|
with context_manager(): |
|
params_grad = [param for param in parameters if param.requires_grad] |
|
s_params_grad = [ |
|
s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad |
|
] |
|
|
|
if len(params_grad) < len(parameters): |
|
torch._foreach_copy_( |
|
[s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad], |
|
[param for param in parameters if not param.requires_grad], |
|
non_blocking=True, |
|
) |
|
|
|
torch._foreach_sub_( |
|
s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay |
|
) |
|
|
|
else: |
|
for s_param, param in zip(self.shadow_params, parameters): |
|
if is_transformers_available() and is_deepspeed_zero3_enabled(): |
|
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) |
|
|
|
with context_manager(): |
|
if param.requires_grad: |
|
s_param.sub_(one_minus_decay * (s_param - param)) |
|
else: |
|
s_param.copy_(param) |
|
|
|
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: |
|
""" |
|
Copy current averaged parameters into given collection of parameters. |
|
|
|
Args: |
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
|
updated with the stored moving averages. If `None`, the parameters with which this |
|
`ExponentialMovingAverage` was initialized will be used. |
|
""" |
|
parameters = list(parameters) |
|
if self.foreach: |
|
torch._foreach_copy_( |
|
[param.data for param in parameters], |
|
[s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)], |
|
) |
|
else: |
|
for s_param, param in zip(self.shadow_params, parameters): |
|
param.data.copy_(s_param.to(param.device).data) |
|
|
|
def pin_memory(self) -> None: |
|
r""" |
|
Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for |
|
offloading EMA params to the host. |
|
""" |
|
|
|
self.shadow_params = [p.pin_memory() for p in self.shadow_params] |
|
|
|
def to(self, device=None, dtype=None, non_blocking=False) -> None: |
|
r"""Move internal buffers of the ExponentialMovingAverage to `device`. |
|
|
|
Args: |
|
device: like `device` argument to `torch.Tensor.to` |
|
""" |
|
|
|
self.shadow_params = [ |
|
p.to(device=device, dtype=dtype, non_blocking=non_blocking) |
|
if p.is_floating_point() |
|
else p.to(device=device, non_blocking=non_blocking) |
|
for p in self.shadow_params |
|
] |
|
|
|
def state_dict(self) -> dict: |
|
r""" |
|
Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during |
|
checkpointing to save the ema state dict. |
|
""" |
|
|
|
|
|
|
|
return { |
|
"decay": self.decay, |
|
"min_decay": self.min_decay, |
|
"optimization_step": self.optimization_step, |
|
"update_after_step": self.update_after_step, |
|
"use_ema_warmup": self.use_ema_warmup, |
|
"inv_gamma": self.inv_gamma, |
|
"power": self.power, |
|
"shadow_params": self.shadow_params, |
|
} |
|
|
|
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: |
|
r""" |
|
Args: |
|
Save the current parameters for restoring later. |
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
|
temporarily stored. |
|
""" |
|
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] |
|
|
|
def restore(self, parameters: Iterable[torch.nn.Parameter], raise_error_if_already_restored: bool = True) -> None: |
|
r""" |
|
Args: |
|
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: |
|
affecting the original optimization process. Store the parameters before the `copy_to()` method. After |
|
validation (or model saving), use this to restore the former parameters. |
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
|
updated with the stored parameters. If `None`, the parameters with which this |
|
`ExponentialMovingAverage` was initialized will be used. |
|
""" |
|
if self.temp_stored_params is None: |
|
if raise_error_if_already_restored: |
|
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") |
|
return |
|
if self.foreach: |
|
torch._foreach_copy_( |
|
[param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params] |
|
) |
|
else: |
|
for c_param, param in zip(self.temp_stored_params, parameters): |
|
param.data.copy_(c_param.data) |
|
|
|
|
|
self.temp_stored_params = None |
|
|
|
def load_state_dict(self, state_dict: dict) -> None: |
|
r""" |
|
Args: |
|
Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the |
|
ema state dict. |
|
state_dict (dict): EMA state. Should be an object returned |
|
from a call to :meth:`state_dict`. |
|
""" |
|
|
|
state_dict = copy.deepcopy(state_dict) |
|
|
|
self.decay = state_dict.get("decay", self.decay) |
|
if self.decay < 0.0 or self.decay > 1.0: |
|
raise ValueError("Decay must be between 0 and 1") |
|
|
|
self.min_decay = state_dict.get("min_decay", self.min_decay) |
|
if not isinstance(self.min_decay, float): |
|
raise ValueError("Invalid min_decay") |
|
|
|
self.optimization_step = state_dict.get("optimization_step", self.optimization_step) |
|
if not isinstance(self.optimization_step, int): |
|
raise ValueError("Invalid optimization_step") |
|
|
|
self.update_after_step = state_dict.get("update_after_step", self.update_after_step) |
|
if not isinstance(self.update_after_step, int): |
|
raise ValueError("Invalid update_after_step") |
|
|
|
self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) |
|
if not isinstance(self.use_ema_warmup, bool): |
|
raise ValueError("Invalid use_ema_warmup") |
|
|
|
self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) |
|
if not isinstance(self.inv_gamma, (float, int)): |
|
raise ValueError("Invalid inv_gamma") |
|
|
|
self.power = state_dict.get("power", self.power) |
|
if not isinstance(self.power, (float, int)): |
|
raise ValueError("Invalid power") |
|
|
|
shadow_params = state_dict.get("shadow_params", None) |
|
if shadow_params is not None: |
|
self.shadow_params = shadow_params |
|
if not isinstance(self.shadow_params, list): |
|
raise ValueError("shadow_params must be a list") |
|
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): |
|
raise ValueError("shadow_params must all be Tensors") |
|
|