cuBLAS Error
When my pytorch is 2.5 and my transformer engine is 2.1, I run the following original network structure:
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
def forward(self, x, freqs):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
x = flash_attention(
q=rope_apply(q, freqs, self.num_heads),
k=rope_apply(k, freqs, self.num_heads),
v=v,
num_heads=self.num_heads
)
return self.o(x)
Note that at this point, I have changed the lightning network structure to the corresponding te.xx, but it reports the following error:
rank1]: File "/output/DiffSynth-Studio/diffsynth/models/wan_video_dit.py", line 128, in forward [rank1]: q = self.norm_q(self.q(x)) [rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl [rank1]: return self._call_impl(*args, **kwargs) [rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl [rank1]: return forward_call(*args, **kwargs) [rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in fn [rank1]: return fn(*args, **kwargs) [rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/transformer_engine/pytorch/module/linear.py", line 1085, in forward [rank1]: out = linear_fn(*args) [rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply [rank1]: return super().apply(*args, **kwargs) # type: ignore[misc] [rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/transformer_engine/pytorch/module/linear.py", line 231, in forward [rank1]: out, *, rs_out = general_gemm( [rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/transformer_engine/pytorch/cpp_extensions/gemm.py", line 141, in general_gemm [rank1]: out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) [rank1]: RuntimeError: /TransformerEngine/transformer_engine/common/gemm/cublaslt_gemm.cu:282 in function cublas_gemm: cuBLAS Error: an unsupported value or parameter was passed to the function
Epoch 0: 0%| | 0/63 [00:04<?, ?it/s]
But when my PyTorch is 2.4 and my Transformer Engine is 1.3, it does not report any errors, but enters a state where the memory usage remains unchanged, as if it is crashing.
I am also running into this issue with v2.1, whereas I did not have this problem in v2.0. I am using CUDA 12.4, CUDNN 9.0.0
Here is a simple snippet to reproduce the issue:
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, Format
# Layer configuration
hidden_size = 4096
sequence_length = 2048
batch_size = 8
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = torch.bfloat16
model = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_attention_heads,
attn_input_format="bshd",
params_dtype=dtype,
).cuda()
x = torch.rand(batch_size, sequence_length, hidden_size, dtype=dtype).cuda()
with te.fp8_autocast(enabled=True):
y = model(x).sum()
y.backward()
print("Done")
Environment
(llm) pvalois@pegasus02:~/llm$ uv pip list
Package Version
----------------------------- --------------
absl-py 2.1.0
accelerate 1.5.2
aiohappyeyeballs 2.6.1
aiohttp 3.11.14
aiosignal 1.3.2
alabaster 1.0.0
antlr4-python3-runtime 4.9.3
anyio 4.9.0
argon2-cffi 23.1.0
argon2-cffi-bindings 21.2.0
arrow 1.3.0
asttokens 3.0.0
async-lru 2.0.5
attrs 25.3.0
babel 2.17.0
beautifulsoup4 4.13.3
bitsandbytes 0.44.1
black 25.1.0
bleach 6.2.0
boto3 1.37.16
botocore 1.37.16
certifi 2025.1.31
cffi 1.17.1
cfgv 3.4.0
chardet 5.2.0
charset-normalizer 3.4.1
click 8.1.8
colorama 0.4.6
comm 0.2.2
contourpy 1.3.1
curio 1.6
cycler 0.12.1
dataproperty 1.1.0
datasets 3.4.1
debugpy 1.8.13
decorator 5.2.1
deepspeed 0.9.3
defusedxml 0.7.1
dill 0.3.8
distlib 0.3.9
docrepr 0.2.0
docstring-parser 0.16
docutils 0.21.2
einops 0.8.1
evaluate 0.4.3
exceptiongroup 1.2.2
executing 2.2.0
fastapi 0.115.11
fastjsonschema 2.21.1
filelock 3.18.0
fire 0.7.0
flash-attn 2.7.3
fonttools 4.56.0
fqdn 1.5.1
frozenlist 1.5.0
fsspec 2024.12.0
grpcio 1.71.0
h11 0.14.0
hf-transfer 0.1.9
hjson 3.1.0
httpcore 1.0.7
httptools 0.6.4
httpx 0.28.1
huggingface-hub 0.29.3
hydra-core 1.3.2
identify 2.6.9
idna 3.10
imagesize 1.4.1
importlib-metadata 8.6.1
importlib-resources 6.5.2
iniconfig 2.1.0
intersphinx-registry 0.2501.23
ipykernel 6.29.5
ipyparallel 9.0.1
ipython 8.34.0
ipywidgets 8.1.5
isoduration 20.11.0
jedi 0.19.2
jinja2 3.1.6
jmespath 1.0.1
joblib 1.4.2
json5 0.10.0
jsonargparse 4.32.1
jsonlines 4.0.0
jsonpointer 3.0.0
jsonschema 4.23.0
jsonschema-specifications 2024.10.1
jupyter-client 8.6.3
jupyter-core 5.7.2
jupyter-events 0.12.0
jupyter-lsp 2.2.5
jupyter-server 2.15.0
jupyter-server-terminals 0.5.3
jupyterlab 4.3.6
jupyterlab-pygments 0.3.0
jupyterlab-server 2.27.3
jupyterlab-widgets 3.0.13
kiwisolver 1.4.8
lightning 2.5.0.post0
lightning-thunder 0.2.1
lightning-utilities 0.14.1
litdata 0.2.17
litgpt 0.5.7
litserve 0.2.4
lm-eval 0.4.8
loguru 0.7.3
looseversion 1.3.0
lxml 5.3.1
markdown 3.7
markdown-it-py 3.0.0
markupsafe 3.0.2
matplotlib 3.10.1
matplotlib-inline 0.1.7
mbstrdecoder 1.1.4
mdurl 0.1.2
mistune 3.1.3
more-itertools 10.6.0
mpi4py 4.0.3
mpmath 1.3.0
multidict 6.2.0
multiprocess 0.70.16
mypy-extensions 1.0.0
nbclient 0.10.2
nbconvert 7.16.6
nbformat 5.10.4
nest-asyncio 1.6.0
networkx 3.4.2
ninja 1.11.1.3
nltk 3.9.1
nodeenv 1.9.1
notebook 7.3.3
notebook-shim 0.2.4
numexpr 2.10.2
numpy 1.26.4
nvidia-cublas-cu12 12.4.5.8
nvidia-cuda-cupti-cu12 12.4.127
nvidia-cuda-nvrtc-cu12 12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.2.1.3
nvidia-curand-cu12 10.3.5.147
nvidia-cusolver-cu12 11.6.1.9
nvidia-cusparse-cu12 12.3.1.170
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.4.127
omegaconf 2.3.0
opt-einsum 3.4.0
optree 0.14.1
outcome 1.3.0.post0
overrides 7.7.0
packaging 24.2
pandas 2.2.3
pandocfilters 1.5.1
parso 0.8.4
pathspec 0.12.1
pathvalidate 3.2.3
peft 0.15.0
pexpect 4.9.0
pickleshare 0.7.5
pillow 11.1.0
platformdirs 4.3.7
pluggy 1.5.0
portalocker 3.1.1
pre-commit 4.2.0
prometheus-client 0.21.1
prompt-toolkit 3.0.50
propcache 0.3.0
protobuf 6.30.1
psutil 7.0.0
ptyprocess 0.7.0
pure-eval 0.2.3
py-cpuinfo 9.0.0
pyarrow 19.0.1
pybind11 2.13.6
pycparser 2.22
pydantic 1.10.21
pygments 2.19.1
pyparsing 3.2.1
pytablewriter 1.2.1
pytest 8.3.5
pytest-asyncio 0.21.2
python-dateutil 2.9.0.post0
python-dotenv 1.0.1
python-json-logger 3.3.0
pytorch-lightning 2.5.1
pytz 2025.1
pyyaml 6.0.2
pyzmq 26.3.0
qtconsole 5.6.1
qtpy 2.4.3
referencing 0.36.2
regex 2024.11.6
requests 2.32.3
rfc3339-validator 0.1.4
rfc3986-validator 0.1.1
rich 13.9.4
roman-numerals-py 3.1.0
rouge-score 0.1.2
rpds-py 0.23.1
s3transfer 0.11.4
sacrebleu 2.5.1
safetensors 0.5.3
scikit-learn 1.6.1
scipy 1.15.2
send2trash 1.8.3
sentencepiece 0.2.0
setuptools 77.0.1
six 1.17.0
sniffio 1.3.1
snowballstemmer 2.2.0
sortedcontainers 2.4.0
soupsieve 2.6
sphinx 8.2.3
sphinx-rtd-theme 3.0.2
sphinxcontrib-applehelp 2.0.0
sphinxcontrib-devhelp 2.0.0
sphinxcontrib-htmlhelp 2.1.0
sphinxcontrib-jquery 4.1
sphinxcontrib-jsmath 1.0.1
sphinxcontrib-qthelp 2.0.0
sphinxcontrib-serializinghtml 2.0.0
sqlitedict 2.1.0
stack-data 0.6.3
starlette 0.46.1
sympy 1.13.1
tabledata 1.3.4
tabulate 0.9.0
tcolorpy 0.1.7
tenacity 9.0.0
tensorboard 2.19.0
tensorboard-data-server 0.7.2
tensorboardx 2.6.2.2
termcolor 2.5.0
terminado 0.18.1
testpath 0.6.0
threadpoolctl 3.6.0
tinycss2 1.4.0
tokenizers 0.21.1
torch 2.5.1
torchmetrics 1.6.3
torchvision 0.20.1
tornado 6.4.2
tqdm 4.67.1
tqdm-multiprocess 0.0.11
traitlets 5.14.3
transformer-engine 2.1.0
transformer-engine-cu12 2.1.0
transformer-engine-torch 2.1.0
transformers 4.47.1
trio 0.29.0
triton 3.1.0
typepy 1.3.4
types-python-dateutil 2.9.0.20241206
typeshed-client 2.7.0
typing-extensions 4.12.2
tzdata 2025.1
uri-template 1.3.0
urllib3 2.3.0
uvicorn 0.34.0
uvloop 0.21.0
virtualenv 20.29.3
watchfiles 1.0.4
wcwidth 0.2.13
webcolors 24.11.1
webencodings 0.5.1
websocket-client 1.8.0
websockets 15.0.1
werkzeug 3.1.3
widgetsnbextension 4.0.13
word2number 1.1
xxhash 3.5.0
yarl 1.18.3
zipp 3.21.0
zstandard 0.23.0
Hi, I run into this issue with both TE 2.1.0 and 2.0.0, with torch 2.6.0+cu126, cuda is 12.8 and cudnn is 9.8.0
Same issue, TE 2.1.0, torch 2.5.1+cu124, cuda 12.4, cudnn 9.8.0. TE 1.13.0 works fine with my environment.
I am also running into this issue, torch 2.6.0+cuda126, cuda 12.6.2, cudnn 9.8.0, TE 2.1.0
Same Issue, torch 2.6.0/te 2.1.0, cuda 12.4
Only TE 1.13.0 works fine
FYI, I was able to run 2.0.0 by building it from source, but not 2.1.x from PyPI. However, 2.0.x is not available in PyPI.
I think I see the problem. Just to confirm - all of these issues are when running TE installed from PyPi, correct?
@ksivaman @timmoon10 FYI. I believe the problem is that when we create the wheel, the image building TE has new cuBLAS (e.g. 12.8 or 12.9) and so the checks like this one http://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/gemm/cublaslt_gemm.cu#L411 pass in that environment. But then when people run it with older cuBLAS, it complains. We need to have both the compile and runtime check for the cuBLAS version. @cyanguwa FYI, as similar problem might happen on the cuDNN side (although there I hope that cudnn frontend shields us from that).
Hello. I would like to know if there is a known solution or work around. I updated my environment to CUDA 12.8 (cuBLAS 12.9.0.13) and Transformer Engine to 2.3.0, but I still see the error:
[rank0]: File "/home/pedro/Documents/llm/.venv/lib/python3.11/site-packages/transformer_engine/pytorch/cpp_extensions/gemm.py", line 111, in general_gemm
[rank0]: out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: /TransformerEngine/transformer_engine/common/gemm/cublaslt_gemm.cu:409 in function cublas_gemm: cuBLAS Error: an unsupported value or parameter was passed to the function
Downgrading transformer_engine seems to be a workaround.
cuda 12.2, pytorch 2.7.1+cu126, cudnn 9.5.1 Error with: transformer engine(torch) 2.4.0 No error with: transformer engine(torch) 1.13.0.
I think I see the problem. Just to confirm - all of these issues are when running TE installed from PyPi, correct?
@ptrendx Yes, I can confirm on this! After reading this post, I uninstalled the PyPi version, and built TE from source using the following command, as suggested in the official documentation, and things work properly now :-)
pip3 install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable
My previous TE version installed via PyPi was 2.2.0, and now using 2.4.0+b43596b.
Thanks again for the suggestion!