GradCache icon indicating copy to clipboard operation
GradCache copied to clipboard

Implement Grokfast into GradCache

Open ben-walczak opened this issue 1 year ago • 2 comments

I would like to implement the algorithm for grokfast, which is an exponentially weighted mean of past gradients added to the current gradients, with GradCache. I've been able to use it without GradCache, but I'm confused where it could be implemented with GradCache as I'm still learning the underlying mechanisms of GradCache. Any direction on how this might be done? Also curious if this would be an appropriate feature to this library

ben-walczak avatar Jul 09 '24 12:07 ben-walczak

an exponentially weighted mean of past gradients added to the current gradients

shouldn't this be tracked in the optimizer (state) and applied as a gradient transformation?

luyug avatar Jul 09 '24 18:07 luyug

I'm not entirely sure. Just spitballing but I think it would be implemented somewhere within these lines of codes: https://github.com/luyug/GradCache/blob/main/src/grad_cache/grad_cache.py#L193C9-L211

and use similar logic to the following:

def gradfilter_ema(
    m: nn.Module,
    grads: Optional[Dict[str, torch.Tensor]] = None,
    alpha: float = 0.99,
    lamb: float = 5.0,
) -> Dict[str, torch.Tensor]:
    if grads is None:
        grads = {n: p.grad.data.detach() for n, p in m.named_parameters() if p.requires_grad}

    for n, p in m.named_parameters():
        if p.requires_grad:
            grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha)
            p.grad.data = p.grad.data + grads[n] * lamb

    return grads

I'll test this out later when I get a chance

ben-walczak avatar Jul 15 '24 15:07 ben-walczak