gpt-fast icon indicating copy to clipboard operation
gpt-fast copied to clipboard

RuntimeError: cutlassF: no kernel found to launch!

Open goodboyyes2009 opened this issue 1 year ago • 15 comments

root@md:/home/projects/gpt-fast# CUDA_VISIBLE_DEVICES=0 python3 generate.py --compile --checkpoint_path /models/huggingface_models/meta-Llama-2-7b-hf/model_int8.pth --max_new_tokens 100 Loading model ... Using int8 weight-only quantization! /opt/conda/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage() return self.fget.get(instance, owner)() Time to load model: 2.33 seconds Traceback (most recent call last): File "/home/projects/gpt-fast/generate.py", line 407, in main( File "/home/projects/gpt-fast/generate.py", line 346, in main y, metrics = generate( File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/home/projects/gpt-fast/generate.py", line 167, in generate next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) File "/home/projects/gpt-fast/generate.py", line 52, in prefill logits = model(x, input_pos) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/projects/gpt-fast/model.py", line 118, in forward x = layer(x, input_pos, freqs_cis, mask) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/projects/gpt-fast/model.py", line 137, in forward h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/projects/gpt-fast/model.py", line 186, in forward y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) RuntimeError: cutlassF: no kernel found to launch!

GPU: NVIDIA V100

conda list:

