Calling general_thunder_jit inside lookasides doesn't work
🐛 Bug
I want to convert a Python function that might contain PyTorch calls into a Thunder function inside the lookaside function.
I wasn't successful at using thunder.core.interpreter.interpret so I resorted to thunder_general_jit. The inner function interpreted_fn does the correct thing. However, something stands on the way of correct nested usage of thunder_general_jit and I see the following error:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[1], line 60
58 x = torch.randn(3, 4, requires_grad=True)
59 jf = thunder.jit(f)
---> 60 out = jf(x)
File ~/dev/lightning-thunder/thunder/__init__.py:704, in jit.<locals>.fn_(*args, **kwargs)
701 cs.last_trace_host_start = time.perf_counter_ns()
702 cs.calls += 1
--> 704 cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
705 cs.last_trace_host_execution_start = time.perf_counter_ns()
707 if cache_entry.vanilla_tensor_args:
File ~/dev/lightning-thunder/thunder/core/langctxs.py:136, in langctx.__call__.<locals>._fn(*args, **kwargs)
134 try:
135 tok = set_langctx(self.langctx)
--> 136 result = fn(*args, **kwargs)
137 return result
138 finally:
File ~/dev/lightning-thunder/thunder/__init__.py:213, in _with_cache_info_ctx.<locals>.cache_info_wrapper(*args, **kwargs)
211 tok = _cache_info_ctx.set({})
212 try:
--> 213 res = fn(*args, **kwargs)
214 finally:
215 _cache_info_ctx.reset(tok)
File ~/dev/lightning-thunder/thunder/__init__.py:500, in jit.<locals>.get_computation_and_inputs(*args, **kwargs)
498 prologue_trc: TraceCtx
499 computation_trc: TraceCtx
--> 500 jit_results: TraceResults = thunder_general_jit(
501 fn, args, kwargs, record_history=record_history, sharp_edges=cd.sharp_edges
502 )
503 prologue_trc = jit_results.prologue_trace
504 computation_trc = jit_results.computation_trace
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1562, in thunder_general_jit(fn, args, kwargs, record_history, sharp_edges)
1559 else:
1560 epilogue_trace = None
-> 1562 pro_to_comp_proxies, pro_to_epi_proxies = unpack_inputs(ctx, prologue_trace, pro_to_comp, pro_to_epi, args, kwargs)
1564 proxy_order = {id(p): i for i, p in enumerate(pro_to_comp_proxies)}
1565 pro_to_comp = tuple(sorted(pro_to_comp, key=lambda v: proxy_order[id(v.proxy)]))
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1367, in unpack_inputs(ctx, prologue_trace, pro_to_comp_inps, pro_to_epi_inps, args, kwargs)
1365 print(f"pro_to_comp_inps: {pro_to_comp_inps}")
1366 pro_to_epi = tuple(sorted((unpack(v) for v in pro_to_epi_inps), key=lambda x: param_ordering[id(x)][1]))
-> 1367 pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1]))
1369 with tracectx(prologue_trace):
1370 for prim, *args in ctx._constraints:
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1367, in unpack_inputs.<locals>.<lambda>(x)
1365 print(f"pro_to_comp_inps: {pro_to_comp_inps}")
1366 pro_to_epi = tuple(sorted((unpack(v) for v in pro_to_epi_inps), key=lambda x: param_ordering[id(x)][1]))
-> 1367 pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1]))
1369 with tracectx(prologue_trace):
1370 for prim, *args in ctx._constraints:
KeyError: 140260996191264
Script to reproduce it:
from thunder.core.jit_ext import (
compile_data_and_stats,
CompileData,
get_compile_data,
interpreter_needs_wrap,
SHARP_EDGES_OPTIONS,
thunder_general_jit,
unwrap,
register_general_jit_lookaside,
TraceResults,
wrap_const,
)
import torch
import thunder
def my_call(fn, *args, **kwargs):
return fn(*args, **kwargs)
@register_general_jit_lookaside(my_call)
def _lookaside(
fn,
*args,
**kwargs,
):
# Translate possibly PyTorch function into Thunder function
def interpreted_fn(*args, **kwargs):
# NOTE: Using thunder.jit with get_computation_and_inputs instead of
# thunder_general_jit results in the same error
unwrapped_function = unwrap(fn)
cd = CompileData(
fn=unwrapped_function,
disable_preprocessing=True,
executor_lookasides=get_compile_data().executor_lookasides,
)
with compile_data_and_stats(cd, None):
jit_results: TraceResults = thunder_general_jit(
unwrapped_function,
args,
kwargs,
sharp_edges=SHARP_EDGES_OPTIONS.ALLOW,
)
# inps, pro_to_epi = jit_results.prologue_trace.python_callable()(*args, **kwargs)
# result = jit_results.computation_trace.python_callable()(*inps)
result = jit_results.computation_trace.python_callable()(*args)
return result
wrapped_thunder_function = wrap_const(interpreted_fn)
result = interpreter_needs_wrap(my_call)(
wrapped_thunder_function, *args, **kwargs
)
return result
def f(x):
return my_call(lambda x: torch.sin(torch.cos(x)), x)
x = torch.randn(3, 4, requires_grad=True)
jf = thunder.jit(f)
out = jf(x)
In general support for nested JIT-tracing for higher order operations is discussed in https://github.com/Lightning-AI/lightning-thunder/issues/1134.
cc @t-vi
@IvanYashchuk - where do you hit this issue?
A tidy reproducible code is shared, as in the description. I confirmed that we can reproduce the error (with the slightly different KeyError message, with high probability)
@IvanYashchuk - where do you hit this issue?
I discovered this bug when trying to support PyTorch's and Dynamo's activation checkpointing implementation in https://github.com/Lightning-AI/lightning-thunder/pull/1127. Currently that PR works only for simple functions that have exactly the same implementation in PyTorch and Thunder (for example a.cos() + a.exp()). Fixing this bug would enable supporting any PyTorch function.
@t-vi will open issues for detecting and raising a meaningful error message first.
Related:
- #1134
Issues for the steps:
- #1222
- #1220
After 1220 is solved, we could use the present issue to track the remainder of the work.
Inside the tracing, it is not unlikely that some form of _interpret_call can help you, see the torch.autograd.Function-lookaside in JIT-ext for an advanced example.