lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Traces with bsyms of `torch.autograd.Function` are not printable

Open crcrpar opened this issue 11 months ago • 2 comments

Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.

🐛 Bug

If traces include one or more BoundSymbols of torch.autograd.Function, then they are not printable.

The lookaside is https://github.com/Lightning-AI/lightning-thunder/blob/3390c922cfbe1b70c42118dfa8aa71adb3bec692/thunder/core/jit_ext.py#L655.

When registering a symbol, it doesn't specify a module:https://github.com/Lightning-AI/lightning-thunder/blob/3390c922cfbe1b70c42118dfa8aa71adb3bec692/thunder/core/jit_ext.py#L688-L691, arriving at https://github.com/Lightning-AI/lightning-thunder/blob/3390c922cfbe1b70c42118dfa8aa71adb3bec692/thunder/core/symbol.py#L604-L616

To Reproduce

Code sample

import torch

import thunder


class MyLinear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, weight: torch.Tensor, shape: tuple[int, int]) -> torch.Tensor:
        ctx.shape = shape
        ctx.save_for_backward(x, weight)
        ctx.pretty_attr = 100
        ctx.scaler = 1.0
        return torch.matmul(x, weight.t())

    @staticmethod
    def backward(ctx, grad_output):
        (x, weight) = ctx.saved_tensors
        assert weight.shape == ctx.shape  # really bogus, just to use ctx.shape
        scaler2 = ctx.shape[0] / ctx.shape[1]
        return torch.matmul(grad_output, weight) * ctx.scaler, torch.matmul(grad_output.t(), x) / scaler2, None


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(2, 2, bias=False)

    def forward(self, x):
        return MyLinear.apply(x, self.l1.weight, self.l1.weight.shape)


if __name__ == "__main__":
    x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True)
    model = Model().to(dtype=torch.float64)
    jitted = thunder.jit(model)

    jitted(x)
    print(thunder.last_traces(jitted)[-1])
    print(thunder.last_traces(jitted)[0])

The outputP:

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x, weight):
  # x: "cpu f64[2, 2]"
  # weight: "cpu f64[2, 2]"
  t21 = torch.permute(weight, (1, 0))  # t21: "cpu f64[2, 2]"
    # t21 = ltorch.permute(weight, (1, 0))  # t21: "cpu f64[2, 2]"
      # t21 = prims.transpose(weight, (1, 0))  # t21: "cpu f64[2, 2]"
  t22 = torch.matmul(x, t21)  # t22: "cpu f64[2, 2]"
    # t22 = ltorch.matmul(x, t21)  # t22: "cpu f64[2, 2]"
      # t22 = prims.matmul(x, t21)  # t22: "cpu f64[2, 2]"
  del t21
  t11 = shallow_copy(t22)  # t11: "cpu f64[2, 2]"
  del t22
  return {'output': t11, 'flat_args': [x, weight], 'flat_output': (t11,)}, ((weight, x), ())
Traceback (most recent call last):
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/repro.py", line 39, in <module>
    print(thunder.last_traces(jitted)[0])
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/trace.py", line 510, in __repr__
    return self.python(print_depth=-1)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/trace.py", line 363, in python
    import_ctx, call_ctx, object_ctx = self._gather_ctxs()
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/trace.py", line 322, in _gather_ctxs
    bsym_import_ctx, bsym_call_ctx, bsym_object_ctx = bsym.gather_ctxs()
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/symbol.py", line 648, in gather_ctxs
    return self.import_ctx(), self._get_call_ctx(), self.object_ctx()
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/symbol.py", line 608, in import_ctx
    assert self.sym.module is not None  # TODO: Is this a valid assumption?
AssertionError

Expected behavior

crcrpar avatar Nov 07 '24 07:11 crcrpar