_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
asttokens 2.0.5 pyhd3eb1b0_0
astunparse 1.6.3 pypi_0 pypi attrs 23.1.0 pypi_0 pypi backcall 0.2.0 pyhd3eb1b0_0
beautifulsoup4 4.12.2 py310h06a4308_0
blas 1.0 mkl
boltons 23.0.0 py310h06a4308_0
brotlipy 0.7.0 py310h7f8727e_1002
bzip2 1.0.8 h7b6447c_0
c-ares 1.19.0 h5eee18b_0
ca-certificates 2023.08.22 h06a4308_0
certifi 2023.7.22 py310h06a4308_0
cffi 1.15.1 py310h5eee18b_3
chardet 4.0.0 py310h06a4308_1003
charset-normalizer 2.0.4 pyhd3eb1b0_0
click 8.0.4 py310h06a4308_0
cmake 3.26.4 h96355d8_0
conda 23.9.0 py310h06a4308_0
conda-build 3.27.0 py310h06a4308_0
conda-content-trust 0.2.0 py310h06a4308_0
conda-index 0.3.0 py310h06a4308_0
conda-libmamba-solver 23.7.0 py310h06a4308_0
conda-package-handling 2.2.0 py310h06a4308_0
conda-package-streaming 0.9.0 py310h06a4308_0
cryptography 41.0.3 py310hdda0065_0
cuda-cudart 11.8.89 0 nvidia cuda-cupti 11.8.87 0 nvidia cuda-libraries 11.8.0 0 nvidia cuda-nvrtc 11.8.89 0 nvidia cuda-nvtx 11.8.86 0 nvidia cuda-runtime 11.8.0 0 nvidia decorator 5.1.1 pyhd3eb1b0_0
dnspython 2.4.2 pypi_0 pypi exceptiongroup 1.0.4 py310h06a4308_0
executing 0.8.3 pyhd3eb1b0_0
expat 2.5.0 h6a678d5_0
expecttest 0.1.6 pypi_0 pypi ffmpeg 4.3 hf484d3e_0 pytorch filelock 3.9.0 py310h06a4308_0
fmt 9.1.0 hdb19cb5_0
freetype 2.12.1 h4a9f257_0
fsspec 2023.9.2 pypi_0 pypi giflib 5.2.1 h5eee18b_3
gmp 6.2.1 h295c915_3
gmpy2 2.1.2 py310heeb90bb_0
gnutls 3.6.15 he1e5248_0
hypothesis 6.87.1 pypi_0 pypi icu 58.2 he6710b0_3
idna 3.4 py310h06a4308_0
intel-openmp 2023.1.0 hdb19cb5_46305
ipython 8.15.0 py310h06a4308_0
jedi 0.18.1 py310h06a4308_1
jinja2 3.1.2 py310h06a4308_0
jpeg 9e h5eee18b_1
jsonpatch 1.32 pyhd3eb1b0_0
jsonpointer 2.1 pyhd3eb1b0_0
krb5 1.20.1 h143b758_1
lame 3.100 h7b6447c_0
lcms2 2.12 h3be6417_0
ld_impl_linux-64 2.38 h1181459_1
lerc 3.0 h295c915_0
libarchive 3.6.2 h6ac8c49_2
libcublas 11.11.3.6 0 nvidia libcufft 10.9.0.58 0 nvidia libcufile 1.7.2.10 0 nvidia libcurand 10.3.3.141 0 nvidia libcurl 8.1.1 h251f7ec_1
libcusolver 11.4.1.48 0 nvidia libcusparse 11.7.5.86 0 nvidia libdeflate 1.17 h5eee18b_1
libedit 3.1.20221030 h5eee18b_0
libev 4.33 h7f8727e_1
libffi 3.4.4 h6a678d5_0
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libiconv 1.16 h7f8727e_2
libidn2 2.3.4 h5eee18b_0
libjpeg-turbo 2.0.0 h9bf148f_0 pytorch liblief 0.12.3 h6a678d5_0
libmamba 1.4.1 h2dafd23_1
libmambapy 1.4.1 py310h2dafd23_1
libnghttp2 1.52.0 h2d74bed_1
libnpp 11.8.0.86 0 nvidia libnvjpeg 11.9.0.86 0 nvidia libpng 1.6.39 h5eee18b_0
libsolv 0.7.22 he621ea3_0
libssh2 1.10.0 hdbd6064_2
libstdcxx-ng 11.2.0 h1234567_1
libtasn1 4.19.0 h5eee18b_0
libtiff 4.5.1 h6a678d5_0
libunistring 0.9.10 h27cfd23_0
libuuid 1.41.5 h5eee18b_0
libuv 1.44.2 h5eee18b_0
libwebp 1.3.2 h11a3e52_0
libwebp-base 1.3.2 h5eee18b_0
libxml2 2.10.3 hcbfbd50_0
llvm-openmp 14.0.6 h9e868ea_0
lz4-c 1.9.4 h6a678d5_0
markupsafe 2.1.1 py310h7f8727e_0
matplotlib-inline 0.1.6 py310h06a4308_0
mkl 2023.1.0 h213fc3f_46343
mkl-service 2.4.0 py310h5eee18b_1
mkl_fft 1.3.8 py310h5eee18b_0
mkl_random 1.2.4 py310hdb19cb5_0
more-itertools 8.12.0 pyhd3eb1b0_0
mpc 1.1.0 h10f8cd9_1
mpfr 4.0.2 hb69a4c5_1
mpmath 1.3.0 py310h06a4308_0
ncurses 6.4 h6a678d5_0
nettle 3.7.3 hbbd107a_1
networkx 3.1 py310h06a4308_0
numpy 1.26.0 py310h5f9d8c6_0
numpy-base 1.26.0 py310hb5e798b_0
openh264 2.1.1 h4ff587b_0
openssl 3.0.11 h7f8727e_2
packaging 23.1 py310h06a4308_0
parso 0.8.3 pyhd3eb1b0_0
patch 2.7.6 h7b6447c_1001
patchelf 0.17.2 h6a678d5_0
pcre2 10.37 he7ceb23_1
pexpect 4.8.0 pyhd3eb1b0_3
pickleshare 0.7.5 pyhd3eb1b0_1003
pillow 9.4.0 py310h6a678d5_1
pip 23.2.1 py310h06a4308_0
pkginfo 1.9.6 py310h06a4308_0
pluggy 1.0.0 py310h06a4308_1
prompt-toolkit 3.0.36 py310h06a4308_0
psutil 5.9.0 py310h5eee18b_0
ptyprocess 0.7.0 pyhd3eb1b0_2
pure_eval 0.2.2 pyhd3eb1b0_0
py-lief 0.12.3 py310h6a678d5_0
pybind11-abi 4 hd3eb1b0_1
pycosat 0.6.4 py310h5eee18b_0
pycparser 2.21 pyhd3eb1b0_0
pygments 2.15.1 py310h06a4308_1
pyopenssl 23.2.0 py310h06a4308_0
pysocks 1.7.1 py310h06a4308_0
python 3.10.13 h955ad1f_0
python-etcd 0.4.5 pypi_0 pypi python-libarchive-c 2.9 pyhd3eb1b0_1
pytorch 2.1.0 py3.10_cuda11.8_cudnn8.7.0_0 pytorch pytorch-cuda 11.8 h7e8668a_5 pytorch pytorch-mutex 1.0 cuda pytorch pytz 2023.3.post1 py310h06a4308_0
pyyaml 6.0 py310h5eee18b_1
readline 8.2 h5eee18b_0
reproc 14.2.4 h295c915_1
reproc-cpp 14.2.4 h295c915_1
requests 2.31.0 py310h06a4308_0
rhash 1.4.3 hdbd6064_0
ruamel.yaml 0.17.21 py310h5eee18b_0
ruamel.yaml.clib 0.2.6 py310h5eee18b_1
sentencepiece 0.1.99 pypi_0 pypi setuptools 68.0.0 py310h06a4308_0
six 1.16.0 pyhd3eb1b0_1
sortedcontainers 2.4.0 pypi_0 pypi soupsieve 2.5 py310h06a4308_0
sqlite 3.41.2 h5eee18b_0
stack_data 0.2.0 pyhd3eb1b0_0
sympy 1.12 pypi_0 pypi tbb 2021.8.0 hdb19cb5_0
tk 8.6.12 h1ccaba5_0
tomli 2.0.1 py310h06a4308_0
toolz 0.12.0 py310h06a4308_0
torchaudio 2.1.0 py310_cu118 pytorch torchelastic 0.2.2 pypi_0 pypi torchtriton 2.1.0 py310 pytorch torchvision 0.16.0 py310_cu118 pytorch tqdm 4.65.0 py310h2f386ee_0
traitlets 5.7.1 py310h06a4308_0
truststore 0.8.0 py310h06a4308_0
types-dataclasses 0.6.6 pypi_0 pypi typing-extensions 4.8.0 pypi_0 pypi typing_extensions 4.7.1 py310h06a4308_0
tzdata 2023c h04d1e81_0
urllib3 1.26.16 py310h06a4308_0
wcwidth 0.2.5 pyhd3eb1b0_0
wheel 0.41.2 py310h06a4308_0
xz 5.4.2 h5eee18b_0
yaml 0.2.5 h7b6447c_0
yaml-cpp 0.7.0 h295c915_1
zlib 1.2.13 h5eee18b_0
zstandard 0.19.0 py310h5eee18b_0
zstd 1.5.5 hc292b87_0

