Source code for mlbench_core.optim.pytorch.decentralized

from torch.optim import SGD
from torch.optim.optimizer import required

from mlbench_core.aggregation.pytorch.decentralized import DecentralizedAggregation


[docs]class DecentralizedSGD(SGD): r"""Implements decentralized stochastic gradient descent (optionally with momentum). Args: rank (int): rank of current process in the network neighbors (list): list of ranks of the neighbors of current process model (:obj:`nn.Module`): model which contains parameters for SGD lr (float): learning rate momentum (float, optional): momentum factor (default: 0) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) dampening (float, optional): dampening for momentum (default: 0) nesterov (bool, optional): enables Nesterov momentum (default: False) average_world (bool): Whether to average models on the world_size (default: `True`) use_cuda (bool): Whether to use cuda tensors for aggregation by_layer (bool): Aggregate by layer instead of all layers at once """ def __init__( self, rank=None, neighbors=None, model=None, lr=required, momentum=0, dampening=0, weight_decay=0, nesterov=False, average_world=True, use_cuda=False, by_layer=False, ): if not rank: raise ValueError('"rank" not set for optimizer') if not neighbors: raise ValueError('"neighbors" not set for optimizer') if not model: raise ValueError('"model" not set for optimizer') super(DecentralizedSGD, self).__init__( model.parameters(), lr, momentum, dampening, weight_decay, nesterov ) if average_world: self.agg_mode = "avg_world" else: raise NotImplementedError("Only average model is supported right now.") self.model = model self.agg = DecentralizedAggregation( rank, neighbors, use_cuda=use_cuda ).agg_model(by_layer=by_layer)
[docs] def step(self, closure=None, tracker=None): """Aggregates the gradients and performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. tracker (:obj:`mlbench_core.utils.Tracker`, optional) The current tracker """ loss = super(DecentralizedSGD, self).step(closure=closure) if tracker: tracker.record_batch_opt_step() # Averaging the model after updating the gradient separately. self.agg(self.model, self.agg_mode) if tracker: tracker.record_batch_agg() return loss