lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Support `torch.where(condition)` with thunder.jit

Open carmocca opened this issue 1 year ago • 12 comments

🚀 Feature

Motivation

Mixtral uses it: https://github.com/mistralai/mistral-src/blob/8598cf582091a596671be31990448e0620017851/moe_one_file_ref.py#L215

Minimal Repro

import thunder

def fn(cond):
    return torch.where(cond)

thunder.jit(fn)(torch.randn(3) > 0)

Pitch

Support image from https://pytorch.org/docs/stable/generated/torch.where.html

Additional context

We already support torch.where(condition, input, other): https://github.com/search?q=repo%3ALightning-AI%2Flightning-thunder+%22def+where%22&type=code

cc @apaz-cli

carmocca avatar Apr 03 '24 12:04 carmocca

As of now, we cannot support data-dependent ops, alas...

nikitaved avatar Apr 03 '24 13:04 nikitaved

@carmocca , looking at the code I think the solution could be modifying the model in the package. The result of topk can be sorted, and then we do not need to apply where at all. This will also eliminate the device sync (syncs, actually) caused by where.

nikitaved avatar Apr 03 '24 13:04 nikitaved

@nikitaved Faster and better code is very welcome in LitGPT. I benchmarked a few different implementations when this was added and this came out to be the best in general (see description and discussion in https://github.com/Lightning-AI/litgpt/pull/823). It would be useful to see them compared to whatever you propose.

carmocca avatar Apr 03 '24 13:04 carmocca

The error message is not friendly and doesn't tell that torch.where(condition) is not supported properly:

In [1]: import torch

In [2]: import thunder

In [3]: from litgpt import Config

In [4]: from litgpt.model import LLaMAMoE

In [5]: config = Config.from_name("Mixtral-8x7B-v0.1")

In [6]: model = LLaMAMoE(config).to(dtype=torch.bfloat16, device="cuda")

In [7]: jit_model = thunder.jit(model)

In [8]: x = torch.randn(2, config.block_size, config.n_embd, dtype=torch.bfloat16, device="cuda")

In [9]: jit_model(x);

Traceback:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[9], line 1
----> 1 jit_model(x);

File ~/dev/pytorch/main/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/dev/pytorch/main/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/dev/lightning-thunder/thunder/__init__.py:194, in ThunderModule.forward(self, *args, **kwargs)
    193 def forward(self, *args, **kwargs):
--> 194     res = self._forward_fn(*args, **kwargs)
    195     return res

File ~/dev/lightning-thunder/thunder/__init__.py:629, in jit.<locals>.fn_(*args, **kwargs)
    626 cs.last_trace_host_start = time.time_ns()
    627 cs.calls += 1
--> 629 cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
    630 cs.last_trace_host_execution_start = time.time_ns()
    632 result = cache_entry.computation_fn(*inps)

File ~/dev/lightning-thunder/thunder/__init__.py:262, in _with_cache_info_ctx.<locals>.cache_info_wrapper(*args, **kwargs)
    260 tok = _cache_info_ctx.set({})
    261 try:
--> 262     res = fn(*args, **kwargs)
    263 finally:
    264     _cache_info_ctx.reset(tok)

File ~/dev/lightning-thunder/thunder/__init__.py:504, in jit.<locals>.get_computation_and_inputs(*args, **kwargs)
    502     prologue_trc: TraceCtx
    503     computation_trc: TraceCtx
--> 504     prologue_trc, computation_trc, *maybe_epilogue = interpreter(
    505         fn, args, kwargs, sharp_edges=cd.sharp_edges
    506     )
    508 if maybe_epilogue:
    509     epilogue_traces = maybe_epilogue

File ~/dev/lightning-thunder/thunder/__init__.py:175, in _general_frontend(fn, args, kwargs, sharp_edges)
    174 def _general_frontend(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> tuple[TraceCtx, TraceCtx]:
--> 175     return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1440, in thunder_general_jit(fn, args, kwargs, sharp_edges)
   1438 with general_jit_ctx(ctx):
   1439     with tracectx(computation_trace):
-> 1440         result = jfn(*args, **kwargs)
   1441         prims.python_return(result)
   1442         process_recorded_modifications(ctx, epilogue_trace)

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6684, in interpret.<locals>.fn_(*args, **kwargs)
   6682     assert isinstance(e, BaseException), e
   6683     runtimectx.curexc = None
-> 6684     raise e
   6686 return interpretation_result

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6647, in interpret.<locals>.fn_.<locals>.getfn.<locals>.fn_2()
   6646 def fn_2(args, kwargs):
-> 6647     return fn(*args, **kwargs)

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6046, in _call_dispatch.<locals>._impl()
   6045 def _impl(fn, *args, **kwargs):
-> 6046     return fn.__func__(fn.__self__, *args, **kwargs)

File ~/dev/pytorch/main/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl()
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6046, in _call_dispatch.<locals>._impl()
   6045 def _impl(fn, *args, **kwargs):
-> 6046     return fn.__func__(fn.__self__, *args, **kwargs)

File ~/dev/pytorch/main/torch/nn/modules/module.py:1520, in Module._call_impl()
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6046, in _call_dispatch.<locals>._impl()
   6045 def _impl(fn, *args, **kwargs):
-> 6046     return fn.__func__(fn.__self__, *args, **kwargs)

File ~/dev/litgpt/litgpt/model.py:347, in LLaMAMoE.forward()
    345 y = torch.zeros_like(x)  # (B*T, C)
    346 for mask, expert in zip(masks, self.experts):
--> 347     token_idx, expert_idx = torch.where(mask)
    348     y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx])
    349 return y.view(B, T, C)

