Source code for lmflow.optim.sgd_schedule_free

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# 
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch
import torch.optim

[docs] class SGDScheduleFree(torch.optim.Optimizer): r""" Schedule-Free SGD As the name suggests, no scheduler is needed with this optimizer. To add warmup, rather than using a learning rate schedule you can just set the warmup_steps parameter. This optimizer requires that .train() and .eval() be called before the beginning of training and evaluation respectively. The optimizer should also be placed in eval mode when saving checkpoints. """ def __init__(self, params, lr=1.0, momentum=0.9, weight_decay=0, warmup_steps=0, r=0.0, weight_lr_power=2, foreach=hasattr(torch, "_foreach_mul_"), ): if lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if weight_decay < 0.0: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if momentum <= 0 or momentum >= 1: raise ValueError("Momentum must be between 0 and 1 exclusive: {}".format(momentum)) defaults = dict(lr=lr, momentum=momentum, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay, foreach=foreach) super().__init__(params, defaults)
[docs] def eval(self): for group in self.param_groups: train_mode = group['train_mode'] momentum = group['momentum'] if train_mode: for p in group['params']: state = self.state[p] if 'z' in state: # Set p.data to x p.data.lerp_(end=state['z'], weight=1-1/momentum) group['train_mode'] = False
[docs] def train(self): for group in self.param_groups: train_mode = group['train_mode'] momentum = group['momentum'] if not train_mode: for p in group['params']: state = self.state[p] if 'z' in state: # Set p.data to y p.data.lerp_(end=state['z'], weight=1-momentum) group['train_mode'] = True
[docs] def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: momentum = group['momentum'] lr = group['lr'] weight_decay = group['weight_decay'] k = group['k'] warmup_steps = group['warmup_steps'] if k < warmup_steps: sched = (k+1) / warmup_steps else: sched = 1.0 lr = group['lr']*sched weight_lr_power = group['weight_lr_power'] r = group['r'] lr_max = group['lr_max'] = max(lr, group['lr_max']) weight = ((k+1)**r) * (lr_max**weight_lr_power) weight_sum = group['weight_sum'] = group['weight_sum'] + weight try: ckp1 = weight/weight_sum except ZeroDivisionError: ckp1 = 0 if not group['train_mode']: raise Exception("Not in train mode!") active_p = [p for p in group['params'] if p.grad is not None] for p in active_p: if 'z' not in self.state[p]: self.state[p]['z'] = torch.clone(p.data) if group['foreach'] and len(active_p) > 0: y, grad, z = zip(*[(p.data, p.grad, self.state[p]['z']) for p in active_p]) # Apply weight decay if weight_decay != 0: torch._foreach_add_(grad, y, alpha=weight_decay) # These operations update y in-place, # without computing x explicitly. torch._foreach_lerp_(y, z, weight=ckp1) torch._foreach_add_(y, grad, alpha=lr*(momentum*(1-ckp1)-1)) # SGD step torch._foreach_sub_(z, grad, alpha=lr) else: for p in active_p: y = p.data # Notation to match theory grad = p.grad.data z = self.state[p]['z'] # Apply weight decay if weight_decay != 0: grad.add_(y, alpha=weight_decay) # These operations update y in-place, # without computing x explicitly. y.lerp_(end=z, weight=ckp1) y.add_(grad, alpha=lr*(momentum*(1-ckp1)-1)) # SGD step z.sub_(grad, alpha=lr) group['k'] = k+1 return loss