import logging
import torch
from torch.nn.utils import clip_grad_norm_
from mlbench_core.aggregation.pytorch.centralized import (
AVG_CUSTOM,
AVG_WORLD,
AllReduceAggregation,
AllReduceAggregationHVD,
)
logger = logging.getLogger("mlbench")
class DynamicLossScaler:
def __init__(
self, init_scale=2.0 ** 15, scale_factor=2.0, scale_window=2000, max_scale=None
):
self.loss_scale = init_scale
self.scale_factor = scale_factor
self.scale_window = scale_window
self._iter = 0
self._last_overflow_iter = -1
self.max_scale = max_scale
def update_scale(self, overflow):
if overflow:
self.loss_scale /= self.scale_factor
self._last_overflow_iter = self._iter
elif (self._iter - self._last_overflow_iter) % self.scale_window == 0:
self.loss_scale *= self.scale_factor
if self.max_scale is not None:
self.loss_scale = min(self.loss_scale, self.max_scale)
self._iter += 1
@staticmethod
def has_overflow(grad_norm):
# detect inf and nan
if grad_norm == float("inf") or grad_norm != grad_norm:
return True
return False
[docs]class FP16Optimizer:
"""
Mixed precision optimizer with dynamic loss scaling and backoff.
https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html#scalefactor
Args:
fp16_model (:obj:`torch.nn.Module`): model (previously casted to half)
world_size (int): Distributed world size
use_cuda (bool): Use cuda tensors for aggregation
use_horovod (bool): Use Horovod for aggregation
by_layer (bool): Aggregate by layer
grad_clip (float): coefficient for gradient clipping, max L2 norm of the gradients
init_scale (int): initial loss scale
scale_factor (float): Factor for upscale/dowscale
scale_window (int): interval for loss scale upscaling
average_world (bool): Average the gradients by world size
average_custom (bool): Divide gradients by given denominator at each step, instead
of `world_size`
divide_before (bool): Divide gradients before reduction (default: False)
"""
def __init__(
self,
fp16_model,
world_size,
use_cuda=False,
use_horovod=False,
by_layer=False,
grad_clip=float("inf"),
init_scale=1024,
scale_factor=2,
scale_window=128,
max_scale=None,
min_scale=1e-4,
average_world=False,
average_custom=False,
divide_before=False,
):
self.use_cuda = use_cuda
self.fp16_model = fp16_model
self.fp32_params = self.initialize_flat_fp32_weight()
self.loss_scaler = DynamicLossScaler(
init_scale=init_scale,
scale_factor=scale_factor,
scale_window=scale_window,
max_scale=max_scale,
)
self.min_scale = min_scale
self.grad_clip = grad_clip
self.optimizer = None
self.world_size = world_size
if use_horovod:
self.agg = AllReduceAggregationHVD(
world_size=world_size, use_cuda=use_cuda, divide_before=divide_before
).agg_grad(by_layer=by_layer)
else:
self.agg = AllReduceAggregation(
world_size=world_size, use_cuda=use_cuda, divide_before=divide_before
).agg_grad(by_layer=by_layer)
if average_world:
self.agg_mode = AVG_WORLD
elif average_custom:
self.agg_mode = AVG_CUSTOM
else:
raise NotImplementedError("Only average model is supported right now.")
def set_optimizer(self, optimizer):
self.optimizer = optimizer
# Flattening master weight
[docs] def initialize_flat_fp32_weight(self):
"""Initializes the model's parameters in fp32
Returns:
(:obj:`torch.Tensor`): The Parameters in fp32
"""
# Get all params that require gradient
params = [p for p in self.fp16_model.parameters() if p.requires_grad]
total_param_size = sum(p.data.numel() for p in params)
# Create new fp32 params
fp32_params = params[0].new(0).float().new(total_param_size)
offset = 0
for p in params:
numel = p.data.numel()
fp32_params[offset : offset + numel].copy_(p.data.view(-1))
offset += numel
fp32_params = torch.nn.Parameter(fp32_params, requires_grad=True)
fp32_params.grad = torch.autograd.Variable(
fp32_params.data.new(*fp32_params.size())
)
return fp32_params
@staticmethod
[docs] def fp16_to_fp32_flat_grad(fp32_params, fp16_model):
"""Copies the parameters in `fp16_model` into `fp32_params` in-place
Args:
fp32_params (torch.Tensor): Parameters in fp32
fp16_model (torch.nn.Module): Model in fp16
"""
flat_grads = torch.cat(
[p.grad.data.view(-1) for p in fp16_model.parameters() if p.requires_grad]
)
fp32_params.grad = flat_grads.to(torch.float32)
@staticmethod
[docs] def fp32_to_fp16_weights(fp16_model, fp32_params):
"""Copies the parameters in `fp32_params` into `fp16_model` in-place
Args:
fp16_model (torch.nn.Module): Model in fp16
fp32_params (torch.Tensor): Parameters in fp32
"""
with torch.no_grad():
pointer = 0
for p in fp16_model.parameters():
if not p.requires_grad:
continue
nelem = p.numel()
p.data.copy_(
fp32_params.data[pointer : pointer + nelem].view_as(p.data)
)
pointer += nelem
[docs] def backward_loss(self, loss):
"""Scales and performs backward on the given loss
Args:
loss (torch.nn.Module): The loss
"""
loss *= self.loss_scaler.loss_scale
loss.backward()
[docs] def step(self, closure=None, tracker=None, multiplier=1, denom=None):
"""
Performs one step of the optimizer.
Applies loss scaling, computes gradients in fp16, converts gradients to
fp32, inverts scaling and applies optional gradient norm clipping.
If gradients are finite, it applies update to fp32 master weights and
copies updated parameters to fp16 model for the next iteration. If
gradients are not finite, it skips the batch and adjusts scaling factor
for the next iteration.
Args:
closure (callable, optional): A closure that reevaluates the model and returns the loss.
tracker (:obj:`mlbench_core.utils.Tracker`, optional) The current tracker
multiplier (float): Multiplier for gradient scaling. Gradient will be scaled using
`scaled_grad = reduced_grad / (loss_scaler * multiplier)`
denom (Optional[:obj:`torch.Tensor`]): Custom denominator to average by
Use with `average_batch`. (default: `None`)
"""
scaling_factor = self.loss_scaler.loss_scale * multiplier
if self.world_size > 1:
# Aggregate gradients
self.agg(self.fp16_model, self.agg_mode, denom=denom)
if tracker:
tracker.record_batch_agg()
# Cast fp16 grads to fp32 for optimizer
self.fp16_to_fp32_flat_grad(self.fp32_params, self.fp16_model)
# UnScale gradients
if scaling_factor != 1.0:
self.fp32_params.grad.data.div_(scaling_factor)
# Clip and compute gradient norm
norm = clip_grad_norm_([self.fp32_params], self.grad_clip)
updated = False
overflow = self.loss_scaler.has_overflow(norm)
self.loss_scaler.update_scale(overflow)
if not overflow:
self.optimizer.step(closure=closure)
self.fp32_to_fp16_weights(self.fp16_model, self.fp32_params)
updated = True
else:
if self.loss_scaler.loss_scale <= self.min_scale:
raise Exception(
"Minimum loss scale ({}) reached".format(self.min_scale)
)
logger.info(f"Skipped batch, new scale: {self.loss_scaler.loss_scale}")
if tracker:
tracker.record_batch_opt_step()
return updated
[docs] def zero_grad(self):
"""Resets the gradients of the optimizer and fp16_model"""
self.optimizer.zero_grad()
self.fp16_model.zero_grad()