lightning-thunder
lightning-thunder copied to clipboard
Traces with bsyms of `torch.autograd.Function` are not printable
Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.
🐛 Bug
If traces include one or more BoundSymbols of torch.autograd.Function, then they are not printable.
The lookaside is https://github.com/Lightning-AI/lightning-thunder/blob/3390c922cfbe1b70c42118dfa8aa71adb3bec692/thunder/core/jit_ext.py#L655.
When registering a symbol, it doesn't specify a module:https://github.com/Lightning-AI/lightning-thunder/blob/3390c922cfbe1b70c42118dfa8aa71adb3bec692/thunder/core/jit_ext.py#L688-L691, arriving at https://github.com/Lightning-AI/lightning-thunder/blob/3390c922cfbe1b70c42118dfa8aa71adb3bec692/thunder/core/symbol.py#L604-L616
To Reproduce
Code sample
import torch
import thunder
class MyLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, weight: torch.Tensor, shape: tuple[int, int]) -> torch.Tensor:
ctx.shape = shape
ctx.save_for_backward(x, weight)
ctx.pretty_attr = 100
ctx.scaler = 1.0
return torch.matmul(x, weight.t())
@staticmethod
def backward(ctx, grad_output):
(x, weight) = ctx.saved_tensors
assert weight.shape == ctx.shape # really bogus, just to use ctx.shape
scaler2 = ctx.shape[0] / ctx.shape[1]
return torch.matmul(grad_output, weight) * ctx.scaler, torch.matmul(grad_output.t(), x) / scaler2, None
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(2, 2, bias=False)
def forward(self, x):
return MyLinear.apply(x, self.l1.weight, self.l1.weight.shape)
if __name__ == "__main__":
x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True)
model = Model().to(dtype=torch.float64)
jitted = thunder.jit(model)
jitted(x)
print(thunder.last_traces(jitted)[-1])
print(thunder.last_traces(jitted)[0])
The outputP:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(x, weight):
# x: "cpu f64[2, 2]"
# weight: "cpu f64[2, 2]"
t21 = torch.permute(weight, (1, 0)) # t21: "cpu f64[2, 2]"
# t21 = ltorch.permute(weight, (1, 0)) # t21: "cpu f64[2, 2]"
# t21 = prims.transpose(weight, (1, 0)) # t21: "cpu f64[2, 2]"
t22 = torch.matmul(x, t21) # t22: "cpu f64[2, 2]"
# t22 = ltorch.matmul(x, t21) # t22: "cpu f64[2, 2]"
# t22 = prims.matmul(x, t21) # t22: "cpu f64[2, 2]"
del t21
t11 = shallow_copy(t22) # t11: "cpu f64[2, 2]"
del t22
return {'output': t11, 'flat_args': [x, weight], 'flat_output': (t11,)}, ((weight, x), ())
Traceback (most recent call last):
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/repro.py", line 39, in <module>
print(thunder.last_traces(jitted)[0])
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/trace.py", line 510, in __repr__
return self.python(print_depth=-1)
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/trace.py", line 363, in python
import_ctx, call_ctx, object_ctx = self._gather_ctxs()
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/trace.py", line 322, in _gather_ctxs
bsym_import_ctx, bsym_call_ctx, bsym_object_ctx = bsym.gather_ctxs()
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/symbol.py", line 648, in gather_ctxs
return self.import_ctx(), self._get_call_ctx(), self.object_ctx()
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/symbol.py", line 608, in import_ctx
assert self.sym.module is not None # TODO: Is this a valid assumption?
AssertionError