Source code for lmflow.optim.novograd
#!/usr/bin/env python
import torch
import torch.optim as optim
[docs]
class NovoGrad(optim.Optimizer):
def __init__(
self, params, lr=0.01, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, grad_averaging=False, amsgrad=False
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, grad_averaging=grad_averaging, amsgrad=amsgrad
)
super().__init__(params, defaults)
[docs]
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("amsgrad", False)
[docs]
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError("NovoGrad does not support sparse gradients")
amsgrad = group["amsgrad"]
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p.data)
state["exp_avg_sq"] = torch.zeros([]).to(state["exp_avg"].device)
if amsgrad:
state["max_exp_avg_sq"] = torch.zeros([]).to(state["exp_avg"].device)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
norm = torch.sum(torch.pow(grad, 2))
if exp_avg_sq == 0:
exp_avg_sq.copy_(norm)
else:
exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group["eps"])
else:
denom = exp_avg_sq.sqrt().add_(group["eps"])
grad.div_(denom)
if group["weight_decay"] != 0:
grad.add_(group["weight_decay"], p.data)
if group["grad_averaging"]:
grad.mul_(1 - beta1)
exp_avg.mul_(beta1).add_(grad)
p.data.add_(-group["lr"], exp_avg)
return loss