File ~/dev/lightning-thunder/thunder/core/interpreter.py:1258, in interpreter_needs_wrap.<locals>.wrapping_wrapper(*args, **kwargs)
   1255     ukwargs = kwargs
   1257 try:
-> 1258     res = ufn(*uargs, **ukwargs)
   1260     # If result is a WrappedValue, we trust its provenance record
   1261     if isinstance(res, WrappedValue):

File ~/dev/lightning-thunder/thunder/core/symbol.py:250, in Symbol.__call__(self, *args, **kwargs)
    248 else:
    249     trace.push_scope(subsymbols)
--> 250     result = self.meta(*args, **kwargs)
    251     trace.pop_scope()
    253 bsym = self.bind(*args, **kwargs, output=result, subsymbols=subsymbols)

File ~/dev/lightning-thunder/thunder/core/langctxs.py:124, in langctx.__call__.<locals>._fn(*args, **kwargs)
    122 try:
    123     tok = set_langctx(self.langctx)
--> 124     result = fn(*args, **kwargs)
    125     return result
    126 finally:

TypeError: where() missing 2 required positional arguments: 'a' and 'b'

IvanYashchuk avatar Apr 09 '24 12:04 IvanYashchuk

@IvanYashchuk , looks like we should update the meta function for where. To be frank, I did not even know about this overload...

Might be a very nice issue for external contributors...

nikitaved avatar Apr 09 '24 12:04 nikitaved

triage review:

  • can the call to torch.where(condition) in mixtral use the hypothetical shape parameter to nonzero to make the output shape known at compile-time?
  • we should implement nonzero(..., shape=...)

mruberry avatar Apr 15 '24 19:04 mruberry

nonzero doesn't have a shape= argument. Did you mean as_tuple=?

carmocca avatar Apr 25 '24 22:04 carmocca

nonzero doesn't have a shape= argument. Did you mean as_tuple=?

We were referring to a parameter that would be analogous to jax.lax.nonzero's size parameter:

https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.nonzero.html

mruberry avatar Apr 25 '24 23:04 mruberry

This prototype adds torch.where(boolean_tensor) https://github.com/Lightning-AI/lightning-thunder/pull/303

IvanYashchuk avatar May 01 '24 06:05 IvanYashchuk

@IvanYashchuk @kshitij12345 Is this still unsupported when using ThunderFX? If torch.where(condition) is supported when using ThunderFX (because the operator is sent to PyTorch for execution?), then maybe we can close or amend this issue to refer more specifically to using torch.where(condition) with the Thunder interpreter as the entrypoint?

mruberry avatar Oct 29 '24 16:10 mruberry

torch.where(condition) works with ThunderFX path by sending it to PyTorch. We also have a test for the same.

https://github.com/Lightning-AI/lightning-thunder/blob/b28d5b3536e60fb0b30896bdd4df6e288cf6a5c8/thunder/tests/test_dynamo.py#L346-L349

Will update the issue title to reflect the request for torch.where(condition) not being supported by thunder.jit entrypoint.

kshitij12345 avatar Oct 29 '24 17:10 kshitij12345

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

stale[bot] avatar Apr 16 '25 06:04 stale[bot]