lightning-thunder
lightning-thunder copied to clipboard
Implement _set_grad_enabled of torch._C
triage team: looking to understand if this is high-effort or low-effort (honestly, looking for that on all NeMo things, this one's just particularly out of my depth).
triage review — implementing this operation is probably out of scope for thunder, and the model may need to be revised to target thunder
torch.autograd.grad_mode.set_grad_enabled is PyTorch's Autograd concept and is out of scope for Thunder. Options for "supporting" it is to make Thunder's JIT interpreter ignore it and not raise an error (if there's an error today).
model may need to be revised to target thunder
I'm not sure "revise the model" is going to be a reasonable solution in the general case; already seen this in GPT, CLIPEncoder, and a vision transformer from NeMo.
make Thunder's JIT interpreter ignore it and not raise an error (if there's an error today).
There is an error today. "Ignore it" seems like a good approach to at least make progress on identifying issues... will we still produce the same computation as torch if we do that?
"Ignore it" seems like a good approach to at least make progress on identifying issues... will we still produce the same computation as torch if we do that?
Using no_grad/enabled_grad now works without an error on main branch. We will still produce the same forward computation but we will hold onto intermediates (and be able to compute gradients) if one of the input has requires_grad=True even if no_grad was enabled. I think we should add a warning that these operations will be ignored when used with thunder.jit.
import torch
import thunder
def foo(x):
# with torch.enable_grad():
# x = x + 1
with torch.no_grad():
return x + 1
jfoo = thunder.jit(foo)
x = torch.randn(3, requires_grad=True)
o = jfoo(x)
o.sum().backward()
print(x.grad)
# print(thunder.last_traces(jfoo))