Support `torch.Tensor.register_hook`
🚀 Feature
Motivation
In Lightning Fabric, we use this once for error checking that the user properly called backward. https://github.com/Lightning-AI/pytorch-lightning/blob/096b063d6eeb41567409f4a6b9bac6f5af28ed93/src/lightning/fabric/wrappers.py#L232-L233. cc @awaelchli
I don't expect that we run this hook properly on backward, but it would be useful to simply ignore it and not fail on it, maybe showing a warning.
Pitch
import thunder
import torch
def hook(_):
print("Hello")
def fn(x):
y = x * 2
y.register_hook(hook)
return y
t = torch.tensor([1.0], requires_grad=True)
fn = thunder.jit(fn)
out = fn(t)
out.backward()
print(out)
y.register_hook(hook)
File "/home/carmocca/git/lightning-thunder/thunder/core/interpreter.py", line 5862, in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
File "/home/carmocca/git/lightning-thunder/thunder/core/interpreter.py", line 1243, in jit_wrapped
res = ufn(*uargs, **ukwargs)
File "/home/carmocca/git/lightning-thunder/thunder/core/proxies.py", line 1210, in __getattr__
method: None | Callable = resolve_method(attr, self)
File "/home/carmocca/git/lightning-thunder/thunder/core/langctxs.py", line 68, in resolve_method
method: Callable = ctx.get_method(id, *args, **kwargs)
File "/home/carmocca/git/lightning-thunder/thunder/torch/langctx.py", line 40, in get_method
raise AttributeError(f"The {self.name} language context has no method {id}")
AttributeError: The torch language context has no method register_hook
Additional context
You mean something along the lines of Lightning-AI/lit-thunder-LEGACY#1779 ?
Oh yes perfect. I was happy with just not erroring out because otherwise we would need to comment this out in Fabric if we want to compile forward and the loss together
Thinking about this more, it's not so clear that it is a reasonable implementation though, because the JITed things will need backward hook calls generated for them because we don't use the autograd engine for it..