goodboyyes2009 avatar Dec 14 '23 03:12 goodboyyes2009

I have the same error

merveermann avatar Dec 14 '23 14:12 merveermann

same error

Armod-I avatar Dec 15 '23 05:12 Armod-I

My conda environment is as below:

GPU: RTX 5000 CUDA: 12.3

Name Version Build Channel _libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
blas 1.0 mkl
brotlipy 0.7.0 py311h9bf148f_1002 pytorch-nightly bzip2 1.0.8 h7b6447c_0
ca-certificates 2023.08.22 h06a4308_0
certifi 2023.11.17 py311h06a4308_0
cffi 1.15.1 py311h9bf148f_3 pytorch-nightly charset-normalizer 2.0.4 pyhd3eb1b0_0
cryptography 38.0.4 py311h46ebde7_0 pytorch-nightly cuda-cudart 12.1.105 0 nvidia cuda-cupti 12.1.105 0 nvidia cuda-libraries 12.1.0 0 nvidia cuda-nvrtc 12.1.105 0 nvidia cuda-nvtx 12.1.105 0 nvidia cuda-opencl 12.3.101 0 nvidia cuda-runtime 12.1.0 0 nvidia ffmpeg 4.2.2 h20bf706_0
filelock 3.9.0 py311_0 pytorch-nightly freetype 2.12.1 h4a9f257_0
fsspec 2023.12.2 pypi_0 pypi giflib 5.2.1 h5eee18b_3
gmp 6.2.1 h295c915_3
gmpy2 2.1.2 py311hc9b5ff0_0
gnutls 3.6.15 he1e5248_0
huggingface-hub 0.19.4 pypi_0 pypi idna 3.4 py311h06a4308_0
intel-openmp 2021.4.0 h06a4308_3561
jinja2 3.1.2 py311h06a4308_0
jpeg 9e h5eee18b_1
lame 3.100 h7b6447c_0
lcms2 2.12 h3be6417_0
ld_impl_linux-64 2.38 h1181459_1
lerc 3.0 h295c915_0
libcublas 12.1.0.26 0 nvidia libcufft 11.0.2.4 0 nvidia libcufile 1.8.1.2 0 nvidia libcurand 10.3.4.101 0 nvidia libcusolver 11.4.4.55 0 nvidia libcusparse 12.0.2.55 0 nvidia libdeflate 1.17 h5eee18b_1
libffi 3.4.4 h6a678d5_0
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libidn2 2.3.4 h5eee18b_0
libjpeg-turbo 2.0.0 h9bf148f_0 pytorch-nightly libnpp 12.0.2.50 0 nvidia libnvjitlink 12.1.105 0 nvidia libnvjpeg 12.1.1.14 0 nvidia libopus 1.3.1 h7b6447c_0
libpng 1.6.39 h5eee18b_0
libstdcxx-ng 11.2.0 h1234567_1
libtasn1 4.19.0 h5eee18b_0
libtiff 4.5.1 h6a678d5_0
libunistring 0.9.10 h27cfd23_0
libuuid 1.41.5 h5eee18b_0
libvpx 1.7.0 h439df22_0
libwebp 1.2.4 h11a3e52_1
libwebp-base 1.2.4 h5eee18b_1
llvm-openmp 14.0.6 h9e868ea_0
lz4-c 1.9.4 h6a678d5_0
markupsafe 2.1.1 py311h5eee18b_0
mkl 2021.4.0 h06a4308_640
mkl-service 2.4.0 py311h9bf148f_0 pytorch-nightly mkl_fft 1.3.1 py311hc796f24_0 pytorch-nightly mkl_random 1.2.2 py311hbba84a0_0 pytorch-nightly mpc 1.1.0 h10f8cd9_1
mpfr 4.0.2 hb69a4c5_1
mpmath 1.2.1 py311_0 pytorch-nightly ncurses 6.4 h6a678d5_0
nettle 3.7.3 hbbd107a_1
networkx 3.1 py311h06a4308_0
numpy 1.24.3 py311hc206e33_0
numpy-base 1.24.3 py311hfd5febd_0
openh264 2.1.1 h4ff587b_0
openssl 3.0.12 h7f8727e_0
packaging 23.2 pypi_0 pypi pillow 9.3.0 py311h3fd9d12_2 pytorch-nightly pip 23.3.1 py311h06a4308_0
pycparser 2.21 pyhd3eb1b0_0
pyopenssl 23.2.0 py311h06a4308_0
pysocks 1.7.1 py311_0 pytorch-nightly python 3.11.5 h955ad1f_0
pytorch 2.3.0.dev20231214 py3.11_cuda12.1_cudnn8.9.2_0 pytorch-nightly pytorch-cuda 12.1 ha16c6d3_5 pytorch-nightly pytorch-mutex 1.0 cuda pytorch-nightly pyyaml 6.0.1 py311h5eee18b_0
readline 8.2 h5eee18b_0
requests 2.28.1 py311_0 pytorch-nightly sentencepiece 0.1.99 pypi_0 pypi setuptools 68.2.2 py311h06a4308_0
six 1.16.0 pyhd3eb1b0_1
sqlite 3.41.2 h5eee18b_0
sympy 1.12 py311h06a4308_0
tk 8.6.12 h1ccaba5_0
torchaudio 2.2.0.dev20231214 py311_cu121 pytorch-nightly torchtriton 2.1.0+bcad9dabe1 py311 pytorch-nightly torchvision 0.18.0.dev20231214 py311_cu121 pytorch-nightly tqdm 4.66.1 pypi_0 pypi typing_extensions 4.7.1 py311h06a4308_0
tzdata 2023c h04d1e81_0
urllib3 1.26.14 py311_0 pytorch-nightly wheel 0.41.2 py311h06a4308_0
x264 1!157.20191217 h7b6447c_0
xz 5.4.5 h5eee18b_0
yaml 0.2.5 h7b6447c_0
zlib 1.2.13 h5eee18b_0
zstd 1.5.5 hc292b87_0

