automl
automl copied to clipboard
More inplace ops for pytorch lion's impl
https://github.com/google/automl/blob/master/lion/lion_pytorch.py#L79: Now:
update = exp_avg * beta1 + grad * (1 - beta1)
p.add_(torch.sign(update), alpha=-group['lr'])
# Decay the momentum running average coefficient
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
Can be:
update = torch.lerp(grad, exp_avg, beta1)
p.add_(update.sign_(), alpha=-group['lr'])
# Decay the momentum running average coefficient
exp_avg.lerp_(grad, 1 - beta2)
cc @crazydonkey200