lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Proxy renaming in general jit sometimes is skipped

Open IvanYashchuk opened this issue 1 year ago • 3 comments

🐛 Bug

Proxy renaming in the initial trace doesn't work sometimes. Let's check how does the initial trace look like for the following example (taken from test_core.py::test_cse):

import thunder
import torch

from thunder import clang

def func(x, y, device):
    a = x * y
    b = y / x
    c = x * y
    d = y / x
    z = a * b
    w = c * d 
    m = w * 1
    a = clang.uniform(w.shape, device=device, dtype=thunder.float16)
    return z, w, m, a

x = torch.randn(3, 4, device='cuda:0')
y = torch.randn(3, 4, device='cuda:0')
trace = thunder.trace()(func, x, y, 'cuda:0')
print(trace)
func = trace.python_callable()
jfunc = thunder.jit(func, executors=["torch"])
out = jfunc(x, y, device='cuda:0')
print(thunder.last_traces(jfunc)[0])

The initial trace in thunder.jit is

def computation(x, y):
  # x: "cuda:0 f32[3, 4]"
  # y: "cuda:0 f32[3, 4]"

  # thunder.func_39:15: 	  t0 = ltorch.mul(x, y)  # t0: "cuda:0 f32[3, 4]"
  t0 = ltorch.mul(x, y)  # t0: "cuda:0 f32[3, 4]"
    # t0 = prims.mul(x, y)  # t0: "cuda:0 f32[3, 4]"

  # thunder.func_39:16: 	  t1 = ltorch.true_divide(y, x)  # t1: "cuda:0 f32[3, 4]"
  t1 = ltorch.true_divide(y, x)  # t1: "cuda:0 f32[3, 4]"
    # t1 = prims.div(y, x)  # t1: "cuda:0 f32[3, 4]"

  # thunder.func_39:17: 	  t2 = ltorch.mul(x, y)  # t2: "cuda:0 f32[3, 4]"
  t2 = ltorch.mul(x, y)  # t2: "cuda:0 f32[3, 4]"
    # t2 = prims.mul(x, y)  # t2: "cuda:0 f32[3, 4]"

  # thunder.func_39:18: 	  t3 = ltorch.true_divide(y, x)  # t3: "cuda:0 f32[3, 4]"
  t3 = ltorch.true_divide(y, x)  # t3: "cuda:0 f32[3, 4]"
    # t3 = prims.div(y, x)  # t3: "cuda:0 f32[3, 4]"

  # thunder.func_39:19: 	  t4 = ltorch.mul(t0, t1)  # t4: "cuda:0 f32[3, 4]"
  t4 = ltorch.mul(t0, t1)  # t4: "cuda:0 f32[3, 4]"
    # t4 = prims.mul(t0, t1)  # t4: "cuda:0 f32[3, 4]"

  # thunder.func_39:20: 	  t5 = ltorch.mul(t2, t3)  # t5: "cuda:0 f32[3, 4]"
  t5 = ltorch.mul(t2, t3)  # t5: "cuda:0 f32[3, 4]"
    # t5 = prims.mul(t2, t3)  # t5: "cuda:0 f32[3, 4]"

  # thunder.func_39:21: 	  t6 = ltorch.mul(t5, 1)  # t6: "cuda:0 f32[3, 4]"
  t6 = ltorch.mul(t5, 1)  # t6: "cuda:0 f32[3, 4]"
    # _ = prims.convert_element_type(1, float)
    # t6 = prims.mul(t5, 1.0)  # t6: "cuda:0 f32[3, 4]"

  # thunder.func_39:22: 	  t7 = prims.uniform((3, 4), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float16)  # t7: "cuda:0 f16[3, 4]"
  t14 = prims.uniform((3, 4), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float16)  # t14: "cuda:0 f16[3, 4]"

  # /home/iyashchuk/dev/pytorch/main/torch/autograd/grad_mode.py:186: 	        torch._C._set_grad_enabled(mode)
  return (t4, t5, t6, t14)

