Source code for lmflow.optim.radam

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import math
import warnings
import torch
from torch.optim.optimizer import Optimizer

[docs] class RAdam(Optimizer): r"""Implements RAdam optimization algorithm. Note: Deprecated, please use version provided by PyTorch_. It has been proposed in `On the Variance of the Adaptive Learning Rate and Beyond`. https://arxiv.org/abs/1908.03265 Note: Reference code: https://github.com/LiyuanLucasLiu/RAdam """ def __init__( self, params, lr: float = 1e-3, betas = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0, ) -> None: warnings.warn( "RAdam optimizer is deprecated, since it is included " "in pytorch natively.", DeprecationWarning, stacklevel=2, ) if lr <= 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if eps < 0.0: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError( "Invalid beta parameter at index 0: {}".format(betas[0]) ) if not 0.0 <= betas[1] < 1.0: raise ValueError( "Invalid beta parameter at index 1: {}".format(betas[1]) ) if weight_decay < 0: raise ValueError( "Invalid weight_decay value: {}".format(weight_decay) ) if ( isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict) ): for param in params: if "betas" in param and ( param["betas"][0] != betas[0] or param["betas"][1] != betas[1] ): param["buffer"] = [[None, None, None] for _ in range(10)] defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)], ) super(RAdam, self).__init__(params, defaults)
[docs] def __setstate__(self, state): super(RAdam, self).__setstate__(state)
[docs] def step(self, closure = None): r"""Performs a single optimization step. Arguments: closure: A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: lr = group["lr"] weight_decay = group["weight_decay"] beta1, beta2 = group["betas"] eps = group["eps"] for p in group["params"]: if p.grad is None: continue grad = p.grad.data.float() if grad.is_sparse: msg = ( "RAdam does not support sparse gradients, " "please consider SparseAdam instead" ) raise RuntimeError(msg) p_data_fp32 = p.data.float() state = self.state[p] if len(state) == 0: state["step"] = 0 state["exp_avg"] = torch.zeros_like( p_data_fp32, memory_format=torch.preserve_format ) state["exp_avg_sq"] = torch.zeros_like( p_data_fp32, memory_format=torch.preserve_format ) else: state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) state["exp_avg_sq"] = state["exp_avg_sq"].type_as( p_data_fp32 ) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) state["step"] += 1 buffered = group["buffer"][int(state["step"] % 10)] if state["step"] == buffered[0]: N_sma, step_size = buffered[1], buffered[2] else: buffered[0] = state["step"] beta2_t = beta2 ** state["step"] N_sma_max = 2 / (1 - beta2) - 1 N_sma = N_sma_max - 2 * state["step"] * beta2_t / ( 1 - beta2_t ) buffered[1] = N_sma # more conservative since it's an approximated value if N_sma >= 5: step_size = ( lr * math.sqrt( (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2) ) / (1 - beta1 ** state["step"]) ) else: step_size = lr / (1 - beta1 ** state["step"]) buffered[2] = step_size if weight_decay != 0: p_data_fp32.add_(p_data_fp32, alpha=-weight_decay * lr) # more conservative since it's an approximated value if N_sma >= 5: denom = exp_avg_sq.sqrt().add_(eps) p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size) else: p_data_fp32.add_(exp_avg, alpha=-step_size) p.data.copy_(p_data_fp32) return loss