lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

`disable_torch_autograd_support` should consider `no_grad` and `inference_mode`

Open carmocca opened this issue 2 years ago • 1 comments

🚀 Feature

Per title

Motivation

I'm running a benchmark of my own with @torch.inference_mode() (I also tried @torch.no_grad)

thunder.compile is failing with NotImplementedError: VJP for PrimIDs.RECIPROCAL is not implemented even though the benchmark doesn't run backward

Pitch

model = thunder.jit(model, disable_torch_autograd=True)

# is equivalent to

with torch.inference_mode():
    model = thunder.jit(model)

# and

with torch.no_grad():
    model = thunder.jit(model)

With careful consideration if the user also passes a conflicting disable_torch_autograd=...

cc @carmocca @borda

carmocca avatar Sep 12 '23 12:09 carmocca

I think we can capture if we're in a no grad context when compilation begins. We should be careful to document that we only capture it when compiling, however. So if a practitioner runs the compiled function with a different setting later it will not pick up the change.

mruberry avatar Sep 12 '23 17:09 mruberry