lmflow.optim.lars#
Classes#
Extends SGD in PyTorch with LARS scaling from the paper |
Module Contents#
- class lmflow.optim.lars.LARS(params, lr: float = 0.01, momentum: float = 0.0, dampening: float = 0.0, weight_decay: float = 0.0, nesterov: bool = False, trust_coefficient: float = 0.01, eps: float = 1e-08)[source]#
Bases:
torch.optim.optimizer.Optimizer
Extends SGD in PyTorch with LARS scaling from the paper Large batch training of Convolutional Networks. .. note:
The application of momentum in the SGD part is modified according to the PyTorch standards. LARS scaling fits into the equation in the following fashion. .. math:: \begin{aligned} g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\ v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\ p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, \\end{aligned} where :math:`p`, :math:`g`, :math:`v`, :math:`\\mu` and :math:`\beta` denote the parameters, gradient, velocity, momentum, and weight decay respectively. The :math:`lars_lr` is defined by Eq. 6 in the paper. The Nesterov version is analogously modified.
Warning
Parameters with weight decay set to 0 will automatically be excluded from layer-wise LR scaling. This is to ensure consistency with papers like SimCLR and BYOL.
- Note:
Reference code: PyTorchLightning/lightning-bolts