Operator support for `F.one_hot`
🐛 Bug
thunder fails When attempting to compile a graph containing torch.nn.functional.one_hot within the forward pass.
The error message indicates that the input to the method must be a Tensor, but a TensorProxy is received instead.
To Reproduce
Steps to reproduce the behavior:
- Define a PyTorch model class with a forward pass involving
F.one_hotto convert the input tensor to a one-hot encoded representation. - Create an instance of the model and evaluate it on a random input tensor.
- Compile the model using
thunder.jit. - Call the compiled model with the same input tensor.
Example
import thunder
class MLP(nn.Module):
def __init__(self, hidden_size=1024):
super(MLP, self).__init__()
self.hidden = nn.Linear(6 * 256, hidden_size, bias=False)
self.head = nn.Linear(hidden_size, 32000, bias=False)
def forward(self, inputs):
x = F.one_hot(inputs, 6).reshape(-1, 6 * 256).float()
x = self.hidden(x)
logits = self.head(x)
return logits
x = torch.randint(0, 6, (1, 256))
model = MLP(1024).eval()
print(model(x))
model = thunder.jit(model)
print(model(x))
Output
tensor([[-0.1134, -0.0827, -0.0205, ..., 0.0757, 0.0066, 0.0974]],
grad_fn=<MmBackward0>)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-6-6425e5faad6e>](https://localhost:8080/#) in <cell line: 23>()
21
22 model = thunder.jit(model)
---> 23 print(model(x))
16 frames
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1509 # type ignore was added because at this point one knows that
1510 # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any]
-> 1511 name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None # type: ignore[index, operator] # noqa: B950
1512 if name:
1513 tracing_state.push_scope(name)
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1518 finally:
1519 if recording_scopes:
-> 1520 tracing_state.pop_scope()
1521 return result
1522
[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in forward(self, *args, **kwargs)
192
193 def forward(self, *args, **kwargs):
--> 194 res = self._forward_fn(*args, **kwargs)
195 return res
196
[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in fn_(*args, **kwargs)
609 cs.calls += 1
610
--> 611 cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
612 cs.last_trace_host_execution_start = time.time_ns()
613
[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in cache_info_wrapper(*args, **kwargs)
260 tok = _cache_info_ctx.set({})
261 try:
--> 262 res = fn(*args, **kwargs)
263 finally:
264 _cache_info_ctx.reset(tok)
[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in get_computation_and_inputs(*args, **kwargs)
496 prologue_trc: TraceCtx
497 computation_trc: TraceCtx
--> 498 prologue_trc, computation_trc, *maybe_epilogue = interpreter(
499 fn, args, kwargs, sharp_edges=cd.sharp_edges
500 )
[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in _general_frontend(fn, args, kwargs, sharp_edges)
173 # Translates the Python function to a thunder program using the thunder interpreter
174 def _general_frontend(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> tuple[TraceCtx, TraceCtx]:
--> 175 return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
176
177
[/usr/local/lib/python3.10/dist-packages/thunder/core/jit_ext.py](https://localhost:8080/#) in thunder_general_jit(fn, args, kwargs, sharp_edges)
1384 with general_jit_ctx(ctx):
1385 with tracectx(computation_trace):
-> 1386 result = jfn(*args, **kwargs)
1387 prims.python_return(result)
1388 process_recorded_modifications(ctx, epilogue_trace)
[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in fn_(*args, **kwargs)
6578 assert isinstance(e, BaseException), e
6579 runtimectx.curexc = None
-> 6580 raise e
6581
6582 return interpretation_result
[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in fn_2()
6541 def getfn():
6542 def fn_2(args, kwargs):
-> 6543 return fn(*args, **kwargs)
6544
6545 return fn_2
[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _impl()
5940
5941 def _impl(fn, *args, **kwargs):
-> 5942 return fn.__func__(fn.__self__, *args, **kwargs)
5943
5944 return _interpret_call(_impl, wrapped_fn, *args, **kwargs) # type: ignore
[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _wrapped_call_impl()
1509 # type ignore was added because at this point one knows that
1510 # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any]
-> 1511 name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None # type: ignore[index, operator] # noqa: B950
1512 if name:
1513 tracing_state.push_scope(name)
[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _impl()
5940
5941 def _impl(fn, *args, **kwargs):
-> 5942 return fn.__func__(fn.__self__, *args, **kwargs)
5943
5944 return _interpret_call(_impl, wrapped_fn, *args, **kwargs) # type: ignore
[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _call_impl()
1518 finally:
1519 if recording_scopes:
-> 1520 tracing_state.pop_scope()
1521 return result
1522
[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _impl()
5940
5941 def _impl(fn, *args, **kwargs):
-> 5942 return fn.__func__(fn.__self__, *args, **kwargs)
5943
5944 return _interpret_call(_impl, wrapped_fn, *args, **kwargs) # type: ignore
[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in forward()
9
10 def forward(self, inputs):
---> 11 x = F.one_hot(inputs, 6).reshape(-1, 6 * 256).float()
12 x = self.hidden(x)
13 logits = self.head(x)
[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)
6067 kwargs_ = {unwrap(k): unwrap(v) for k, v in kwargs.items()}
6068 try:
-> 6069 opaque_result: Any = fn(*args_, **kwargs_)
6070 except Exception as e:
6071 runtimectx.curexc = e
TypeError: one_hot(): argument 'input' (position 1) must be Tensor, not TensorProxy
Environment
- OS: Ubuntu/Google Colab
- Python Version: 3.10
- PyTorch Version: 2.3.0.dev20240314+cu121
- Thunder Version: 0.1.0
- Installation:
pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com
pip install lightning-thunder
Additional context
- Other functional methods like
F.reludoesn't seem to raise the issue.
Hello @kyo-takano , thank you for trying out thunder! This is the thunder way (we should have a FAQ) of saying "this operator is not supported yet in thunder". We are actually working on giving a better error message.
If you want to have a comprehensive check on whether a model you want to run has ops we don't support yet, you can use examine to check.
Implementing an operator like one_hot is not entirely trivial, but I would guess that it could be a great entry point for an aspiring thunder developer. :wink: (Reminds me a lot of the early PyTorch days, my first PR almost 7 years ago was to get double derivatives for clamp going.) one_hot might be decomposed using scatter add.)
Thanks @t-vi
This is just a heads up, as I'd expect to see many issues of the same kind, given that functional operations like this are used everywhere in PyTorch implementations.
This bug is something that does not occur with the native compiler regardless of the chosen backend. So, if thunder.jit is intended to be a drop-in replacement for torch.compile, it might be necessary to address this issue to ensure a seamless transition for users, depending on the number & significance of such incompatible operations.
Thank you @kyo-takano, the limitation is not related to functional operations per se, but the number of ops we currently support. They will grow quickly, but it’ll take a bit.
For this reason (and we’ll add a FAQ soon) Thunder is not intended to be a drop in replacement for torch.compile at this stage, with the same coverage of torch.compile. However we do focus on:
- making it easy to have a list of the gaps ahead of time: run examine https://github.com/Lightning-AI/lightning-thunder/blob/main/docs/source/fundamentals/examine.rst
- making it easy to contribute missing ops through a clean API and automatic testing
What model are you targeting specifically?
Thank you @lantiga
the limitation is [...] the number of ops we currently support
Thunder is not intended to be a drop in replacement for torch.compile at this stage
Understood. Thanks for clarifying these points.
What model are you targeting specifically?
I don't have a particular model I need to compile with thunder right away.
I just thought it would be beneficial to flag an issue about a widely used operation that is currently incompatible.
Hi everyone, I'm new to contributing to open-source and very interested in the Thunder project. I noticed this issue on and I'd love to take on this opportunity to contribute. I have some experience with Python and Pytorch and am eager to learn more about Thunder and PyTorch internals. Could I possibly take on this issue? Any guidance or suggestions on how to get started would be greatly appreciated!
Thank you for the support. I am starting to work on implementing support for the one_hot operator as discussed. I'll keep the thread updated with my progress and any questions I might have.
Could one of the collaborators please assign this issue to me? Thanks!
Awesome, welcome aboard! Looking forward to your contribution. One super-useful thing you could do is record your journey as a first contributor so we can create more onboarding material and address the papercuts.
Hey, @kyo-takano , thank you for the issue!
Hey, @shaharelys , super excited about your interest in helping us out!
A couple of, hopefully, helpful notes:
- As @t-vi pointed out, this operation might not necessarily be trivial to implement, so I would advice to familiarize oneself with how this operation is implemented in PyTorch. Special attention has to be paid to handling wrong and edge-case inputs.
- Once the behavior of
one_hotis understood, it is about time to think about implementing it! It is very likely thatone_hotcould be implemented as a composition of othertorch-like operations, so it does, as we say it, "decompose" into "primitive" operations. So, if your algorithm does depend onscatter_add, we can check whether this operation is available to us. I like usinggit grep, and it gives the following:
>>> git grep "def scatter_add"
thunder/clang/__init__.py:def scatter_add(a: TensorProxy, /, indices: TensorProxy, value: TensorProxy, dim: int) -> TensorProxy:
thunder/core/prims.py:def scatter_add_meta(a: TensorProxy, /, index: TensorProxy, value: TensorProxy, dim: int) -> TensorProxy:
thunder/tests/opinfos.py:def scatter_add_sample_generator(op, device, dtype, requires_grad, **kwargs):
thunder/torch/__init__.py:def scatter_add(a: TensorLike, /, dim: int, index: TensorLike, src: TensorLike) -> TensorLike:
- Once all the required "primitive" operations are available (defined in
clang/__init__.pyand/orcore/prims.pyand/ortorch.__init__.py), all high-level decompositions are implemented inthunder/torch/__init__.pyand these try to match the behavior of PyTorch. A good and rather complex example of a non-trivial decomposition could be theinterpolateoperation. https://github.com/Lightning-AI/lightning-thunder/blob/be16444471352eeb1bd11c468efecb2c3b91e5f1/thunder/torch/init.py#L3534 - Never forget to always test things! The grep from above contains the
thunder/tests/opinfos.py:def scatter_add_sample_generator(op, device, dtype, requires_grad, **kwargs):line. It defines a generator that produces inputs thescatter_addfunction is being tested on. You could have a look at other entries ofscatter_addin the file to figure out all the remaining part which are necessary to get the testing going. Once the code is filled in, you can run theone_hot-related tests aspytest -sv thunder/tests/test_ops.py -k one_hot.
If you have any questions, do not hesitate to ping me! Have fun and thank you!
@nikitaved, thank you for the detailed guidance!
-
I will dive into PyTorch's implementation to understand one_hot more deeply.
-
As for the mention of def scatter_add, as I'm still familiarizing myself with the project's structure, its role and how it aids in achieving our goal with one_hot isn't fully clear to me yet.
-
Could you recommend specific areas of the codebase or modules as good starting points? From your comment, I gather that clang/init.py, core/prims.py, and torch/init.py are crucial. Are there other parts I should focus on?
-
Am I correct in understanding that the approach you are suggestion is to decompose one_hot into a set of primitive operations already supported by Thunder?
Also, I'd like to know the typical timeline for resolving issues like this. Recognizing that I'm new and still catching up, I want to set realistic expectations for my contribution timeline. What is usually anticipated?
Thank you for your support!
@nikitaved, thank you for the detailed guidance!
Always happy to help!
Could you recommend specific areas of the codebase or modules as good starting points? From your comment, I gather that clang/init.py, core/prims.py, and torch/init.py are crucial. Are there other parts I should focus on?
These should be more than enough to start with operations.
I would recommend specifically focusing on thunder/torch/__init__.py for now. prims and clang could be optional.
Also have a look at sample input generators in thunder/tests/opinfos.py.
Am I correct in understanding that the approach you are suggestion is to decompose one_hot into a set of primitive operations already supported by Thunder?
You are correct! This simplifies things a lot!
Also, I'd like to know the typical timeline for resolving issues like this. Recognizing that I'm new and still catching up, I want to set realistic expectations for my contribution timeline. What is usually anticipated?
Take as much time as you need. The outcome may turn out to be just as simple as adding a gelu. https://github.com/Lightning-AI/lightning-thunder/blob/f2e48b39eda18713439891da9fcbf0338e86e4d9/thunder/torch/init.py#L1182
In fact, if you think that we are missing some simple activation functions, you could open an issue, work on them, and then return to one_hot back once you are comfortable with the codebase.
Awesome, welcome aboard! Looking forward to your contribution. One super-useful thing you could do is record your journey as a first contributor so we can create more onboarding material and address the papercuts.
Hey @lantiga , per your request, I will try to list things that it took some time for me to understand and I think a better onboarding material could help. For me the easiest way is to just write these here under this thread from time to time. Will that work for you guys?
Just note that I am very new to open source and I am still a junior so some of my comments might be subject to my lack of experience and not to lack of onboarding material. Anyways, for me it is hard to tell the difference so this is just a little heads up (:
Here are my initial comments (some are also questions): Context - Had my first review today over core/prims, core/symbol, clang/init, thunder/torch/init
- At the high level, I'd love to get some insight into the idea of Thunder. Like, how could you possibly make torch faster?
- Should all pytorch primitives be listed also under thunder primitives? (After a short glimse into the pytorch prim file I guess the answer is no, but why?)
- Could you make some order for me in the roles of clang, prims, symbol?
- What is the role of TensorProxy as oppose to Tensor and other Proxy types?
- Generally a bit more documentation could help. A good example is from clang/init where it states '# This file defines the operations in thunder.jit's "core" language.'. As for other modules it took some time for me to understand that the 'header' comment is actually refering to the whole file.
Hey, @shaharelys ! Congrats on your hardswish PR! Do you need more help with this one? I can expand more on your questions posted above in your next PR. What do you think?
Hey @nikitaved! Actually, I think I'm handling this for now. We will likely have the PR for this ready soon, I believe.
But since you asked, just a quick question; I've already got the basic structure set up:
def one_hot(tensor: Tensor, num_classes: int = -1) -> Tensor:
if -1 == num_classes:
num_classes = int(torch.max(tensor)) + 1
canvas = torch.zeros(*tensor.shape, num_classes, dtype=torch.long)
index = tensor.unsqueeze(-1)
src = torch.ones_like(index, dtype=torch.long)
return canvas.scatter_add_(dim=-1, index=index, src=src)
I could find all needed methods as building blocks, but not torch.max.
How would you go about this?
@shaharelys I guess we should have argmax.
@shaharelys , I also suspect that we might have issues with num_classes == -1 because we need to get max.item to be able to implement that. Or do we properly proxify scalars now, @mruberry ?
We can still implement this function at least partially and file an issue for potential extensions.
@shaharelys , I also suspect that we might have issues with
num_classes == -1because we need to getmax.itemto be able to implement that. Or do we properly proxify scalars now, @mruberry ? We can still implement this function at least partially and file an issue for potential extensions.
I believe the latest is that we treat numeric inputs as constants, and NumberProxies should only be generated for numbers whose value is only determined at runtime (like those coming from calls to .item()). If that's not the case there's probably a bug.
num_classes == -1 is an issue because the shape of the output tensor would only be determined at runtime, which we do not support. We could look at extending the one_hot function in the future to support a shape parameter, like JAX's nonzero (it calls the parameter size), but I'm not immediately sure how that would work, and I agree with you that we should just limit the functionality for now and raise a NotImplementedError if num_classes == -1.