lightning-thunder
lightning-thunder copied to clipboard
implement zip lookaside in Python interpreter (enables e.g. thunder.jit with zip from LitGPT LLaMAMoE)
🐛 Bug
Here's a simplified version of LitGPT's LLaMAMoE without data-dependent shapes and it fails somewhere in the general jit:
NotImplementedError: unpacking from OPAQUE <slot wrapper '__next__' of 'zip' objects> ProvenanceRecord(
To reproduce:
import torch
import thunder
from torch import nn
class Test(nn.Module):
def __init__(self) -> None:
super().__init__()
self.n_expert = 8
self.n_expert_per_token = 2
self.C = 2
self.gate = nn.Linear(self.C, self.n_expert, bias=False)
self.experts = nn.ModuleList(nn.Linear(2, 2, bias=False) for _ in range(self.n_expert))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
x = x.view(-1, C) # (B*T, C)
router = self.gate(x) # (B*T, n_expert)
probs, indices = torch.topk(router, self.n_expert_per_token) # (B*T, n_expert_per_token)
probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)
masks = indices.unsqueeze(-1) == torch.arange(self.n_expert, device=x.device)
masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token)
y = torch.zeros_like(x) # (B*T, C)
for (mask, expert) in zip(masks, self.experts):
token_idx, expert_idx = torch.arange(B*T, device=x.device), torch.arange(B*T, device=x.device)
pprobs = probs[token_idx, expert_idx]
pprobs = pprobs.unsqueeze(-1)
eexpert = expert(x[token_idx])
y = torch.index_add(y, 0, token_idx, pprobs * eexpert)
return y.view(B, T, C)
model = Test()
model = thunder.jit(model)
x = torch.randn(2, 3, 2)
y = model(x)
raises:
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1272, in unpack_inputs.<locals>.unpack(v)
1271 try:
-> 1272 from_provenance(p.history)
1273 except Exception as e:
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1264, in unpack_inputs.<locals>.unpack.<locals>.from_provenance(provenance, new_output)
1263 raise NotImplementedError(f"Unpacking from {inst} {provenance}")
-> 1264 res = unpack_fn(provenance, new_output=new_output)
1265 provenance.proxy = res
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1195, in unpack_inputs.<locals>.unpack.<locals>.from_binary_subscr(provenance, new_output)
1194 def from_binary_subscr(provenance, *, new_output=False):
-> 1195 inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
1196 obj, idx = inputs
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1195, in <listcomp>(.0)
1194 def from_binary_subscr(provenance, *, new_output=False):
-> 1195 inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
1196 obj, idx = inputs
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1264, in unpack_inputs.<locals>.unpack.<locals>.from_provenance(provenance, new_output)
1263 raise NotImplementedError(f"Unpacking from {inst} {provenance}")
-> 1264 res = unpack_fn(provenance, new_output=new_output)
1265 provenance.proxy = res
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1195, in unpack_inputs.<locals>.unpack.<locals>.from_binary_subscr(provenance, new_output)
1194 def from_binary_subscr(provenance, *, new_output=False):
-> 1195 inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
1196 obj, idx = inputs
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1195, in <listcomp>(.0)
1194 def from_binary_subscr(provenance, *, new_output=False):
-> 1195 inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
1196 obj, idx = inputs
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1264, in unpack_inputs.<locals>.unpack.<locals>.from_provenance(provenance, new_output)
1263 raise NotImplementedError(f"Unpacking from {inst} {provenance}")
-> 1264 res = unpack_fn(provenance, new_output=new_output)
1265 provenance.proxy = res
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1178, in unpack_inputs.<locals>.unpack.<locals>.from_load_attr(provenance, new_output)
1177 is_pure = False
-> 1178 inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
1179 if new_output:
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1178, in <listcomp>(.0)
1177 is_pure = False
-> 1178 inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
1179 if new_output:
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1264, in unpack_inputs.<locals>.unpack.<locals>.from_provenance(provenance, new_output)
1263 raise NotImplementedError(f"Unpacking from {inst} {provenance}")
-> 1264 res = unpack_fn(provenance, new_output=new_output)
1265 provenance.proxy = res
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1195, in unpack_inputs.<locals>.unpack.<locals>.from_binary_subscr(provenance, new_output)
1194 def from_binary_subscr(provenance, *, new_output=False):
-> 1195 inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
1196 obj, idx = inputs
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1195, in <listcomp>(.0)
1194 def from_binary_subscr(provenance, *, new_output=False):
-> 1195 inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
1196 obj, idx = inputs
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1264, in unpack_inputs.<locals>.unpack.<locals>.from_provenance(provenance, new_output)
1263 raise NotImplementedError(f"Unpacking from {inst} {provenance}")
-> 1264 res = unpack_fn(provenance, new_output=new_output)
1265 provenance.proxy = res
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1241, in unpack_inputs.<locals>.unpack.<locals>.from_opaque(provenance, new_output)
1232 return from_provenance(
1233 ProvenanceRecord(
1234 PseudoInst.LOAD_ATTR,
(...)
1239 )
1240 )
-> 1241 raise NotImplementedError(f"unpacking from OPAQUE {fn.value} {provenance}")
NotImplementedError: unpacking from OPAQUE <slot wrapper '__next__' of 'zip' objects> ProvenanceRecord(
i1 = INPUT_FN()
i2 = LOAD_ATTR(i1, '__dict__')
i3 = BINARY_SUBSCR(i2, '_modules')
i4 = BINARY_SUBSCR(i3, 'experts')
i5 = INPUT_ARGS()
i6 = BINARY_SUBSCR(i5, 0)
i7 = LOAD_ATTR(i6, '__getattr__')
i8 = LOAD_ATTR(i7, '__func__')
i9 = Instruction(opname='CALL_FUNCTION_KW', opcode=141, arg=2, argval=2, argrepr='', offset=102, starts_line=None, is_jump_target=False)()
i10 = LOAD_ATTR(i1, 'n_expert_per_token')
i11 = BINARY_SUBSCR(i3, 'gate')
i12 = LOAD_ATTR(i11, '__dict__')
i13 = BINARY_SUBSCR(i12, '_parameters')
i14 = BINARY_SUBSCR(i13, 'bias')
i15 = BINARY_SUBSCR(i13, 'weight')
i16 = BUILD_TUPLE('view', i6)
i17 = OPAQUE(i8, i16, CONSTANT({}))
i18 = LOAD_ATTR(i17, 'func')
i19 = Instruction(opname='BINARY_ADD', opcode=23, arg=None, argval=None, argrepr='', offset=10, starts_line=None, is_jump_target=False)()
i20 = BINARY_SUBSCR(i19, 2)
i21 = BINARY_SUBSCR(i19, 1)
i22 = LOAD_ATTR(i17, 'args')
i23 = BINARY_SUBSCR(i22, 0)
i24 = BUILD_TUPLE(i20, i21, i23)
i25 = OPAQUE(i18, i24, CONSTANT({}))
i26 = BUILD_TUPLE(i14, i15, i25)
i27 = OPAQUE(CONSTANT([Symbol name=linear]), i26, CONSTANT({}))
i28 = BUILD_TUPLE(i10, i27)
i29 = OPAQUE(CONSTANT([Symbol name=topk]), i28, CONSTANT({}))
i30 = BINARY_SUBSCR(i29, 1)
i31 = BUILD_TUPLE('unsqueeze', i30)
i32 = OPAQUE(i8, i31, CONSTANT({}))
i33 = LOAD_ATTR(i32, 'func')
i34 = BUILD_TUPLE(i21, i30)
i35 = OPAQUE(i33, i34, CONSTANT({}))
i36 = Instruction(opname='COMPARE_OP', opcode=107, arg=2, argval='==', argrepr='==', offset=104, starts_line=None, is_jump_target=False)(i9, i35)
i37 = BUILD_TUPLE('permute', i36)
i38 = OPAQUE(i8, i37, CONSTANT({}))
i39 = LOAD_ATTR(i38, 'func')
i40 = BINARY_SUBSCR(i19, 3)
i41 = BUILD_TUPLE(i40, i20, i21, i36)
i42 = OPAQUE(i39, i41, CONSTANT({}))
i43 = LOAD_ATTR(i1, 'forward')
i44 = LOAD_ATTR(i43, '__func__')
i45 = LOAD_ATTR(i44, '__globals__')
i46 = BINARY_SUBSCR(i45, '__builtins__')
i47 = LOAD_ATTR(i46, 'zip')
i48 = BUILD_TUPLE(i4, i42, i47)
i49 = OPAQUE(CONSTANT(<built-in method __new__ of type object at 0x55c1c13de340>), i48, CONSTANT({}))
i50 = BUILD_TUPLE(i49)
i51 = OPAQUE(CONSTANT(<slot wrapper '__next__' of 'zip' objects>), i50, CONSTANT({}))
)
The above exception was the direct cause of the following exception:
NotImplementedError Traceback (most recent call last)
Cell In[1], line 35
32 model = thunder.jit(model)
34 x = torch.randn(2, 3, 2)
---> 35 y = model(x)
File ~/miniforge3/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/miniforge3/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/dev/lightning-thunder/thunder/__init__.py:209, in ThunderModule.forward(self, *args, **kwargs)
208 def forward(self, *args, **kwargs):
--> 209 res = self._forward_fn(*args, **kwargs)
210 return res
File ~/dev/lightning-thunder/thunder/__init__.py:661, in jit.<locals>.fn_(*args, **kwargs)
658 cs.last_trace_host_start = time.time_ns()
659 cs.calls += 1
--> 661 cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
662 cs.last_trace_host_execution_start = time.time_ns()
664 result = cache_entry.computation_fn(*inps)
File ~/dev/lightning-thunder/thunder/__init__.py:277, in _with_cache_info_ctx.<locals>.cache_info_wrapper(*args, **kwargs)
275 tok = _cache_info_ctx.set({})
276 try:
--> 277 res = fn(*args, **kwargs)
278 finally:
279 _cache_info_ctx.reset(tok)
File ~/dev/lightning-thunder/thunder/__init__.py:538, in jit.<locals>.get_computation_and_inputs(*args, **kwargs)
536 prologue_trc: TraceCtx
537 computation_trc: TraceCtx
--> 538 jit_results: TraceResults = interpreter(
539 fn, args, kwargs, record_history=record_history, sharp_edges=cd.sharp_edges
540 )
541 prologue_trc = jit_results.prologue_trace
542 computation_trc = jit_results.computation_trace
File ~/dev/lightning-thunder/thunder/__init__.py:190, in _general_frontend(fn, args, kwargs, record_history, sharp_edges)
181 def _general_frontend(
182 fn: Callable,
183 args: tuple[Any, ...],
(...)
188 sharp_edges: SHARP_EDGES_OPTIONS,
189 ) -> TraceResults:
--> 190 return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1481, in thunder_general_jit(fn, args, kwargs, record_history, sharp_edges)
1478 else:
1479 epilogue_trace = None
-> 1481 pro_to_comp_proxies, pro_to_epi_proxies = unpack_inputs(
1482 ctx, prologue_trace, pro_to_comp, pro_to_epi, args, kwargs, has_epilogue=epilogue_trace is not None
1483 )
1485 proxy_order = {id(p): i for i, p in enumerate(pro_to_comp_proxies)}
1486 pro_to_comp = tuple(sorted(pro_to_comp, key=lambda v: proxy_order[id(v.proxy)]))
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1301, in unpack_inputs(ctx, prologue_trace, pro_to_comp_inps, pro_to_epi_inps, args, kwargs, has_epilogue)
1298 pro_kwargs_proxy = output
1300 pro_to_epi = tuple(sorted((unpack(v) for v in pro_to_epi_inps), key=lambda x: param_ordering[id(x)][1]))
-> 1301 pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1]))
1303 with tracectx(prologue_trace):
1304 for prim, *args in ctx._constraints:
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1301, in <genexpr>(.0)
1298 pro_kwargs_proxy = output
1300 pro_to_epi = tuple(sorted((unpack(v) for v in pro_to_epi_inps), key=lambda x: param_ordering[id(x)][1]))
-> 1301 pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1]))
1303 with tracectx(prologue_trace):
1304 for prim, *args in ctx._constraints:
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1274, in unpack_inputs.<locals>.unpack(v)
1272 from_provenance(p.history)
1273 except Exception as e:
-> 1274 raise NotImplementedError(f"Exception occured unpacking object from {p.history}") from e
1276 already_unpacked[id(p)] = p
1278 # Adds cache constraints
1279 # TODO Consider refactoring these contraints
1280 # TODO Constrain on rank, device, and dtype
NotImplementedError: Exception occured unpacking object from ProvenanceRecord(
i1 = INPUT_FN()
i2 = LOAD_ATTR(i1, '__dict__')
i3 = BINARY_SUBSCR(i2, '_modules')
i4 = BINARY_SUBSCR(i3, 'experts')
i5 = INPUT_ARGS()
i6 = BINARY_SUBSCR(i5, 0)
i7 = LOAD_ATTR(i6, '__getattr__')
i8 = LOAD_ATTR(i7, '__func__')
i9 = Instruction(opname='CALL_FUNCTION_KW', opcode=141, arg=2, argval=2, argrepr='', offset=102, starts_line=None, is_jump_target=False)()
i10 = LOAD_ATTR(i1, 'n_expert_per_token')
i11 = BINARY_SUBSCR(i3, 'gate')
i12 = LOAD_ATTR(i11, '__dict__')
i13 = BINARY_SUBSCR(i12, '_parameters')
i14 = BINARY_SUBSCR(i13, 'bias')
i15 = BINARY_SUBSCR(i13, 'weight')
i16 = BUILD_TUPLE('view', i6)
i17 = OPAQUE(i8, i16, CONSTANT({}))
i18 = LOAD_ATTR(i17, 'func')
i19 = Instruction(opname='BINARY_ADD', opcode=23, arg=None, argval=None, argrepr='', offset=10, starts_line=None, is_jump_target=False)()
i20 = BINARY_SUBSCR(i19, 2)
i21 = BINARY_SUBSCR(i19, 1)
i22 = LOAD_ATTR(i17, 'args')
i23 = BINARY_SUBSCR(i22, 0)
i24 = BUILD_TUPLE(i20, i21, i23)
i25 = OPAQUE(i18, i24, CONSTANT({}))
i26 = BUILD_TUPLE(i14, i15, i25)
i27 = OPAQUE(CONSTANT([Symbol name=linear]), i26, CONSTANT({}))
i28 = BUILD_TUPLE(i10, i27)
i29 = OPAQUE(CONSTANT([Symbol name=topk]), i28, CONSTANT({}))
i30 = BINARY_SUBSCR(i29, 1)
i31 = BUILD_TUPLE('unsqueeze', i30)
i32 = OPAQUE(i8, i31, CONSTANT({}))
i33 = LOAD_ATTR(i32, 'func')
i34 = BUILD_TUPLE(i21, i30)
i35 = OPAQUE(i33, i34, CONSTANT({}))
i36 = Instruction(opname='COMPARE_OP', opcode=107, arg=2, argval='==', argrepr='==', offset=104, starts_line=None, is_jump_target=False)(i9, i35)
i37 = BUILD_TUPLE('permute', i36)
i38 = OPAQUE(i8, i37, CONSTANT({}))
i39 = LOAD_ATTR(i38, 'func')
i40 = BINARY_SUBSCR(i19, 3)
i41 = BUILD_TUPLE(i40, i20, i21, i36)
i42 = OPAQUE(i39, i41, CONSTANT({}))
i43 = LOAD_ATTR(i1, 'forward')
i44 = LOAD_ATTR(i43, '__func__')
i45 = LOAD_ATTR(i44, '__globals__')
i46 = BINARY_SUBSCR(i45, '__builtins__')
i47 = LOAD_ATTR(i46, 'zip')
i48 = BUILD_TUPLE(i4, i42, i47)
i49 = OPAQUE(CONSTANT(<built-in method __new__ of type object at 0x55c1c13de340>), i48, CONSTANT({}))
i50 = BUILD_TUPLE(i49)
i51 = OPAQUE(CONSTANT(<slot wrapper '__next__' of 'zip' objects>), i50, CONSTANT({}))
i52 = BINARY_SUBSCR(i51, 1)
i53 = LOAD_ATTR(i52, '__dict__')
i54 = BINARY_SUBSCR(i53, '_parameters')
i55 = BINARY_SUBSCR(i54, 'weight')
)
Thank you @IvanYashchuk The underlying issue is "need lookaside for zip in interpreter".
Seems that this is a good issue for someone who wants to take a look at our great Python interpreter (thunder/core/interpreter.py), it's not trivial but should be relatively self-contained.
@t-vi what would be a similar lookaside to start from for anyone wanting to approach this?
I would I assume that functools.reduce is not that bad to look at, specifically because of the test coverage (generic iterables, custom iterables, etc.).
Doesn't the error message NotImplementedError: unpacking from OPAQUE <slot wrapper '__next__' of 'zip' objects> mean that the problem is in the next interpretation and not in zip?
@t-vi to me it looks like there some errors in the unpacking more than the zip, might it be the opaque ModuleList container?
I would like to participate in the solution for this issue. I have some knowledge of ast library and worked on clang static analyzer before.
Hi @Nachiket18 ,
great!
So the main issue here seems to be that we would want the Thunder Python Interpreter to be able to "see through" zip (i.e. link the things that are yielded by the zip iteration with the arguments to zip).
I think it should be as easy as implementing a zip function in pure Python and putting it in a lookaside.
If you look at def _enumerate_lookaside(obj: Iterable, start: int = 0): in thunder/core/interpreter.py, that would give you a good idea of how to do it. Testing should be added to thunder/tests/test_interpreter.py for the new zip to give the same as the old. The repro @IvanYashchuk has in the issue could go in test_jit_general.py.
Please let me know if there is anything else you need to get started. Also don't hesitate to reach out if you find you're stuck somewhere.
Hello @t-vi - Thanks for explanation. Sorry for late response. We were stuck at job and school work.
@UltraArceus3 and I wrote some code on our fork along with test cases and it ran to success.
https://github.com/Nachiket18/lightning-thunder/commit/bc1fafe411b8f52ad9ce533a638ee01d6fd908e6
Would like to know your thoughts on the same.
@t-vi - I was wondering if you could review the code that we wrote. Let us know if we are making some mistake and needing some correction.
Looks very reasonable, would love to see a PR.