bitsbytes 8bit quantized LLama 3.1 gets stuck sometimes when producing output
System Info
PyTorch version: 2.2.2+cu121 Is debug build: False CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux 10 (buster) (x86_64) GCC version: (Debian 8.3.0-6) 8.3.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.28
Python version: 3.8.5 (default, Sep 4 2020, 07:30:14) [GCC 7.3.0] (64-bit runtime) Python platform: Linux-5.4.0-176-generic-x86_64-with-glibc2.10 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090 Ti Nvidia driver version: 535.161.07 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 48 bits physical, 48 bits virtual CPU(s): 16 On-line CPU(s) list: 0-15 Thread(s) per core: 2 Core(s) per socket: 4 Socket(s): 2 NUMA node(s): 4 Vendor ID: AuthenticAMD CPU family: 21 Model: 1 Model name: AMD Opteron(TM) Processor 6212 Stepping: 2 CPU MHz: 1420.365 CPU max MHz: 2600.0000 CPU min MHz: 1400.0000 BogoMIPS: 5199.80 Virtualization: AMD-V L1d cache: 16K L1i cache: 64K L2 cache: 2048K L3 cache: 6144K NUMA node0 CPU(s): 0-3 NUMA node1 CPU(s): 4-7 NUMA node2 CPU(s): 8-11 NUMA node3 CPU(s): 12-15
Versions of relevant libraries: [pip3] flake8==3.8.4 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.23.5 [pip3] torch==2.2.2 [pip3] triton==2.2.0 [conda] blas 1.0 mkl [conda] libblas 3.9.0 12_linux64_mkl conda-forge [conda] libcblas 3.9.0 12_linux64_mkl conda-forge [conda] liblapack 3.9.0 12_linux64_mkl conda-forge [conda] mkl 2021.4.0 h06a4308_640 [conda] mkl-service 2.4.0 py38h7f8727e_0 [conda] mkl_fft 1.3.1 py38hd3c417c_0 [conda] mkl_random 1.2.2 py38h51133e4_0 [conda] numpy 1.23.5 py38h14f4228_0 [conda] numpy-base 1.23.5 py38h31eccc5_0 [conda] torch 2.2.2 pypi_0 pypi [conda] triton 2.2.0 pypi_0 pypi
Reproduction
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
token=access_token,
trust_remote_code=True,
device_map="auto",
load_in_8bit=True,
offload_folder="offload/"
)
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
token=access_token
)
streamer = TextStreamer(tokenizer)
inputs = tokenizer(prompts, return_tensors='pt', padding=False)
inputs = {key: value.to(model.device) for key, value in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=128000, # Adjusted to 128000 to accommodate long sequences
do_sample=True,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
output_scores=True,
return_dict_in_generate=True,
streamer=streamer
)
I am able to run it and results also look good, but sometimes it gets stuck while producing output. It has produced some output and then just keeps waiting.
This is the point where it gets stuck
outputs = model.generate(
File "/opt/conda/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/transformers/generation/utils.py", line 1989, in generate
result = self._sample(
File "/opt/conda/lib/python3.8/site-packages/transformers/generation/utils.py", line 2932, in _sample
outputs = self(**model_inputs, return_dict=True)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1141, in forward
outputs = self.model(
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 944, in forward
layer_outputs = decoder_layer(
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 677, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 562, in forward
value_states = self.v_proj(hidden_states)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/bitsandbytes/nn/modules.py", line 812, in forward
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
File "/opt/conda/lib/python3.8/site-packages/bitsandbytes/autograd/_functions.py", line 556, in matmul
return MatMul8bitLt.apply(A, B, out, bias, state)
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/lib/python3.8/site-packages/bitsandbytes/autograd/_functions.py", line 415, in forward
output += torch.matmul(subA, state.subB)
Expected behavior
Expected behaviour - It should countinue to produce output and not get stuck.
The above issue happens only sometimes, not sure why. I checked GPU memory when it was stuck, it was not full yet. Thanks