lmflow.optim.lars#

Classes#

LARS

Extends SGD in PyTorch with LARS scaling from the paper

Module Contents#

class lmflow.optim.lars.LARS(params, lr: float = 0.01, momentum: float = 0.0, dampening: float = 0.0, weight_decay: float = 0.0, nesterov: bool = False, trust_coefficient: float = 0.01, eps: float = 1e-08)[source]#

Bases: torch.optim.optimizer.Optimizer

Extends SGD in PyTorch with LARS scaling from the paper Large batch training of Convolutional Networks. .. note:

The application of momentum in the SGD part is modified according to
the PyTorch standards. LARS scaling fits into the equation in the
following fashion.

.. math::
    \begin{aligned}
        g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\
        v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\
        p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
    \\end{aligned}

where :math:`p`, :math:`g`, :math:`v`, :math:`\\mu` and :math:`\beta`
denote the parameters, gradient, velocity, momentum, and weight decay
respectively.  The :math:`lars_lr` is defined by Eq. 6 in the paper.
The Nesterov version is analogously modified.

Warning

Parameters with weight decay set to 0 will automatically be excluded from layer-wise LR scaling. This is to ensure consistency with papers like SimCLR and BYOL.

Note:

Reference code: PyTorchLightning/lightning-bolts

__setstate__(state) None[source]#
step(closure=None)[source]#

Performs a single optimization step.

Arguments:

closure: A closure that reevaluates the model and returns the loss.