Support for torchvision models, e.g., a simple ViT
🐛 Bug
I was trying to run a simple torchvision ViT and am getting the following error:
File "/teamspace/studios/this_studio/minimal-vit/01_pytorch-vit.py", line 136, in <module>
train(
File "/teamspace/studios/this_studio/minimal-vit/01_pytorch-vit.py", line 31, in train
logits = model(features)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 194, in forward
res = self._forward_fn(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 611, in fn_
cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 262, in cache_info_wrapper
res = fn(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 498, in get_computation_and_inputs
prologue_trc, computation_trc, *maybe_epilogue = interpreter(
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 175, in _general_frontend
return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/jit_ext.py", line 1386, in thunder_general_jit
result = jfn(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 6580, in fn_
raise e
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 6543, in fn_2
return fn(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 298, in forward
x = self.encoder(x)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 157, in forward
return self.ln(self.layers(self.dropout(input)))
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 113, in forward
x, _ = self.self_attention(x, x, x, need_weights=False)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1236, in forward
any_nested = query.is_nested or key.is_nested or value.is_nested
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 1253, in wrapping_wrapper
res = ufn(*uargs, **ukwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/proxies.py", line 1234, in __getattr__
method: None | Callable = resolve_method(attr, self)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/langctxs.py", line 68, in resolve_method
method: Callable = ctx.get_method(id, *args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/torch/langctx.py", line 40, in get_method
raise AttributeError(f"The {self.name} language context has no method {id}")
AttributeError: The torch language context has no method is_nested
Not sure how to go about debugging this. I thought that sharing this may help improving thunder in terms of supporting more models and edge cases
To Reproduce
Steps to reproduce the behavior:
I attached self-contained code in the zip.
# Runs PyTorch eager, works ok!
python 01_pytorch-vit.py
# Runs torch.compile, works ok!
python 01_pytorch-vit.py --compilation_option "torch.compile"
# Runs thunder.jit(), fails! (See error above)
python 01_pytorch-vit.py --compilation_option "thunder_default"
Code sample
See zip attached
Expected behavior
Either a clearer error message or ideally it should work :)
Environment
Same as Zero to Thunder studio.
cc @apaz-cli
Hey Seb! @nikitaved Just merged a PR to improve the messaging here: #78
The TLDR is that you want to run examine on the model to get a report of what's not working:
from thunder.examine import examine
x = ...
model = ...
examine(model, x)
It would be useful if you can include here what it reports for those models.
This is nice, thanks! The report is
Files already downloaded and verified
Found 18 distinct operations, of which 15 (83.3%) are supported
Please file an issue requesting the following operators here: https://github.com/Lightning-AI/lightning-thunder/issues/new
TensorBase.is_nested
multi_head_attention_forward of torch.nn.functional
_assert of torch
So the culprit seems to be https://github.com/pytorch/pytorch/blob/1e8d4b389b5f03cea191ed558051f036fe04f92d/torch/nn/functional.py#L5163
triage review:
We think there are three issues here:
First issue:
- we should improve the error messages for attribute access on tensors to be clear that thunder does not support that attribute (yet)
- maybe we should make a list of attributes on the torch tensor object that are not yet supported by thunder
Second issue:
- we could start by always setting
is_nestedto False and ensuring programs passed nested tensors fail
Third issue:
- there's a question of supporting nested tensors (see https://pytorch.org/docs/stable/nested.html)
Can we break this issue up into those three, @rasbt?
This sounds totally reasonable, please feel free to break it up into these three.
Re first issue: Not sure if that's feasible, but perhaps even automatically calling examine upon failure could not be a bad thing for users.
We do seem to be able to run ResNet as of today and vit_b_16 (at least), thanks to #584 and #633 . :tada: