lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

jit error: unpacking from nonconstant opaque function

Open IvanYashchuk opened this issue 1 year ago • 5 comments

🐛 Bug

import thunder
import torch

@thunder.jit
def func(shape):
    return torch.rand(tuple(shape), device="cuda:0", dtype=torch.float32)

torch.manual_seed(12345)
shape = [1, 4, 40, 84, 84]
t1_1 = func(shape)
t2_1 = func(shape)

print(thunder.last_traces(func)[-1])

traceback:

NotImplementedError                       Traceback (most recent call last)
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1422, in unpack_inputs.<locals>.unpack(v)
   1421 try:
-> 1422     from_provenance(p.history)
   1423 except Exception as e:

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1405, in unpack_inputs.<locals>.unpack.<locals>.from_provenance(provenance, new_output)
   1404     raise NotImplementedError(f"Unpacking from {inst} {provenance}")
-> 1405 res = unpack_fn(provenance, new_output=new_output)
   1407 if provenance.ext_flag & EXT_FLAG_IS_MODULE:

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1358, in unpack_inputs.<locals>.unpack.<locals>.from_opaque(provenance, new_output)
   1357 if fn.inst != PseudoInst.CONSTANT:
-> 1358     raise NotImplementedError(f"unpacking from nonconstant opaque function")
   1359 if fn.value.__name__ == "__getitem__":

NotImplementedError: unpacking from nonconstant opaque function

The above exception was the direct cause of the following exception:

NotImplementedError                       Traceback (most recent call last)
Cell In[5], line 10
      8 torch.manual_seed(12345)
      9 shape = [1, 4, 40, 84, 84]
---> 10 t1_1 = func(shape)
     11 t2_1 = func(shape)
     13 print(thunder.last_traces(func)[-1])

File ~/dev/lightning-thunder/thunder/__init__.py:660, in jit.<locals>.fn_(*args, **kwargs)
    657 cs.last_trace_host_start = time.time_ns()
    658 cs.calls += 1
--> 660 cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
    661 cs.last_trace_host_execution_start = time.time_ns()
    663 result = cache_entry.computation_fn(*inps)

File ~/dev/lightning-thunder/thunder/__init__.py:217, in _with_cache_info_ctx.<locals>.cache_info_wrapper(*args, **kwargs)
    215 tok = _cache_info_ctx.set({})
    216 try:
--> 217     res = fn(*args, **kwargs)
    218 finally:
    219     _cache_info_ctx.reset(tok)

File ~/dev/lightning-thunder/thunder/__init__.py:496, in jit.<locals>.get_computation_and_inputs(*args, **kwargs)
    494 prologue_trc: TraceCtx
    495 computation_trc: TraceCtx
--> 496 jit_results: TraceResults = interpreter(
    497     fn, args, kwargs, record_history=record_history, sharp_edges=cd.sharp_edges
    498 )
    499 prologue_trc = jit_results.prologue_trace
    500 computation_trc = jit_results.computation_trace

File ~/dev/lightning-thunder/thunder/__init__.py:205, in _general_frontend(fn, args, kwargs, record_history, sharp_edges)
    196 def _general_frontend(
    197     fn: Callable,
    198     args: tuple[Any, ...],
   (...)
    203     sharp_edges: SHARP_EDGES_OPTIONS,
    204 ) -> TraceResults:
--> 205     return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1638, in thunder_general_jit(fn, args, kwargs, record_history, sharp_edges)
   1635 else:
   1636     epilogue_trace = None
-> 1638 pro_to_comp_proxies, pro_to_epi_proxies = unpack_inputs(
   1639     ctx, prologue_trace, pro_to_comp, pro_to_epi, args, kwargs, has_epilogue=epilogue_trace is not None
   1640 )
   1642 proxy_order = {id(p): i for i, p in enumerate(pro_to_comp_proxies)}
   1643 pro_to_comp = tuple(sorted(pro_to_comp, key=lambda v: proxy_order[id(v.proxy)]))

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1444, in unpack_inputs(ctx, prologue_trace, pro_to_comp_inps, pro_to_epi_inps, args, kwargs, has_epilogue)
   1441             pro_kwargs_proxy = output
   1443 pro_to_epi = tuple(sorted((unpack(v) for v in pro_to_epi_inps), key=lambda x: param_ordering[id(x)][1]))
