returnn
returnn copied to clipboard
Torch multiple simultaneous gradient_checkpoint_scope
There will only be one saved_tensors_hooks
active, specifically for the most recent gradient_checkpoint_scope
. So any of the earlier pack hooks will not be used, when there are multiple simultaneous gradient_checkpoint_scope
s.
Example code:
def get_var1():
with gradient_checkpoint_scope():
return var1 + torch.randn_like(var1)
def get_var2():
with gradient_checkpoint_scope():
return var2 + torch.randn_like(var2)
x = get_var1() * get_var2()
A solution is that we keep a global weak tensor key dictionary for all registered tensors of any gradient_checkpoint_scope
, and in the pack hook, check that instead of the local.
It's currently maybe not so important, as this is a case we likely do not run into (yet; I guess).