lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

If saved_for_backward returns NumberProxy, the value is taken from compile time, not runtime

Open kiya00 opened this issue 1 year ago • 11 comments

🐛 Bug

When I add an operator that returns numbers, the values in the saved_for_backward are the compile time value defined in Symbol.meta, not the real value computed at runtime.

An example trace is like:

@torch.no_grad()
@no_autocast()
def augmented_forward_fn(a):
  # a: "cuda:0 f32[2, 2]"
  t2 = get_rng_state_prim_impl(None, devices.Device("cuda:0"))  # t2: "cpu ui8[16]"
  (i3, i4) = unpack_rng_state_prim_impl(t2)
  del t2
  [t1] = nvFusion0(a, i3, i4)
    # t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i3, offset=i4)  # t0: "cuda:0 f32[2, 2]"
    # t1 = prims.mul(t0, a)  # t1: "cuda:0 f32[2, 2]"
  i6 = operator.add(i4, 4)  # i6: "int 11"
    # i6 = prims.add(i4, 4)  # i6: "int 11"
  del i4
  t7 = pack_rng_state_prim_impl(i3, i6)  # t7: "cpu ui8[16]"
  del i3, i6
  set_rng_state_prim_impl(t7, devices.Device("cuda:0"))
  del t7
  ######### i3, i4 is not passed to backward, but the constant value 7 in the meta function
  return {'output': t1, 'flat_args': [a], 'flat_output': (t1,)}, ((), (7, 7))  

@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  _, C1, = saved_for_backward
  clear_collection(saved_for_backward)
  del saved_for_backward
  t2, = cotangents
  clear_collection(cotangents)
  del cotangents
  i3, i4, = C1
  clear_collection(C1)
  del C1
  [t12] = nvFusion0(i3, i4, t2)
    # t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i3, offset=i4)  # t0: "cuda:0 f32[2, 2]"
    # t12 = prims.mul(t0, t2)  # t12: "cuda:0 f32[2, 2]"
  del i3, i4, t2
  return (t12,)

To Reproduce

Reproduction is based on branch uniform_rng

import torch
import thunder

def func(a):
    b = thunder.torch.uniform_like(a, device=a.device, dtype=a.dtype)
    return b*a

a = torch.randn(2, 2, device="cuda", requires_grad=True)

jfunc = thunder.jit(func)
out = jfunc(a)

print(thunder.last_traces(jfunc)[-1])
print(thunder.last_backward_traces(jfunc)[-1])

kiya00 avatar Apr 18 '24 15:04 kiya00

It happens due to this line. Not sure if there is any impact of not baking in the value from compile time. The tests in test_grad.py seem to be running fine after removing this line.

https://github.com/Lightning-AI/lightning-thunder/blob/6cd19c4d7f23b635513cbc029c7eea6652708f65/thunder/core/transforms.py#L3705

I think other option could be to instead return TupleProxy since this function seems to be returning tuple of numbers.

cc: @IvanYashchuk

kshitij12345 avatar Apr 19 '24 07:04 kshitij12345

Thank you @kshitij12345 ! removing this line works for me

kiya00 avatar Apr 19 '24 08:04 kiya00

sadly, when this line is removed, there are other cases where a NumberProxy name is returned but it doesn't exist in the trace

kiya00 avatar Apr 19 '24 08:04 kiya00

sadly, when this line is removed, there are other cases where a NumberProxy name is returned but it doesn't exist in the trace

That's probably the reason why these numbers were made concrete. We can't solve this particular problem easily, there are a lot of parts of Thunder that rely on concrete numbers.

IvanYashchuk avatar Apr 22 '24 09:04 IvanYashchuk

Does this cause any problems in your work?

IvanYashchuk avatar Apr 22 '24 09:04 IvanYashchuk

Does this cause any problems in your work?

If we have operator that produces NumberProxy results(unpack_rng_state in my case), and these results happen to be passed to backward pass(like the i3,i4 in the upper trace), thunder used the NumerProxy value in meta function(7 in the example) not the one computed in runtime(i3,i4). I hope I registered the operator the right way

kiya00 avatar Apr 22 '24 09:04 kiya00

triage review:

  • we would like to pursue the general problem of being able to return a NumberProxy without a value here
  • @jjsjann123 , this will probably be necessary for symbolic value support, if you're interested in taking a look

mruberry avatar Apr 22 '24 19:04 mruberry

Does this look about right?

root@c574d9980ec8:/volume# python thunder_issue_231.py
# Constructed by Delete Last Used (took 0 milliseconds)
import operator
import thunder.core.devices as devices
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def augmented_forward_fn(a):
  # a: "cuda:0 f32[2, 2]"
  t2 = get_rng_state_prim_impl(None, devices.Device("cuda:0"))  # t2: "cpu ui8[16]"
  (i3, i4) = unpack_rng_state_prim_impl(t2)
  del t2
  [t1] = nvFusion0(a, i3, i4)
    # t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i3, offset=i4)  # t0: "cuda:0 f32[2, 2]"
    # t1 = prims.mul(t0, a)  # t1: "cuda:0 f32[2, 2]"
  i6 = operator.add(i4, 4)  # i6: "int 4"
    # i6 = prims.add(i4, 4)  # i6: "int 4"
  t7 = pack_rng_state_prim_impl(i3, i6)  # t7: "cpu ui8[16]"
  del i6
  set_rng_state_prim_impl(t7, devices.Device("cuda:0"))
  del t7
  return {'output': t1, 'flat_args': [a], 'flat_output': (t1,)}, ((), (i3, i4))
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  _, C1, = saved_for_backward
  clear_collection(saved_for_backward)
  del saved_for_backward
  t2, = cotangents
  clear_collection(cotangents)
  del cotangents
  i3, i4, = C1
  clear_collection(C1)
  del C1
  [t12] = nvFusion0(i3, i4, t2)
    # t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i3, offset=i4)  # t0: "cuda:0 f32[2, 2]"
    # t12 = prims.mul(t0, t2)  # t12: "cuda:0 f32[2, 2]"
  del i3, i4, t2
  return (t12,)

I see (i3, i4) is being saved for backward.

This is with your branch + @kshitij12345 's suggestion on removing the treemap + #250

I happen to be playing with this recently and my hack seems to plumbed it through for you.

jjsjann123 avatar Apr 22 '24 21:04 jjsjann123

sadly, when this line is removed, there are other cases where a NumberProxy name is returned but it doesn't exist in the trace

I don't think my PR helps at all.... since it apparently runs fine without it. :laughing:

Wondering what are the other cases that you are looking at?

jjsjann123 avatar Apr 22 '24 21:04 jjsjann123

Wondering what are the other cases that you are looking at?

when I remove the treemap line and run the dropout case, it has NameError: name 'f7' is not defined.

def func(a):
    b = torch.nn.functional.dropout(a, p=0.5)
    return b*a

In dropout case the NumberProxy is just constant number(2.0), and in the uniform case it is the i3,i4, which is number produced by an operator, so I tried to hack it with https://github.com/Lightning-AI/lightning-thunder/pull/244. But we probably need more general way to deal with it

kiya00 avatar Apr 23 '24 07:04 kiya00

Linking issue #403

jjsjann123 avatar May 15 '24 22:05 jjsjann123

@kiya00, I think this issue was closed automatically. Is it really fixed or should be reopened?

IvanYashchuk avatar May 29 '24 16:05 IvanYashchuk

ah... forgot that we had this one open already. I'm linking it to #541 and I'll verify this when I close the other.

jjsjann123 avatar Jun 11 '24 16:06 jjsjann123