If saved_for_backward returns NumberProxy, the value is taken from compile time, not runtime
🐛 Bug
When I add an operator that returns numbers, the values in the saved_for_backward are the compile time value defined in Symbol.meta, not the real value computed at runtime.
An example trace is like:
@torch.no_grad()
@no_autocast()
def augmented_forward_fn(a):
# a: "cuda:0 f32[2, 2]"
t2 = get_rng_state_prim_impl(None, devices.Device("cuda:0")) # t2: "cpu ui8[16]"
(i3, i4) = unpack_rng_state_prim_impl(t2)
del t2
[t1] = nvFusion0(a, i3, i4)
# t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i3, offset=i4) # t0: "cuda:0 f32[2, 2]"
# t1 = prims.mul(t0, a) # t1: "cuda:0 f32[2, 2]"
i6 = operator.add(i4, 4) # i6: "int 11"
# i6 = prims.add(i4, 4) # i6: "int 11"
del i4
t7 = pack_rng_state_prim_impl(i3, i6) # t7: "cpu ui8[16]"
del i3, i6
set_rng_state_prim_impl(t7, devices.Device("cuda:0"))
del t7
######### i3, i4 is not passed to backward, but the constant value 7 in the meta function
return {'output': t1, 'flat_args': [a], 'flat_output': (t1,)}, ((), (7, 7))
@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
_, C1, = saved_for_backward
clear_collection(saved_for_backward)
del saved_for_backward
t2, = cotangents
clear_collection(cotangents)
del cotangents
i3, i4, = C1
clear_collection(C1)
del C1
[t12] = nvFusion0(i3, i4, t2)
# t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i3, offset=i4) # t0: "cuda:0 f32[2, 2]"
# t12 = prims.mul(t0, t2) # t12: "cuda:0 f32[2, 2]"
del i3, i4, t2
return (t12,)
To Reproduce
Reproduction is based on branch uniform_rng
import torch
import thunder
def func(a):
b = thunder.torch.uniform_like(a, device=a.device, dtype=a.dtype)
return b*a
a = torch.randn(2, 2, device="cuda", requires_grad=True)
jfunc = thunder.jit(func)
out = jfunc(a)
print(thunder.last_traces(jfunc)[-1])
print(thunder.last_backward_traces(jfunc)[-1])
It happens due to this line. Not sure if there is any impact of not baking in the value from compile time. The tests in test_grad.py seem to be running fine after removing this line.
https://github.com/Lightning-AI/lightning-thunder/blob/6cd19c4d7f23b635513cbc029c7eea6652708f65/thunder/core/transforms.py#L3705
I think other option could be to instead return TupleProxy since this function seems to be returning tuple of numbers.
cc: @IvanYashchuk
Thank you @kshitij12345 ! removing this line works for me
sadly, when this line is removed, there are other cases where a NumberProxy name is returned but it doesn't exist in the trace
sadly, when this line is removed, there are other cases where a NumberProxy name is returned but it doesn't exist in the trace
That's probably the reason why these numbers were made concrete. We can't solve this particular problem easily, there are a lot of parts of Thunder that rely on concrete numbers.
Does this cause any problems in your work?
Does this cause any problems in your work?
If we have operator that produces NumberProxy results(unpack_rng_state in my case), and these results happen to be passed to backward pass(like the i3,i4 in the upper trace), thunder used the NumerProxy value in meta function(7 in the example) not the one computed in runtime(i3,i4). I hope I registered the operator the right way
triage review:
- we would like to pursue the general problem of being able to return a NumberProxy without a value here
- @jjsjann123 , this will probably be necessary for symbolic value support, if you're interested in taking a look
Does this look about right?
root@c574d9980ec8:/volume# python thunder_issue_231.py
# Constructed by Delete Last Used (took 0 milliseconds)
import operator
import thunder.core.devices as devices
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast()
def augmented_forward_fn(a):
# a: "cuda:0 f32[2, 2]"
t2 = get_rng_state_prim_impl(None, devices.Device("cuda:0")) # t2: "cpu ui8[16]"
(i3, i4) = unpack_rng_state_prim_impl(t2)
del t2
[t1] = nvFusion0(a, i3, i4)
# t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i3, offset=i4) # t0: "cuda:0 f32[2, 2]"
# t1 = prims.mul(t0, a) # t1: "cuda:0 f32[2, 2]"
i6 = operator.add(i4, 4) # i6: "int 4"
# i6 = prims.add(i4, 4) # i6: "int 4"
t7 = pack_rng_state_prim_impl(i3, i6) # t7: "cpu ui8[16]"
del i6
set_rng_state_prim_impl(t7, devices.Device("cuda:0"))
del t7
return {'output': t1, 'flat_args': [a], 'flat_output': (t1,)}, ((), (i3, i4))
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
_, C1, = saved_for_backward
clear_collection(saved_for_backward)
del saved_for_backward
t2, = cotangents
clear_collection(cotangents)
del cotangents
i3, i4, = C1
clear_collection(C1)
del C1
[t12] = nvFusion0(i3, i4, t2)
# t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i3, offset=i4) # t0: "cuda:0 f32[2, 2]"
# t12 = prims.mul(t0, t2) # t12: "cuda:0 f32[2, 2]"
del i3, i4, t2
return (t12,)
I see (i3, i4) is being saved for backward.
This is with your branch + @kshitij12345 's suggestion on removing the treemap + #250
I happen to be playing with this recently and my hack seems to plumbed it through for you.
sadly, when this line is removed, there are other cases where a NumberProxy name is returned but it doesn't exist in the trace
I don't think my PR helps at all.... since it apparently runs fine without it. :laughing:
Wondering what are the other cases that you are looking at?
Wondering what are the other cases that you are looking at?
when I remove the treemap line and run the dropout case, it has NameError: name 'f7' is not defined.
def func(a):
b = torch.nn.functional.dropout(a, p=0.5)
return b*a
In dropout case the NumberProxy is just constant number(2.0), and in the uniform case it is the i3,i4, which is number produced by an operator, so I tried to hack it with https://github.com/Lightning-AI/lightning-thunder/pull/244. But we probably need more general way to deal with it
Linking issue #403
@kiya00, I think this issue was closed automatically. Is it really fixed or should be reopened?
ah... forgot that we had this one open already. I'm linking it to #541 and I'll verify this when I close the other.