lightning-thunder
lightning-thunder copied to clipboard
Prologue error: check_slice_value fails with function wrappers and closures
🐛 Bug
import torch
import thunder
from thunder.tests.test_grad import _make_differentiable_wrapper
from thunder.core.utils import flatten_func
x = torch.randn((5, 5))
s1 = slice(1, 3, 1)
s2 = slice(2, 4, 2)
func = thunder.torch.getitem
args = (x, (s1, s2))
flat_op, flat_args, spec = flatten_func(thunder.torch.getitem, args, {})
filtered_op, filtered_args = _make_differentiable_wrapper(flat_op, flat_args)
jf = thunder.jit(filtered_op, disable_torch_autograd=True)
jf(*filtered_args)
RuntimeError Traceback (most recent call last)
Cell In[1], line 20
18 filtered_op, filtered_args = _make_differentiable_wrapper(flat_op, flat_args)
19 jf = thunder.jit(filtered_op, disable_torch_autograd=True)
---> 20 jf(*filtered_args)
File ~/dev/lightning-thunder/thunder/__init__.py:686, in jit.<locals>.fn_(*args, **kwargs)
683 cs.last_trace_host_start = time.perf_counter_ns()
684 cs.calls += 1
--> 686 cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
687 cs.last_trace_host_execution_start = time.perf_counter_ns()
689 result = cache_entry.computation_fn(*inps)
File ~/dev/lightning-thunder/thunder/__init__.py:225, in _with_cache_info_ctx.<locals>.cache_info_wrapper(*args, **kwargs)
223 tok = _cache_info_ctx.set({})
224 try:
--> 225 res = fn(*args, **kwargs)
226 finally:
227 _cache_info_ctx.reset(tok)
File ~/dev/lightning-thunder/thunder/__init__.py:563, in jit.<locals>.get_computation_and_inputs(*args, **kwargs)
561 cs.last_prologue_execution_start = time.perf_counter_ns()
562 if interpretation is INTERPRETATION_OPTIONS.TRANSLATE_PYTHON:
--> 563 inps, pro_to_epi = pro(*args, **kwargs)
564 else:
565 inps = pro(*args, **kwargs)
File ~/dev/pytorch/main/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File ~/dev/pytorch/main/torch/amp/autocast_mode.py:28, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
25 @functools.wraps(func)
26 def decorate_autocast(*args, **kwargs):
27 with autocast_instance:
---> 28 return func(*args, **kwargs)
File ~/dev/pytorch/main/torch/amp/autocast_mode.py:28, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
25 @functools.wraps(func)
26 def decorate_autocast(*args, **kwargs):
27 with autocast_instance:
---> 28 return func(*args, **kwargs)
File thunder.prologue_0:18, in prologue(*args, **kwargs)
16 res: "cpu f32[5, 5]" = p0[0]
17 check_tensor_metadata(res, (5, 5), 'cpu', torch.float32, False)
---> 18 check_slice_value(p0, slice(1, 3, 1))
19 p1: "<class 'slice'>" = p0[2]
20 check_slice_value(p1, slice(2, 4, 2))
File ~/dev/lightning-thunder/thunder/executors/pythonex.py:133, in _check_slice_value_impl(s, value)
132 def _check_slice_value_impl(s: slice, value: slice) -> None:
--> 133 utils.check(s == value, lambda: f"Expected '{s} to be equal to '{value}")
File ~/dev/lightning-thunder/thunder/core/baseutils.py:103, in check(cond, s, exception_type)
98 """Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails.
99
100 s is a callable producing a string to avoid string construction if the error check is passed.
101 """
102 if not cond:
--> 103 raise exception_type(s())
RuntimeError: Expected '[tensor([[ 0.5565, 1.2948, 0.4167, 1.2958, 0.7327],
[-0.3139, 2.3312, -0.7532, 0.6989, 0.2032],
[ 1.1990, -0.1274, 0.4933, 0.6365, -0.3264],
[-0.1949, 0.0442, -1.5518, 0.4407, 0.5539],
[-0.4287, 0.4858, -0.5317, -1.9094, -0.9781]]), slice(1, 3, 1), slice(2, 4, 2)] to be equal to 'slice(1, 3, 1)
Extracting the prologue from the traceback (I'll send a PR to assign to last_prologues before running) it looks a bit like a naming / proxy id collision somewhere to me where we wanted to call the first slice p0 again:
@torch.no_grad()
@no_autocast
def prologue(*args, **kwargs):
# args: "Any"
check_len(args, 0)
# kwargs: "Any"
check_len(kwargs, 0)
fn: "Any" = globals()['__function_obj']
obj: "Any" = fn.__closure__
subscr: "Any" = obj[0]
p0: "Any" = subscr.cell_contents
res: "cpu f32[5, 5]" = p0[0]
check_tensor_metadata(res, (5, 5), 'cpu', torch.float32, False)
check_slice_value(p0, slice(1, 3, 1))
p1: "<class 'slice'>" = p0[2]
check_slice_value(p1, slice(2, 4, 2))
cache_info: "Any" = thunder._get_cache_info()
cache_info_default_dtype: "<class 'torch.dtype'>" = cache_info['default_dtype']
check_literal_like(cache_info_default_dtype, torch.float32)
cache_info_default_device: "<class 'torch.device'>" = cache_info['default_device']
check_literal_like(cache_info_default_device, torch.device("cpu"))
cache_info_is_autocast_enabled: "bool False" = cache_info['is_autocast_enabled']
check_number_type_and_value(cache_info_is_autocast_enabled, False)
cache_info_no_grad_sync: "bool False" = cache_info['no_grad_sync']
check_number_type_and_value(cache_info_no_grad_sync, False)
return ((res,), ())
Seems to be working for me now.