lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Implement _set_grad_enabled of torch._C

Open athitten opened this issue 1 year ago • 2 comments

🚀 Feature

Implement _set_grad_enabled of torch._C

Motivation

NeMo Vision Transformer

cc @tfogal

athitten avatar May 01 '24 15:05 athitten

triage team: looking to understand if this is high-effort or low-effort (honestly, looking for that on all NeMo things, this one's just particularly out of my depth).

tfogal avatar May 01 '24 16:05 tfogal

triage review — implementing this operation is probably out of scope for thunder, and the model may need to be revised to target thunder

mruberry avatar May 06 '24 19:05 mruberry

torch.autograd.grad_mode.set_grad_enabled is PyTorch's Autograd concept and is out of scope for Thunder. Options for "supporting" it is to make Thunder's JIT interpreter ignore it and not raise an error (if there's an error today).

IvanYashchuk avatar Jun 05 '24 09:06 IvanYashchuk

model may need to be revised to target thunder

I'm not sure "revise the model" is going to be a reasonable solution in the general case; already seen this in GPT, CLIPEncoder, and a vision transformer from NeMo.

make Thunder's JIT interpreter ignore it and not raise an error (if there's an error today).

There is an error today. "Ignore it" seems like a good approach to at least make progress on identifying issues... will we still produce the same computation as torch if we do that?

tfogal avatar Jun 05 '24 17:06 tfogal

"Ignore it" seems like a good approach to at least make progress on identifying issues... will we still produce the same computation as torch if we do that?

Using no_grad/enabled_grad now works without an error on main branch. We will still produce the same forward computation but we will hold onto intermediates (and be able to compute gradients) if one of the input has requires_grad=True even if no_grad was enabled. I think we should add a warning that these operations will be ignored when used with thunder.jit.

import torch
import thunder

def foo(x):   
    # with torch.enable_grad():
    #     x =  x + 1
    with torch.no_grad():
        return x + 1

jfoo = thunder.jit(foo)
x = torch.randn(3, requires_grad=True)
o = jfoo(x)
o.sum().backward()
print(x.grad)

# print(thunder.last_traces(jfoo))

kshitij12345 avatar Jun 13 '24 08:06 kshitij12345