-> 1444 pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1]))
   1446 with tracectx(prologue_trace):
   1447     for prim, *args in ctx._constraints:

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1444, in <genexpr>(.0)
   1441             pro_kwargs_proxy = output
   1443 pro_to_epi = tuple(sorted((unpack(v) for v in pro_to_epi_inps), key=lambda x: param_ordering[id(x)][1]))
-> 1444 pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1]))
   1446 with tracectx(prologue_trace):
   1447     for prim, *args in ctx._constraints:

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1424, in unpack_inputs.<locals>.unpack(v)
   1422         from_provenance(p.history)
   1423     except Exception as e:
-> 1424         raise NotImplementedError(f"Exception occured unpacking object from {p.history}") from e
   1426 already_unpacked[id(p)] = p
   1428 return p

NotImplementedError: Exception occured unpacking object from ProvenanceRecord(
  i1 = INPUT_FN()
  i2 = LOAD_ATTR(i1, '__globals__')
  i3 = BINARY_SUBSCR(i2, 'torch')
  i4 = LOAD_ATTR(i3, 'rand')
  i5 = INPUT_ARGS()
  i6 = BINARY_SUBSCR(i5, 0)
  i7 = BINARY_SUBSCR(i6, 0)
  i8 = BINARY_SUBSCR(i6, 1)
  i9 = BINARY_SUBSCR(i6, 2)
  i10 = BINARY_SUBSCR(i6, 3)
  i11 = BINARY_SUBSCR(i6, 4)
  i12 = BUILD_TUPLE(i7, i8, i9, i10, i11)
  i13 = BUILD_TUPLE(i12)
  i14 = LOAD_ATTR(i3, 'float32')
  i15 = BUILD_DICT('device', 'cuda:0', 'dtype', i14)
  i16 = OPAQUE(i4, i13, i15)
)

IvanYashchuk avatar Jun 17 '24 07:06 IvanYashchuk

I looked into it a bit with the debugger and I think this is related to us considering the rand function opaque (i.e. we don't really know if it has side effects, state of the rng would be my best guess). What explodes here is the usage of a non-constant as input to the opaque. My guess would be that we decided that we don't trust the opaque fn to not modify the inputs in this case.

With the debugger we clearly hit this line: https://github.com/Lightning-AI/lightning-thunder/blob/849cc2ed3a2296ae4b08f3b124d39ab409da47cc/thunder/core/jit_ext.py#L1357-L1358

with fn.inst defined as LOAD_ATTR.

We could maybe improve on the error reporting and actually print the message from e here: https://github.com/Lightning-AI/lightning-thunder/blob/849cc2ed3a2296ae4b08f3b124d39ab409da47cc/thunder/core/jit_ext.py#L1423-L1424

and maybe print e

cc. @t-vi please don't hesitate to correct me if I'm wrong with the initial assumptions

riccardofelluga avatar Jun 17 '24 13:06 riccardofelluga

@riccardofelluga Yes, seems true. I think the thing that goes wrong is when we decide that the proxy for the result of rand should be unpacked in the prologue while we should put rand in the compute trace (maybe we need a case for things produced by factory functions).

t-vi avatar Jun 17 '24 13:06 t-vi

Yes, factory functions and rand should be put into the computation trace. nvFuser's handling of rand makes it possible to create portable and reproducible across GPUs random tensors, while PyTorch doesn't have this property: https://discuss.pytorch.org/t/even-with-the-same-seed-different-random-numbers-are-generated-by-different-devices/154054/2

IvanYashchuk avatar Jun 17 '24 15:06 IvanYashchuk

triage review —

  • @t-vi is available to consult/help on the design of this fix

mruberry avatar Jun 17 '24 19:06 mruberry

Here's another reproducer for the same error (from https://github.com/Lightning-AI/lightning-thunder/issues/1344):

In [1]: import torch

In [2]: import thunder

In [3]: inputs = [torch.randn(1, 1) for _ in range(2)]

In [4]: def func(inputs): return sum(inputs)

In [5]: func(inputs)
Out[5]: tensor([[-2.0805]])

In [6]: jfunc = thunder.jit(func)

In [7]: jfunc(inputs)

IvanYashchuk avatar Oct 28 '24 09:10 IvanYashchuk