InterpreterError: Encountered exception TypeError: missing a required argument: 'value' while tracing
🐛 Bug
A minimal repro for the fixed issue (https://github.com/Lightning-AI/lightning-thunder/issues/461#issuecomment-2178023346) doesn't work anymore with a failure in Thunder's interpreter:
import transformers
import torch
import thunder
def fn(x):
return transformers.modeling_outputs.BaseModelOutput(x)
jfn = thunder.jit(fn)
x = torch.randn(5, 5)
print(jfn(x))
TypeError: missing a required argument: 'value'
The above exception was the direct cause of the following exception:
InterpreterError Traceback (most recent call last)
Cell In[1], line 12
8 jfn = thunder.jit(fn)
10 x = torch.randn(5, 5)
---> 12 print(jfn(x))
File ~/dev/lightning-thunder/thunder/__init__.py:669, in jit.<locals>.fn_(*args, **kwargs)
666 cs.last_trace_host_start = time.perf_counter_ns()
667 cs.calls += 1
--> 669 cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
670 cs.last_trace_host_execution_start = time.perf_counter_ns()
672 result = cache_entry.computation_fn(*inps)
File ~/dev/lightning-thunder/thunder/__init__.py:223, in _with_cache_info_ctx.<locals>.cache_info_wrapper(*args, **kwargs)
221 tok = _cache_info_ctx.set({})
222 try:
--> 223 res = fn(*args, **kwargs)
224 finally:
225 _cache_info_ctx.reset(tok)
File ~/dev/lightning-thunder/thunder/__init__.py:503, in jit.<locals>.get_computation_and_inputs(*args, **kwargs)
501 prologue_trc: TraceCtx
502 computation_trc: TraceCtx
--> 503 jit_results: TraceResults = interpreter(
504 fn, args, kwargs, record_history=record_history, sharp_edges=cd.sharp_edges
505 )
506 prologue_trc = jit_results.prologue_trace
507 computation_trc = jit_results.computation_trace
File ~/dev/lightning-thunder/thunder/__init__.py:211, in _general_frontend(fn, args, kwargs, record_history, sharp_edges)
202 def _general_frontend(
203 fn: Callable,
204 args: tuple[Any, ...],
(...)
209 sharp_edges: SHARP_EDGES_OPTIONS,
210 ) -> TraceResults:
--> 211 return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1743, in thunder_general_jit(fn, args, kwargs, record_history, sharp_edges)
1741 with general_jit_ctx(ctx):
1742 with tracectx(computation_trace):
-> 1743 result = jfn(*args, **kwargs)
1744 prims.python_return(result)
1745 computation_trace.set_current_source_location(None, None)
File ~/dev/lightning-thunder/thunder/core/interpreter.py:6686, in interpret.<locals>.fn_(*args, **kwargs)
6682 traceback_str = os.linesep.join(f.format_with_source() for f in runtimectx.frame_stack)
6683 msg = (
6684 f"Encountered exception {type(e).__name__}: {e} while tracing {fn}:{os.linesep}" f"{traceback_str}"
6685 )
-> 6686 raise InterpreterError(msg) from e
6688 # NOTE: Wrapped functions are valid to assign new attributes to.
6689 fn_._last_interpreter_log = runtimectx.interp_log # type: ignore
InterpreterError: Encountered exception TypeError: missing a required argument: 'value' while tracing <function fn at 0x7f2e5c9be170>:
I used transformers-4.35.0.
Note that #461 was about avoiding the inf recursion.
This is from the dataclass decorator(?) setting __dataclass_params__ on the BaseModelOutput class (to _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False))
The trouble very likely stems from either _setattr_lookaside not properly handling assigning to classes and calling __setattr__ erroneously (see the dance that _getattr_lookaside does) or from the unbinding in the _call_dispatch going wrong) as the "missing a required argument" could well be that we miss "self", the object being assigned to.
Note also that the above (the creation of the BaseModelOutput class) is triggered by the lazy loading of transformers leading to the importing of the modeling_outputs module being done by the interpreter.
One option might be to deliberately not trace through the lazy importing but making it opaque.
Here is an even more minimal repro (different error message due to different setattr method)
class A:
pass
def fn(x):
A.x = x
fn(1) # works as expected
print(A.x)
jfn = thunder.jit(fn)
jfn(2) # fails because it calls the `__setattr__` intended for A-objects on the A-class.