Source code for lmflow.optim.adamw_schedule_free

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# 
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch
import torch.optim
import math

[docs] class AdamWScheduleFree(torch.optim.Optimizer): r""" Schedule-Free AdamW As the name suggests, no scheduler is needed with this optimizer. To add warmup, rather than using a learning rate schedule you can just set the warmup_steps parameter. This optimizer requires that .train() and .eval() be called before the beginning of training and evaluation respectively. The optimizer should also be placed in eval mode when saving checkpoints. """ def __init__(self, params, lr=0.0025, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0, weight_lr_power=2.0, foreach=hasattr(torch, "_foreach_mul_") ): defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay, foreach=foreach) super().__init__(params, defaults)
[docs] def eval(self): for group in self.param_groups: train_mode = group['train_mode'] beta1, _ = group['betas'] if train_mode: for p in group['params']: state = self.state[p] if 'z' in state: # Set p.data to x p.data.lerp_(end=state['z'], weight=1-1/beta1) group['train_mode'] = False
[docs] def train(self): for group in self.param_groups: train_mode = group['train_mode'] beta1, _ = group['betas'] if not train_mode: for p in group['params']: state = self.state[p] if 'z' in state: # Set p.data to y p.data.lerp_(end=state['z'], weight=1-beta1) group['train_mode'] = True
[docs] def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: eps = group['eps'] beta1, beta2 = group['betas'] decay = group['weight_decay'] k = group['k'] r = group['r'] warmup_steps = group['warmup_steps'] weight_lr_power = group['weight_lr_power'] if k < warmup_steps: sched = (k+1) / warmup_steps else: sched = 1.0 bias_correction2 = 1 - beta2 ** (k+1) lr = group['lr']*sched*math.sqrt(bias_correction2) lr_max = group['lr_max'] = max(lr, group['lr_max']) weight = ((k+1)**r) * (lr_max**weight_lr_power) weight_sum = group['weight_sum'] = group['weight_sum'] + weight try: ckp1 = weight/weight_sum except ZeroDivisionError: ckp1 = 0 if not group['train_mode']: raise Exception("Not in train mode!") active_p = [p for p in group['params'] if p.grad is not None] for p in active_p: if 'z' not in self.state[p]: self.state[p]['z'] = torch.clone(p.data) self.state[p]['exp_avg_sq'] = torch.zeros_like(p.data) if group['foreach'] and len(active_p) > 0: y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state[p]['exp_avg_sq'], self.state[p]['z']) for p in active_p]) # Decay the first and second moment running average coefficient torch._foreach_mul_(exp_avg_sq, beta2) torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1-beta2) denom = torch._foreach_sqrt(exp_avg_sq) torch._foreach_add_(denom, eps) # Normalize grad in-place for memory efficiency torch._foreach_div_(grad, denom) # Weight decay calculated at y if decay != 0: torch._foreach_add_(grad, y, alpha=decay) # These operations update y in-place, # without computing x explicitly. torch._foreach_lerp_(y, z, weight=ckp1) torch._foreach_add_(y, grad, alpha=lr*(beta1*(1-ckp1)-1)) # z step torch._foreach_sub_(z, grad, alpha=lr) else: for p in active_p: y = p.data # Notation to match theory grad = p.grad.data state = self.state[p] z = state['z'] exp_avg_sq = state['exp_avg_sq'] exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) denom = exp_avg_sq.sqrt().add_(eps) # Reuse grad buffer for memory efficiency grad_normalized = grad.div_(denom) # Weight decay calculated at y if decay != 0: grad_normalized.add_(y, alpha=decay) # These operations update y in-place, # without computing x explicitly. y.lerp_(end=z, weight=ckp1) y.add_(grad_normalized, alpha=lr*(beta1*(1-ckp1)-1)) # z step z.sub_(grad_normalized, alpha=lr) group['k'] = k+1 return loss