merveermann avatar Dec 15 '23 05:12 merveermann

cc: @drisspg

Chillee avatar Dec 15 '23 17:12 Chillee

Can you try using a the patch release, or nightly?

drisspg avatar Dec 15 '23 17:12 drisspg

@drisspg https://github.com/pytorch-labs/gpt-fast/issues/46#issuecomment-1857615288 @merveermann says they're using the nightly I believe.

Chillee avatar Dec 17 '23 01:12 Chillee

So this error is being thrown on Nightly for devices: V100, RTX5000 Is there any others?

Also it is possible to give example inputs of to SDPA that are causing this error to be thrown? Is this only happening when the model is being compiled?

My hunch is that compile is doing some memory planning optimizations that cause the alignment check here: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L1023-L1027 to fail for all possible kernels.

drisspg avatar Dec 17 '23 18:12 drisspg

It seems your GPU not support bf16, change all torch.bfloat16 to torch.float32 may work.

VendaCino avatar Dec 17 '23 23:12 VendaCino

@drisspg I tested on a V100. Both eager and compiled runs into the same error.

I think the issue is that mem_eff_attention doesn't support bf16 on sm < 80: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h#L286

I tested with float16 and it works. Shall we default gpt-fast to float16 for V100 and under?

yifuwang avatar Dec 18 '23 04:12 yifuwang

