Source code for lmflow.optim.novograd

#!/usr/bin/env python

import torch
import torch.optim as optim


[docs] class NovoGrad(optim.Optimizer): def __init__( self, params, lr=0.01, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, grad_averaging=False, amsgrad=False ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, grad_averaging=grad_averaging, amsgrad=amsgrad ) super().__init__(params, defaults)
[docs] def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault("amsgrad", False)
[docs] def step(self, closure=None): loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: raise RuntimeError("NovoGrad does not support sparse gradients") amsgrad = group["amsgrad"] state = self.state[p] # State initialization if len(state) == 0: state["step"] = 0 state["exp_avg"] = torch.zeros_like(p.data) state["exp_avg_sq"] = torch.zeros([]).to(state["exp_avg"].device) if amsgrad: state["max_exp_avg_sq"] = torch.zeros([]).to(state["exp_avg"].device) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] if amsgrad: max_exp_avg_sq = state["max_exp_avg_sq"] beta1, beta2 = group["betas"] state["step"] += 1 norm = torch.sum(torch.pow(grad, 2)) if exp_avg_sq == 0: exp_avg_sq.copy_(norm) else: exp_avg_sq.mul_(beta2).add_(1 - beta2, norm) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) # Use the max. for normalizing running avg. of gradient denom = max_exp_avg_sq.sqrt().add_(group["eps"]) else: denom = exp_avg_sq.sqrt().add_(group["eps"]) grad.div_(denom) if group["weight_decay"] != 0: grad.add_(group["weight_decay"], p.data) if group["grad_averaging"]: grad.mul_(1 - beta1) exp_avg.mul_(beta1).add_(grad) p.data.add_(-group["lr"], exp_avg) return loss