returnn icon indicating copy to clipboard operation
returnn copied to clipboard

Torch multiple simultaneous gradient_checkpoint_scope

Open albertz opened this issue 7 months ago • 0 comments

There will only be one saved_tensors_hooks active, specifically for the most recent gradient_checkpoint_scope. So any of the earlier pack hooks will not be used, when there are multiple simultaneous gradient_checkpoint_scopes.

Example code:


def get_var1():
    with gradient_checkpoint_scope():
        return var1 + torch.randn_like(var1)

def get_var2():
    with gradient_checkpoint_scope():
        return var2 + torch.randn_like(var2)

x = get_var1() * get_var2()

A solution is that we keep a global weak tensor key dictionary for all registered tensors of any gradient_checkpoint_scope, and in the pack hook, check that instead of the local.

It's currently maybe not so important, as this is a case we likely do not run into (yet; I guess).

albertz avatar Jul 15 '24 13:07 albertz