lightning-thunder
lightning-thunder copied to clipboard
TE + cudagraphs
🐛 Bug
Compiling a model with Transformer Engine executor with Cudagraphs enabled is not supported
To Reproduce
Code sample
import torch
import thunder
class Module(torch.nn.Module):
def __init__(self, in_features, out_features) -> None:
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features)
def forward(self, x: torch.Tensor):
return self.linear(x)
with torch.device('cuda'):
m = 1
in_features = 4096 * m
out_features = 4096 * m
model = Module(in_features, out_features)
x = torch.randn(768, in_features, requires_grad=True)
jmodel_def = thunder.jit(model, executors=['transformer_engine'], use_cudagraphs=True)
y = jmodel_def(x)
Expected behaviour
Traceback (most recent call last):
File "/workspace/workdir/examples/dev/te.py", line 32, in <module>
y = jmodel_def(x)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/workdir/thunder/core/module.py", line 63, in forward
res = self._forward_fn(*args, **kwargs)
File "/workspace/workdir/thunder/__init__.py", line 781, in fn_
result = cache_entry.computation_fn(*inps)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
return func(*args, **kwargs)
File "thunder.augmented_forward_fn_3", line 12, in augmented_forward_fn
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/workdir/thunder/executors/transformer_engineex.py", line 212, in forward
weight_fp8, weight_t_fp8 = self.get_fp8_weight_version_compat(
File "/workspace/workdir/thunder/executors/transformer_engineex.py", line 293, in get_fp8_weight_version_compat
weight_fp8 = self.get_fp8_workspace(
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/base.py", line 965, in get_fp8_workspace
out.cast_transpose_(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/float8_tensor.py", line 732, in cast_transpose_
fp8_meta = self._fp8_meta[fp8_meta_key]
KeyError: 'scaling_fwd'
Environment
- PyTorch Version (e.g., 1.0): 2.5.0a0+gitb0fc6aa
- Thunder: f9dbf9ce6e2b9c3e6885a8ef84d3ffcadbff5f87
- OS (e.g., Linux): Linux
- Python version: 3.10.12
- CUDA/cuDNN version: 12.6
- GPU models and configuration: RTX ADA 6000
- Any other relevant information: Tested on NVIDIA internal docker containers