LLaVA
LLaVA copied to clipboard
[Feature request] Support for load_in_8bit
feature
Being able to run LLaVA in 8-bit mode would allow better support for inference on consumer GPUs due to lower memory requirements. Passing in load_in_8bit=True
to from_pretrained
in the eval/run_llava.py
doesn't work. I'm testing with the 7B v1.1 model. Do you know what might need to be changed in llama.py
to support 8-bit inference?
# 7B v1.1 model
model = LlavaLlamaForCausalLM.from_pretrained(
model_name,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
use_cache=True,
device_map='auto',
load_in_8bit=True
).cuda()
Traceback:
File "/home/ubuntu/llava/LLaVA/llava/eval/run_llava.py", line 191, in <module>
eval_model(args)
File "/home/ubuntu/llava/LLaVA/llava/eval/run_llava.py", line 146, in eval_model
output_ids = model.generate(
File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/generation/utils.py", line 1462, in generate
return self.sample(
File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/generation/utils.py", line 2478, in sample
outputs = self(
File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
File "/home/ubuntu/llava/LLaVA/llava/model/llava.py", line 222, in forward
outputs = self.model(
File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
File "/home/ubuntu/llava/LLaVA/llava/model/llava.py", line 133, in forward
image_features = self.mm_projector(image_features)
File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 320, in forward
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 500, in matmul
return MatMul8bitLt.apply(A, B, out, bias, state)
File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 322, in forward
A = A.view(-1, A.shape[-1]).contiguous()
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.