Ohh @yifuwang thank you, that is a great catch I will put up a PR right now to fix this in PyTorch

drisspg avatar Dec 18 '23 16:12 drisspg

thank you all, after change all torch.bfloat16 to torch.float32, run with unquantized model works well but run with int8 seems wrong

root@md:/home/projects/gpt-fast# CUDA_VISIBLE_DEVICES=0 python3 generate.py --compile --checkpoint_path /models/huggingface_models/meta-Llama-2-7b-hf/model_int8.pth --max_new_tokens 100
Loading model ...
Using int8 weight-only quantization!
/opt/conda/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return self.fget.__get__(instance, owner)()
Time to load model: 2.52 seconds
[2023-12-19 00:54:26,247] [0/0] torch._dynamo.output_graph: [WARNING] nn.Module state_dict and backward hooks are not yet supported by torch.compile, but were detected in your model and will be silently ignored. See https://pytorch.org/docs/master/compile/nn-module.html for more information and limitations.
Compilation time: 101.21 seconds
Hello, my name is ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ 
Time for inference 1: 4.87 sec total, 20.53 tokens/sec
Bandwidth achieved: 141.08 GB/s
Hello, my name is ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ 
Time for inference 2: 4.87 sec total, 20.55 tokens/sec
Bandwidth achieved: 141.25 GB/s
Hello, my name is ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ 
Time for inference 3: 4.87 sec total, 20.55 tokens/sec
Bandwidth achieved: 141.24 GB/s
Hello, my name is ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ 
Time for inference 4: 4.87 sec total, 20.55 tokens/sec
Bandwidth achieved: 141.22 GB/s
Hello, my name is ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ 
Time for inference 5: 4.87 sec total, 20.55 tokens/sec
Bandwidth achieved: 141.24 GB/s
==========
Average tokens/sec: 20.55
Memory used: 8.01 GB

goodboyyes2009 avatar Dec 19 '23 01:12 goodboyyes2009

@goodboyyes2009 Did you re-run quatilized.py after torch.bfloat16 to torch.float32?

VendaCino avatar Dec 19 '23 01:12 VendaCino

@VendaCino oh, sorry, I do re-run quatilized.py, but I change all torch.bfloat16 to torch.float16

goodboyyes2009 avatar Dec 19 '23 01:12 goodboyyes2009

'⁇ ⁇ ⁇' is because tensor value nan. I debug found that the kv_cache in that attention layer is nan. and this issue will not happen when all dtype is torch.float32 but not torch.float16 and this issue not happen when I use tinyllama but not viucna-7b.

image

hope this information can help to trace the problem.

update: deep debug found that it is because x.max() = inf I think some layer output too large and float16 not ok to show that.

Time to load model: 1.97 seconds
tensor(7.0664, device='cuda:0', dtype=torch.float16)
tensor(18.3906, device='cuda:0', dtype=torch.float16)
tensor(inf, device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16)

it depends on the weight of model, so when I test in tinyllama it works well.

when I use model.pth

Time to load model: 10.10 seconds
tensor(7.0625, device='cuda:0', dtype=torch.float16)
tensor(18.3438, device='cuda:0', dtype=torch.float16)
tensor(1532., device='cuda:0', dtype=torch.float16)

so i guess something wrong in WeightOnlyInt8Linear

class WeightOnlyInt8Linear(torch.nn.Module):
    ...

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales -> here loss the precision

change it to

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return (F.linear(input.to(dtype=torch.float32), self.weight.to(dtype=torch.float32)) * self.scales).to(dtype=input.dtype)

everything looks good

Time to load model: 1.66 seconds
tensor(7.0664, device='cuda:0', dtype=torch.float16)
tensor(18.3906, device='cuda:0', dtype=torch.float16)
tensor(1535., device='cuda:0', dtype=torch.float16)

VendaCino avatar Dec 19 '23 03:12 VendaCino

OK. Thank you very much! @VendaCino

goodboyyes2009 avatar Dec 19 '23 04:12 goodboyyes2009