LLaVA
LLaVA copied to clipboard
device mis-match error on pre-training
Describe the issue
Issue:
I wanted to run the pre-train code https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/pretrain.sh
, but it ends to a device mis-match error. It seems that the whole LLM model is on the cpu
while the data and vision models (its device is changed via vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
in train.py
) are on gpu
, which causes the error. I checked the train.py
code (https://github.com/haotian-liu/LLaVA/blob/main/llava/train/train.py
) which is called in pre-training, and it seems that the issue is where it loads the model without a specific device
or device_map`:
model = LlavaLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args
)
This ends to the following error:
File "/home/afshin/miniconda3/envs/llava/lib/python3.10/contextlib.py", line 153, in __exit__
self.gen.throw(typ, value, traceback)
File "/home/afshin/miniconda3/envs/llava/lib/python3.10/site-packages/accelerate/accelerator.py", line 1058, in accumulate
yield
File "/home/afshin/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/home/afshin/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3238, in training_step
loss = self.compute_loss(model, inputs)
File "/home/afshin/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3264, in compute_loss
outputs = model(**inputs)
File "/home/afshin/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
python-BaseException
return self._call_impl(*args, **kwargs)
File "/home/afshin/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/afshin/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
raise RuntimeError("module must have its parameters and buffers "
RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu
If I add device_map="auto",
to LlavaLlamaForCausalLM.from_pretrained()
, it ends sending the model to cuda
, but still the projector
is on cpu
. I also can manually send that to cuda
, but since you have run this code, it should work out of the box with no change, and probably I am not following all the steps correctly. So, posting it here to see what am I missing.