jit error: unpacking from nonconstant opaque function
🐛 Bug
import thunder
import torch
@thunder.jit
def func(shape):
return torch.rand(tuple(shape), device="cuda:0", dtype=torch.float32)
torch.manual_seed(12345)
shape = [1, 4, 40, 84, 84]
t1_1 = func(shape)
t2_1 = func(shape)
print(thunder.last_traces(func)[-1])
traceback:
NotImplementedError Traceback (most recent call last)
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1422, in unpack_inputs.<locals>.unpack(v)
1421 try:
-> 1422 from_provenance(p.history)
1423 except Exception as e:
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1405, in unpack_inputs.<locals>.unpack.<locals>.from_provenance(provenance, new_output)
1404 raise NotImplementedError(f"Unpacking from {inst} {provenance}")
-> 1405 res = unpack_fn(provenance, new_output=new_output)
1407 if provenance.ext_flag & EXT_FLAG_IS_MODULE:
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1358, in unpack_inputs.<locals>.unpack.<locals>.from_opaque(provenance, new_output)
1357 if fn.inst != PseudoInst.CONSTANT:
-> 1358 raise NotImplementedError(f"unpacking from nonconstant opaque function")
1359 if fn.value.__name__ == "__getitem__":
NotImplementedError: unpacking from nonconstant opaque function
The above exception was the direct cause of the following exception:
NotImplementedError Traceback (most recent call last)
Cell In[5], line 10
8 torch.manual_seed(12345)
9 shape = [1, 4, 40, 84, 84]
---> 10 t1_1 = func(shape)
11 t2_1 = func(shape)
13 print(thunder.last_traces(func)[-1])
File ~/dev/lightning-thunder/thunder/__init__.py:660, in jit.<locals>.fn_(*args, **kwargs)
657 cs.last_trace_host_start = time.time_ns()
658 cs.calls += 1
--> 660 cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
661 cs.last_trace_host_execution_start = time.time_ns()
663 result = cache_entry.computation_fn(*inps)
File ~/dev/lightning-thunder/thunder/__init__.py:217, in _with_cache_info_ctx.<locals>.cache_info_wrapper(*args, **kwargs)
215 tok = _cache_info_ctx.set({})
216 try:
--> 217 res = fn(*args, **kwargs)
218 finally:
219 _cache_info_ctx.reset(tok)
File ~/dev/lightning-thunder/thunder/__init__.py:496, in jit.<locals>.get_computation_and_inputs(*args, **kwargs)
494 prologue_trc: TraceCtx
495 computation_trc: TraceCtx
--> 496 jit_results: TraceResults = interpreter(
497 fn, args, kwargs, record_history=record_history, sharp_edges=cd.sharp_edges
498 )
499 prologue_trc = jit_results.prologue_trace
500 computation_trc = jit_results.computation_trace
File ~/dev/lightning-thunder/thunder/__init__.py:205, in _general_frontend(fn, args, kwargs, record_history, sharp_edges)
196 def _general_frontend(
197 fn: Callable,
198 args: tuple[Any, ...],
(...)
203 sharp_edges: SHARP_EDGES_OPTIONS,
204 ) -> TraceResults:
--> 205 return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1638, in thunder_general_jit(fn, args, kwargs, record_history, sharp_edges)
1635 else:
1636 epilogue_trace = None
-> 1638 pro_to_comp_proxies, pro_to_epi_proxies = unpack_inputs(
1639 ctx, prologue_trace, pro_to_comp, pro_to_epi, args, kwargs, has_epilogue=epilogue_trace is not None
1640 )
1642 proxy_order = {id(p): i for i, p in enumerate(pro_to_comp_proxies)}
1643 pro_to_comp = tuple(sorted(pro_to_comp, key=lambda v: proxy_order[id(v.proxy)]))
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1444, in unpack_inputs(ctx, prologue_trace, pro_to_comp_inps, pro_to_epi_inps, args, kwargs, has_epilogue)
1441 pro_kwargs_proxy = output
1443 pro_to_epi = tuple(sorted((unpack(v) for v in pro_to_epi_inps), key=lambda x: param_ordering[id(x)][1]))
-> 1444 pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1]))
1446 with tracectx(prologue_trace):
1447 for prim, *args in ctx._constraints:
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1444, in <genexpr>(.0)
1441 pro_kwargs_proxy = output
1443 pro_to_epi = tuple(sorted((unpack(v) for v in pro_to_epi_inps), key=lambda x: param_ordering[id(x)][1]))
-> 1444 pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1]))
1446 with tracectx(prologue_trace):
1447 for prim, *args in ctx._constraints:
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1424, in unpack_inputs.<locals>.unpack(v)
1422 from_provenance(p.history)
1423 except Exception as e:
-> 1424 raise NotImplementedError(f"Exception occured unpacking object from {p.history}") from e
1426 already_unpacked[id(p)] = p
1428 return p
NotImplementedError: Exception occured unpacking object from ProvenanceRecord(
i1 = INPUT_FN()
i2 = LOAD_ATTR(i1, '__globals__')
i3 = BINARY_SUBSCR(i2, 'torch')
i4 = LOAD_ATTR(i3, 'rand')
i5 = INPUT_ARGS()
i6 = BINARY_SUBSCR(i5, 0)
i7 = BINARY_SUBSCR(i6, 0)
i8 = BINARY_SUBSCR(i6, 1)
i9 = BINARY_SUBSCR(i6, 2)
i10 = BINARY_SUBSCR(i6, 3)
i11 = BINARY_SUBSCR(i6, 4)
i12 = BUILD_TUPLE(i7, i8, i9, i10, i11)
i13 = BUILD_TUPLE(i12)
i14 = LOAD_ATTR(i3, 'float32')
i15 = BUILD_DICT('device', 'cuda:0', 'dtype', i14)
i16 = OPAQUE(i4, i13, i15)
)
I looked into it a bit with the debugger and I think this is related to us considering the rand function opaque (i.e. we don't really know if it has side effects, state of the rng would be my best guess). What explodes here is the usage of a non-constant as input to the opaque. My guess would be that we decided that we don't trust the opaque fn to not modify the inputs in this case.
With the debugger we clearly hit this line: https://github.com/Lightning-AI/lightning-thunder/blob/849cc2ed3a2296ae4b08f3b124d39ab409da47cc/thunder/core/jit_ext.py#L1357-L1358
with fn.inst defined as LOAD_ATTR.
We could maybe improve on the error reporting and actually print the message from e here: https://github.com/Lightning-AI/lightning-thunder/blob/849cc2ed3a2296ae4b08f3b124d39ab409da47cc/thunder/core/jit_ext.py#L1423-L1424
and maybe print e
cc. @t-vi please don't hesitate to correct me if I'm wrong with the initial assumptions
@riccardofelluga Yes, seems true. I think the thing that goes wrong is when we decide that the proxy for the result of rand should be unpacked in the prologue while we should put rand in the compute trace (maybe we need a case for things produced by factory functions).
Yes, factory functions and rand should be put into the computation trace. nvFuser's handling of rand makes it possible to create portable and reproducible across GPUs random tensors, while PyTorch doesn't have this property: https://discuss.pytorch.org/t/even-with-the-same-seed-different-random-numbers-are-generated-by-different-devices/154054/2
triage review —
- @t-vi is available to consult/help on the design of this fix
Here's another reproducer for the same error (from https://github.com/Lightning-AI/lightning-thunder/issues/1344):
In [1]: import torch
In [2]: import thunder
In [3]: inputs = [torch.randn(1, 1) for _ in range(2)]
In [4]: def func(inputs): return sum(inputs)
In [5]: func(inputs)
Out[5]: tensor([[-2.0805]])
In [6]: jfunc = thunder.jit(func)
In [7]: jfunc(inputs)