lightning-thunder
lightning-thunder copied to clipboard
`disable_torch_autograd_support` should consider `no_grad` and `inference_mode`
🚀 Feature
Per title
Motivation
I'm running a benchmark of my own with @torch.inference_mode() (I also tried @torch.no_grad)
thunder.compile is failing with NotImplementedError: VJP for PrimIDs.RECIPROCAL is not implemented even though the benchmark doesn't run backward
Pitch
model = thunder.jit(model, disable_torch_autograd=True)
# is equivalent to
with torch.inference_mode():
model = thunder.jit(model)
# and
with torch.no_grad():
model = thunder.jit(model)
With careful consideration if the user also passes a conflicting disable_torch_autograd=...
cc @carmocca @borda
I think we can capture if we're in a no grad context when compilation begins. We should be careful to document that we only capture it when compiling, however. So if a practitioner runs the compiled function with a different setting later it will not pick up the change.