Proxy renaming in general jit sometimes is skipped
🐛 Bug
Proxy renaming in the initial trace doesn't work sometimes. Let's check how does the initial trace look like for the following example (taken from test_core.py::test_cse):
import thunder
import torch
from thunder import clang
def func(x, y, device):
a = x * y
b = y / x
c = x * y
d = y / x
z = a * b
w = c * d
m = w * 1
a = clang.uniform(w.shape, device=device, dtype=thunder.float16)
return z, w, m, a
x = torch.randn(3, 4, device='cuda:0')
y = torch.randn(3, 4, device='cuda:0')
trace = thunder.trace()(func, x, y, 'cuda:0')
print(trace)
func = trace.python_callable()
jfunc = thunder.jit(func, executors=["torch"])
out = jfunc(x, y, device='cuda:0')
print(thunder.last_traces(jfunc)[0])
The initial trace in thunder.jit is
def computation(x, y):
# x: "cuda:0 f32[3, 4]"
# y: "cuda:0 f32[3, 4]"
# thunder.func_39:15: t0 = ltorch.mul(x, y) # t0: "cuda:0 f32[3, 4]"
t0 = ltorch.mul(x, y) # t0: "cuda:0 f32[3, 4]"
# t0 = prims.mul(x, y) # t0: "cuda:0 f32[3, 4]"
# thunder.func_39:16: t1 = ltorch.true_divide(y, x) # t1: "cuda:0 f32[3, 4]"
t1 = ltorch.true_divide(y, x) # t1: "cuda:0 f32[3, 4]"
# t1 = prims.div(y, x) # t1: "cuda:0 f32[3, 4]"
# thunder.func_39:17: t2 = ltorch.mul(x, y) # t2: "cuda:0 f32[3, 4]"
t2 = ltorch.mul(x, y) # t2: "cuda:0 f32[3, 4]"
# t2 = prims.mul(x, y) # t2: "cuda:0 f32[3, 4]"
# thunder.func_39:18: t3 = ltorch.true_divide(y, x) # t3: "cuda:0 f32[3, 4]"
t3 = ltorch.true_divide(y, x) # t3: "cuda:0 f32[3, 4]"
# t3 = prims.div(y, x) # t3: "cuda:0 f32[3, 4]"
# thunder.func_39:19: t4 = ltorch.mul(t0, t1) # t4: "cuda:0 f32[3, 4]"
t4 = ltorch.mul(t0, t1) # t4: "cuda:0 f32[3, 4]"
# t4 = prims.mul(t0, t1) # t4: "cuda:0 f32[3, 4]"
# thunder.func_39:20: t5 = ltorch.mul(t2, t3) # t5: "cuda:0 f32[3, 4]"
t5 = ltorch.mul(t2, t3) # t5: "cuda:0 f32[3, 4]"
# t5 = prims.mul(t2, t3) # t5: "cuda:0 f32[3, 4]"
# thunder.func_39:21: t6 = ltorch.mul(t5, 1) # t6: "cuda:0 f32[3, 4]"
t6 = ltorch.mul(t5, 1) # t6: "cuda:0 f32[3, 4]"
# _ = prims.convert_element_type(1, float)
# t6 = prims.mul(t5, 1.0) # t6: "cuda:0 f32[3, 4]"
# thunder.func_39:22: t7 = prims.uniform((3, 4), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float16) # t7: "cuda:0 f16[3, 4]"
t14 = prims.uniform((3, 4), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float16) # t14: "cuda:0 f16[3, 4]"
# /home/iyashchuk/dev/pytorch/main/torch/autograd/grad_mode.py:186: torch._C._set_grad_enabled(mode)
return (t4, t5, t6, t14)
Why is t14 not renamed to t7 and all other variables are renamed?
The renaming is happening at https://github.com/Lightning-AI/lightning-thunder/blob/9f6e5b14e7a0fc6c96cca254540666d899df60b2/thunder/core/jit_ext.py#L1822-L1823
The goal of the renaming has been to use more user variable names, not to get consecutive generic names. Is that what you want? (Also, with any bit of luck, #954 will change the trace above.
I understand the goal, but in the example above the renaming doesn't happen when it should. In the example above the trace object returned by thunder.trace() plays the role of a user script and variable names should be preserved.
I htink what happens often is that it tries to rename something to an internal name ("tos" etc.) and then finds that that is taken already for whatever reason. With #954 (which is not the end of the story), you'll get t7.
@no_autocast
def computation(x, y):
# x: "cpu f32[3, 4]"
# y: "cpu f32[3, 4]"
# thunder.func_0:15: t0 = ltorch.mul(x, y) # t0: "cpu f32[3, 4]"
t0 = ltorch.mul(x, y) # t0: "cpu f32[3, 4]"
# t0 = prims.mul(x, y) # t0: "cpu f32[3, 4]"
# thunder.func_0:16: t1 = ltorch.true_divide(y, x) # t1: "cpu f32[3, 4]"
t1 = ltorch.true_divide(y, x) # t1: "cpu f32[3, 4]"
# t1 = prims.div(y, x) # t1: "cpu f32[3, 4]"
# thunder.func_0:17: t2 = ltorch.mul(x, y) # t2: "cpu f32[3, 4]"
t2 = ltorch.mul(x, y) # t2: "cpu f32[3, 4]"
# t2 = prims.mul(x, y) # t2: "cpu f32[3, 4]"
# thunder.func_0:18: t3 = ltorch.true_divide(y, x) # t3: "cpu f32[3, 4]"
t3 = ltorch.true_divide(y, x) # t3: "cpu f32[3, 4]"
# t3 = prims.div(y, x) # t3: "cpu f32[3, 4]"
# thunder.func_0:19: t4 = ltorch.mul(t0, t1) # t4: "cpu f32[3, 4]"
t4 = ltorch.mul(t0, t1) # t4: "cpu f32[3, 4]"
# t4 = prims.mul(t0, t1) # t4: "cpu f32[3, 4]"
# thunder.func_0:20: t5 = ltorch.mul(t2, t3) # t5: "cpu f32[3, 4]"
t5 = ltorch.mul(t2, t3) # t5: "cpu f32[3, 4]"
# t5 = prims.mul(t2, t3) # t5: "cpu f32[3, 4]"
# thunder.func_0:21: t6 = ltorch.mul(t5, 1) # t6: "cpu f32[3, 4]"
t6 = ltorch.mul(t5, 1) # t6: "cpu f32[3, 4]"
# _ = prims.convert_element_type(1, float)
# t6 = prims.mul(t5, 1.0) # t6: "cpu f32[3, 4]"
# thunder.func_0:22: t7 = prims.uniform((3, 4), 0.0, 1.0, device=devices.Device("cpu"), dtype=dtypes.float16) # t7: "cpu f16[3, 4]"
t7 = prims.uniform((3, 4), 0.0, 1.0, device=devices.Device("cpu"), dtype=dtypes.float16) # t7: "cpu f16[3, 4]"
# /usr/local/lib/python3.12/dist-packages/torch/autograd/grad_mode.py:186: torch._C._set_grad_enabled(mode)
return (t4, t5, t6, t7)