Why is t14 not renamed to t7 and all other variables are renamed? The renaming is happening at https://github.com/Lightning-AI/lightning-thunder/blob/9f6e5b14e7a0fc6c96cca254540666d899df60b2/thunder/core/jit_ext.py#L1822-L1823

IvanYashchuk avatar Aug 09 '24 09:08 IvanYashchuk

The goal of the renaming has been to use more user variable names, not to get consecutive generic names. Is that what you want? (Also, with any bit of luck, #954 will change the trace above.

t-vi avatar Aug 12 '24 12:08 t-vi

I understand the goal, but in the example above the renaming doesn't happen when it should. In the example above the trace object returned by thunder.trace() plays the role of a user script and variable names should be preserved.

IvanYashchuk avatar Aug 12 '24 13:08 IvanYashchuk

I htink what happens often is that it tries to rename something to an internal name ("tos" etc.) and then finds that that is taken already for whatever reason. With #954 (which is not the end of the story), you'll get t7.

@no_autocast
def computation(x, y):
  # x: "cpu f32[3, 4]"
  # y: "cpu f32[3, 4]"

  # thunder.func_0:15: 	  t0 = ltorch.mul(x, y)  # t0: "cpu f32[3, 4]"
  t0 = ltorch.mul(x, y)  # t0: "cpu f32[3, 4]"
    # t0 = prims.mul(x, y)  # t0: "cpu f32[3, 4]"

  # thunder.func_0:16: 	  t1 = ltorch.true_divide(y, x)  # t1: "cpu f32[3, 4]"
  t1 = ltorch.true_divide(y, x)  # t1: "cpu f32[3, 4]"
    # t1 = prims.div(y, x)  # t1: "cpu f32[3, 4]"

  # thunder.func_0:17: 	  t2 = ltorch.mul(x, y)  # t2: "cpu f32[3, 4]"
  t2 = ltorch.mul(x, y)  # t2: "cpu f32[3, 4]"
    # t2 = prims.mul(x, y)  # t2: "cpu f32[3, 4]"

  # thunder.func_0:18: 	  t3 = ltorch.true_divide(y, x)  # t3: "cpu f32[3, 4]"
  t3 = ltorch.true_divide(y, x)  # t3: "cpu f32[3, 4]"
    # t3 = prims.div(y, x)  # t3: "cpu f32[3, 4]"

  # thunder.func_0:19: 	  t4 = ltorch.mul(t0, t1)  # t4: "cpu f32[3, 4]"
  t4 = ltorch.mul(t0, t1)  # t4: "cpu f32[3, 4]"
    # t4 = prims.mul(t0, t1)  # t4: "cpu f32[3, 4]"

  # thunder.func_0:20: 	  t5 = ltorch.mul(t2, t3)  # t5: "cpu f32[3, 4]"
  t5 = ltorch.mul(t2, t3)  # t5: "cpu f32[3, 4]"
    # t5 = prims.mul(t2, t3)  # t5: "cpu f32[3, 4]"

  # thunder.func_0:21: 	  t6 = ltorch.mul(t5, 1)  # t6: "cpu f32[3, 4]"
  t6 = ltorch.mul(t5, 1)  # t6: "cpu f32[3, 4]"
    # _ = prims.convert_element_type(1, float)
    # t6 = prims.mul(t5, 1.0)  # t6: "cpu f32[3, 4]"

  # thunder.func_0:22: 	  t7 = prims.uniform((3, 4), 0.0, 1.0, device=devices.Device("cpu"), dtype=dtypes.float16)  # t7: "cpu f16[3, 4]"
  t7 = prims.uniform((3, 4), 0.0, 1.0, device=devices.Device("cpu"), dtype=dtypes.float16)  # t7: "cpu f16[3, 4]"

  # /usr/local/lib/python3.12/dist-packages/torch/autograd/grad_mode.py:186: 	        torch._C._set_grad_enabled(mode)
  return (t4, t5, t6, t7)

t-vi avatar Aug 12 '24 13:08 t-vi