lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Prologue error: check_slice_value fails with function wrappers and closures

Open IvanYashchuk opened this issue 1 year ago • 1 comments

🐛 Bug

import torch
import thunder

from thunder.tests.test_grad import _make_differentiable_wrapper
from thunder.core.utils import flatten_func

x = torch.randn((5, 5))

s1 = slice(1, 3, 1)
s2 = slice(2, 4, 2)

func = thunder.torch.getitem
args = (x, (s1, s2))

flat_op, flat_args, spec = flatten_func(thunder.torch.getitem, args, {})

filtered_op, filtered_args = _make_differentiable_wrapper(flat_op, flat_args)
jf = thunder.jit(filtered_op, disable_torch_autograd=True)
jf(*filtered_args)
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 20
     18 filtered_op, filtered_args = _make_differentiable_wrapper(flat_op, flat_args)
     19 jf = thunder.jit(filtered_op, disable_torch_autograd=True)
---> 20 jf(*filtered_args)

File ~/dev/lightning-thunder/thunder/__init__.py:686, in jit.<locals>.fn_(*args, **kwargs)
    683 cs.last_trace_host_start = time.perf_counter_ns()
    684 cs.calls += 1
--> 686 cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
    687 cs.last_trace_host_execution_start = time.perf_counter_ns()
    689 result = cache_entry.computation_fn(*inps)

File ~/dev/lightning-thunder/thunder/__init__.py:225, in _with_cache_info_ctx.<locals>.cache_info_wrapper(*args, **kwargs)
    223 tok = _cache_info_ctx.set({})
    224 try:
--> 225     res = fn(*args, **kwargs)
    226 finally:
    227     _cache_info_ctx.reset(tok)

File ~/dev/lightning-thunder/thunder/__init__.py:563, in jit.<locals>.get_computation_and_inputs(*args, **kwargs)
    561 cs.last_prologue_execution_start = time.perf_counter_ns()
    562 if interpretation is INTERPRETATION_OPTIONS.TRANSLATE_PYTHON:
--> 563     inps, pro_to_epi = pro(*args, **kwargs)
    564 else:
    565     inps = pro(*args, **kwargs)

File ~/dev/pytorch/main/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/dev/pytorch/main/torch/amp/autocast_mode.py:28, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     25 @functools.wraps(func)
     26 def decorate_autocast(*args, **kwargs):
     27     with autocast_instance:
---> 28         return func(*args, **kwargs)

File ~/dev/pytorch/main/torch/amp/autocast_mode.py:28, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     25 @functools.wraps(func)
     26 def decorate_autocast(*args, **kwargs):
     27     with autocast_instance:
---> 28         return func(*args, **kwargs)

File thunder.prologue_0:18, in prologue(*args, **kwargs)
     16 res: "cpu f32[5, 5]" = p0[0]
     17 check_tensor_metadata(res, (5, 5), 'cpu', torch.float32, False)
---> 18 check_slice_value(p0, slice(1, 3, 1))
     19 p1: "<class 'slice'>" = p0[2]
     20 check_slice_value(p1, slice(2, 4, 2))

File ~/dev/lightning-thunder/thunder/executors/pythonex.py:133, in _check_slice_value_impl(s, value)
    132 def _check_slice_value_impl(s: slice, value: slice) -> None:
--> 133     utils.check(s == value, lambda: f"Expected '{s} to be equal to '{value}")

File ~/dev/lightning-thunder/thunder/core/baseutils.py:103, in check(cond, s, exception_type)
     98 """Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails.
     99 
    100 s is a callable producing a string to avoid string construction if the error check is passed.
    101 """
    102 if not cond:
--> 103     raise exception_type(s())

RuntimeError: Expected '[tensor([[ 0.5565,  1.2948,  0.4167,  1.2958,  0.7327],
        [-0.3139,  2.3312, -0.7532,  0.6989,  0.2032],
        [ 1.1990, -0.1274,  0.4933,  0.6365, -0.3264],
        [-0.1949,  0.0442, -1.5518,  0.4407,  0.5539],
        [-0.4287,  0.4858, -0.5317, -1.9094, -0.9781]]), slice(1, 3, 1), slice(2, 4, 2)] to be equal to 'slice(1, 3, 1)

IvanYashchuk avatar Jul 23 '24 13:07 IvanYashchuk

Extracting the prologue from the traceback (I'll send a PR to assign to last_prologues before running) it looks a bit like a naming / proxy id collision somewhere to me where we wanted to call the first slice p0 again:

@torch.no_grad()
@no_autocast
def prologue(*args, **kwargs):
  # args: "Any"
  check_len(args, 0)
  # kwargs: "Any"
  check_len(kwargs, 0)
  fn: "Any" = globals()['__function_obj']
  obj: "Any" = fn.__closure__
  subscr: "Any" = obj[0]
  p0: "Any" = subscr.cell_contents
  res: "cpu f32[5, 5]" = p0[0]
  check_tensor_metadata(res, (5, 5), 'cpu', torch.float32, False)
  check_slice_value(p0, slice(1, 3, 1))
  p1: "<class 'slice'>" = p0[2]
  check_slice_value(p1, slice(2, 4, 2))
  cache_info: "Any" = thunder._get_cache_info()
  cache_info_default_dtype: "<class 'torch.dtype'>" = cache_info['default_dtype']
  check_literal_like(cache_info_default_dtype, torch.float32)
  cache_info_default_device: "<class 'torch.device'>" = cache_info['default_device']
  check_literal_like(cache_info_default_device, torch.device("cpu"))
  cache_info_is_autocast_enabled: "bool False" = cache_info['is_autocast_enabled']
  check_number_type_and_value(cache_info_is_autocast_enabled, False)
  cache_info_no_grad_sync: "bool False" = cache_info['no_grad_sync']
  check_number_type_and_value(cache_info_no_grad_sync, False)
  return ((res,), ())

t-vi avatar Jul 23 '24 18:07 t-vi

Seems to be working for me now.

t-vi avatar Mar 25 '25 22:03 t-vi