TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

cuBLAS Error

Open wangli68 opened this issue 9 months ago • 8 comments

When my pytorch is 2.5 and my transformer engine is 2.1, I run the following original network structure:

class SelfAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, dim)
        self.o = nn.Linear(dim, dim)
        self.norm_q = RMSNorm(dim, eps=eps)
        self.norm_k = RMSNorm(dim, eps=eps)

    def forward(self, x, freqs):
        q = self.norm_q(self.q(x))
        k = self.norm_k(self.k(x))
        v = self.v(x)
        x = flash_attention(
            q=rope_apply(q, freqs, self.num_heads),
            k=rope_apply(k, freqs, self.num_heads),
            v=v,
            num_heads=self.num_heads
        )
        return self.o(x)

Note that at this point, I have changed the lightning network structure to the corresponding te.xx, but it reports the following error:

rank1]: File "/output/DiffSynth-Studio/diffsynth/models/wan_video_dit.py", line 128, in forward [rank1]: q = self.norm_q(self.q(x)) [rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl [rank1]: return self._call_impl(*args, **kwargs) [rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl [rank1]: return forward_call(*args, **kwargs) [rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in fn [rank1]: return fn(*args, **kwargs) [rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/transformer_engine/pytorch/module/linear.py", line 1085, in forward [rank1]: out = linear_fn(*args) [rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply [rank1]: return super().apply(*args, **kwargs) # type: ignore[misc] [rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/transformer_engine/pytorch/module/linear.py", line 231, in forward [rank1]: out, *, rs_out = general_gemm( [rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/transformer_engine/pytorch/cpp_extensions/gemm.py", line 141, in general_gemm [rank1]: out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) [rank1]: RuntimeError: /TransformerEngine/transformer_engine/common/gemm/cublaslt_gemm.cu:282 in function cublas_gemm: cuBLAS Error: an unsupported value or parameter was passed to the function

Epoch 0: 0%| | 0/63 [00:04<?, ?it/s]

But when my PyTorch is 2.4 and my Transformer Engine is 1.3, it does not report any errors, but enters a state where the memory usage remains unchanged, as if it is crashing.

wangli68 avatar Mar 18 '25 09:03 wangli68

I am also running into this issue with v2.1, whereas I did not have this problem in v2.0. I am using CUDA 12.4, CUDNN 9.0.0

Here is a simple snippet to reproduce the issue:

import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, Format

# Layer configuration
hidden_size = 4096
sequence_length = 2048
batch_size = 8
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = torch.bfloat16

model = te.TransformerLayer(
    hidden_size,
    ffn_hidden_size,
    num_attention_heads,
    attn_input_format="bshd",
    params_dtype=dtype,
).cuda()

x = torch.rand(batch_size, sequence_length, hidden_size, dtype=dtype).cuda()

with te.fp8_autocast(enabled=True):
    y = model(x).sum()
y.backward()

print("Done")

Environment

(llm) pvalois@pegasus02:~/llm$ uv pip list
Package                       Version
----------------------------- --------------
absl-py                       2.1.0
accelerate                    1.5.2
aiohappyeyeballs              2.6.1
aiohttp                       3.11.14
aiosignal                     1.3.2
alabaster                     1.0.0
antlr4-python3-runtime        4.9.3
anyio                         4.9.0
argon2-cffi                   23.1.0
argon2-cffi-bindings          21.2.0
arrow                         1.3.0
asttokens                     3.0.0
async-lru                     2.0.5
attrs                         25.3.0
babel                         2.17.0
beautifulsoup4                4.13.3
bitsandbytes                  0.44.1
black                         25.1.0
bleach                        6.2.0
boto3                         1.37.16
botocore                      1.37.16
certifi                       2025.1.31
cffi                          1.17.1
cfgv                          3.4.0
chardet                       5.2.0
charset-normalizer            3.4.1
click                         8.1.8
colorama                      0.4.6
comm                          0.2.2
contourpy                     1.3.1
curio                         1.6
cycler                        0.12.1
dataproperty                  1.1.0
datasets                      3.4.1
debugpy                       1.8.13
decorator                     5.2.1
deepspeed                     0.9.3
defusedxml                    0.7.1
dill                          0.3.8
distlib                       0.3.9
docrepr                       0.2.0
docstring-parser              0.16
docutils                      0.21.2
einops                        0.8.1
evaluate                      0.4.3
exceptiongroup                1.2.2
executing                     2.2.0
fastapi                       0.115.11
fastjsonschema                2.21.1
filelock                      3.18.0
fire                          0.7.0
flash-attn                    2.7.3
fonttools                     4.56.0
fqdn                          1.5.1
frozenlist                    1.5.0
fsspec                        2024.12.0
grpcio                        1.71.0
h11                           0.14.0
hf-transfer                   0.1.9
hjson                         3.1.0
httpcore                      1.0.7
httptools                     0.6.4
httpx                         0.28.1
huggingface-hub               0.29.3
hydra-core                    1.3.2
identify                      2.6.9
idna                          3.10
imagesize                     1.4.1
importlib-metadata            8.6.1
importlib-resources           6.5.2
iniconfig                     2.1.0
intersphinx-registry          0.2501.23
ipykernel                     6.29.5
ipyparallel                   9.0.1
ipython                       8.34.0
ipywidgets                    8.1.5
isoduration                   20.11.0
jedi                          0.19.2
jinja2                        3.1.6
jmespath                      1.0.1
joblib                        1.4.2
json5                         0.10.0
jsonargparse                  4.32.1
jsonlines                     4.0.0
jsonpointer                   3.0.0
jsonschema                    4.23.0
jsonschema-specifications     2024.10.1
jupyter-client                8.6.3
jupyter-core                  5.7.2
jupyter-events                0.12.0
jupyter-lsp                   2.2.5
jupyter-server                2.15.0
jupyter-server-terminals      0.5.3
jupyterlab                    4.3.6
jupyterlab-pygments           0.3.0
jupyterlab-server             2.27.3
jupyterlab-widgets            3.0.13
kiwisolver                    1.4.8
lightning                     2.5.0.post0
lightning-thunder             0.2.1
lightning-utilities           0.14.1
litdata                       0.2.17
litgpt                        0.5.7
litserve                      0.2.4
lm-eval                       0.4.8
loguru                        0.7.3
looseversion                  1.3.0
lxml                          5.3.1
markdown                      3.7
markdown-it-py                3.0.0
markupsafe                    3.0.2
matplotlib                    3.10.1
matplotlib-inline             0.1.7
mbstrdecoder                  1.1.4
mdurl                         0.1.2
mistune                       3.1.3
more-itertools                10.6.0
mpi4py                        4.0.3
mpmath                        1.3.0
multidict                     6.2.0
multiprocess                  0.70.16
mypy-extensions               1.0.0
nbclient                      0.10.2
nbconvert                     7.16.6
nbformat                      5.10.4
nest-asyncio                  1.6.0
networkx                      3.4.2
ninja                         1.11.1.3
nltk                          3.9.1
nodeenv                       1.9.1
notebook                      7.3.3
notebook-shim                 0.2.4
numexpr                       2.10.2
numpy                         1.26.4
nvidia-cublas-cu12            12.4.5.8
nvidia-cuda-cupti-cu12        12.4.127
nvidia-cuda-nvrtc-cu12        12.4.127
nvidia-cuda-runtime-cu12      12.4.127
nvidia-cudnn-cu12             9.1.0.70
nvidia-cufft-cu12             11.2.1.3
nvidia-curand-cu12            10.3.5.147
nvidia-cusolver-cu12          11.6.1.9
nvidia-cusparse-cu12          12.3.1.170
nvidia-nccl-cu12              2.21.5
nvidia-nvjitlink-cu12         12.4.127
nvidia-nvtx-cu12              12.4.127
omegaconf                     2.3.0
opt-einsum                    3.4.0
optree                        0.14.1
outcome                       1.3.0.post0
overrides                     7.7.0
packaging                     24.2
pandas                        2.2.3
pandocfilters                 1.5.1
parso                         0.8.4
pathspec                      0.12.1
pathvalidate                  3.2.3
peft                          0.15.0
pexpect                       4.9.0
pickleshare                   0.7.5
pillow                        11.1.0
platformdirs                  4.3.7
pluggy                        1.5.0
portalocker                   3.1.1
pre-commit                    4.2.0
prometheus-client             0.21.1
prompt-toolkit                3.0.50
propcache                     0.3.0
protobuf                      6.30.1
psutil                        7.0.0
ptyprocess                    0.7.0
pure-eval                     0.2.3
py-cpuinfo                    9.0.0
pyarrow                       19.0.1
pybind11                      2.13.6
pycparser                     2.22
pydantic                      1.10.21
pygments                      2.19.1
pyparsing                     3.2.1
pytablewriter                 1.2.1
pytest                        8.3.5
pytest-asyncio                0.21.2
python-dateutil               2.9.0.post0
python-dotenv                 1.0.1
python-json-logger            3.3.0
pytorch-lightning             2.5.1
pytz                          2025.1
pyyaml                        6.0.2
pyzmq                         26.3.0
qtconsole                     5.6.1
qtpy                          2.4.3
referencing                   0.36.2
regex                         2024.11.6
requests                      2.32.3
rfc3339-validator             0.1.4
rfc3986-validator             0.1.1
rich                          13.9.4
roman-numerals-py             3.1.0
rouge-score                   0.1.2
rpds-py                       0.23.1
s3transfer                    0.11.4
sacrebleu                     2.5.1
safetensors                   0.5.3
scikit-learn                  1.6.1
scipy                         1.15.2
send2trash                    1.8.3
sentencepiece                 0.2.0
setuptools                    77.0.1
six                           1.17.0
sniffio                       1.3.1
snowballstemmer               2.2.0
sortedcontainers              2.4.0
soupsieve                     2.6
sphinx                        8.2.3
sphinx-rtd-theme              3.0.2
sphinxcontrib-applehelp       2.0.0
sphinxcontrib-devhelp         2.0.0
sphinxcontrib-htmlhelp        2.1.0
sphinxcontrib-jquery          4.1
sphinxcontrib-jsmath          1.0.1
sphinxcontrib-qthelp          2.0.0
sphinxcontrib-serializinghtml 2.0.0
sqlitedict                    2.1.0
stack-data                    0.6.3
starlette                     0.46.1
sympy                         1.13.1
tabledata                     1.3.4
tabulate                      0.9.0
tcolorpy                      0.1.7
tenacity                      9.0.0
tensorboard                   2.19.0
tensorboard-data-server       0.7.2
tensorboardx                  2.6.2.2
termcolor                     2.5.0
terminado                     0.18.1
testpath                      0.6.0
threadpoolctl                 3.6.0
tinycss2                      1.4.0
tokenizers                    0.21.1
torch                         2.5.1
torchmetrics                  1.6.3
torchvision                   0.20.1
tornado                       6.4.2
tqdm                          4.67.1
tqdm-multiprocess             0.0.11
traitlets                     5.14.3
transformer-engine            2.1.0
transformer-engine-cu12       2.1.0
transformer-engine-torch      2.1.0
transformers                  4.47.1
trio                          0.29.0
triton                        3.1.0
typepy                        1.3.4
types-python-dateutil         2.9.0.20241206
typeshed-client               2.7.0
typing-extensions             4.12.2
tzdata                        2025.1
uri-template                  1.3.0
urllib3                       2.3.0
uvicorn                       0.34.0
uvloop                        0.21.0
virtualenv                    20.29.3
watchfiles                    1.0.4
wcwidth                       0.2.13
webcolors                     24.11.1
webencodings                  0.5.1
websocket-client              1.8.0
websockets                    15.0.1
werkzeug                      3.1.3
widgetsnbextension            4.0.13
word2number                   1.1
xxhash                        3.5.0
yarl                          1.18.3
zipp                          3.21.0
zstandard                     0.23.0

Pedrexus avatar Mar 20 '25 06:03 Pedrexus

Hi, I run into this issue with both TE 2.1.0 and 2.0.0, with torch 2.6.0+cu126, cuda is 12.8 and cudnn is 9.8.0

chky1997 avatar Mar 21 '25 06:03 chky1997

Same issue, TE 2.1.0, torch 2.5.1+cu124, cuda 12.4, cudnn 9.8.0. TE 1.13.0 works fine with my environment.

haitian-jiang avatar Mar 23 '25 23:03 haitian-jiang

I am also running into this issue, torch 2.6.0+cuda126, cuda 12.6.2, cudnn 9.8.0, TE 2.1.0

pranayj77 avatar Mar 25 '25 20:03 pranayj77

Same Issue, torch 2.6.0/te 2.1.0, cuda 12.4

is avatar Mar 26 '25 11:03 is

Only TE 1.13.0 works fine

flyingmanPan avatar Mar 27 '25 06:03 flyingmanPan

FYI, I was able to run 2.0.0 by building it from source, but not 2.1.x from PyPI. However, 2.0.x is not available in PyPI.

Pedrexus avatar Mar 31 '25 04:03 Pedrexus

I think I see the problem. Just to confirm - all of these issues are when running TE installed from PyPi, correct?

@ksivaman @timmoon10 FYI. I believe the problem is that when we create the wheel, the image building TE has new cuBLAS (e.g. 12.8 or 12.9) and so the checks like this one http://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/gemm/cublaslt_gemm.cu#L411 pass in that environment. But then when people run it with older cuBLAS, it complains. We need to have both the compile and runtime check for the cuBLAS version. @cyanguwa FYI, as similar problem might happen on the cuDNN side (although there I hope that cudnn frontend shields us from that).

ptrendx avatar Apr 08 '25 23:04 ptrendx

Hello. I would like to know if there is a known solution or work around. I updated my environment to CUDA 12.8 (cuBLAS 12.9.0.13) and Transformer Engine to 2.3.0, but I still see the error:

[rank0]:   File "/home/pedro/Documents/llm/.venv/lib/python3.11/site-packages/transformer_engine/pytorch/cpp_extensions/gemm.py", line 111, in general_gemm
[rank0]:     out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)
[rank0]:                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: /TransformerEngine/transformer_engine/common/gemm/cublaslt_gemm.cu:409 in function cublas_gemm: cuBLAS Error: an unsupported value or parameter was passed to the function

Pedrexus avatar May 26 '25 05:05 Pedrexus

Downgrading transformer_engine seems to be a workaround.

cuda 12.2, pytorch 2.7.1+cu126, cudnn 9.5.1 Error with: transformer engine(torch) 2.4.0 No error with: transformer engine(torch) 1.13.0.

yejoon-lee avatar Jul 01 '25 10:07 yejoon-lee

I think I see the problem. Just to confirm - all of these issues are when running TE installed from PyPi, correct?

@ptrendx Yes, I can confirm on this! After reading this post, I uninstalled the PyPi version, and built TE from source using the following command, as suggested in the official documentation, and things work properly now :-)

pip3 install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable

My previous TE version installed via PyPi was 2.2.0, and now using 2.4.0+b43596b.

Thanks again for the suggestion!

haok1402 avatar Jul 06 '25 01:07 haok1402