open_flamingo
open_flamingo copied to clipboard
Mismatch input type and weight type when training with precision fp16
Hi, thanks for making this project public.
I am trying to run training with fp16 and get the following error:
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same
I am able to run using fp32 successfully only with an OOM error.
Traceback for error when using fp16:
Traceback (most recent call last):
File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/train/train.py", line 484, in <module>
main()
File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/train/train.py", line 465, in main
train_one_epoch(
File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/train/train_utils.py", line 111, in train_one_epoch
loss_laion = model(
File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
output = self._run_ddp_forward(*inputs, **kwargs)
File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
return module_to_run(*inputs[0], **kwargs[0]) # type: ignore[index]
File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/src/flamingo.py", line 108, in forward
self._encode_vision_x(vision_x=vision_x)
File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/src/flamingo.py", line 195, in _encode_vision_x
vision_x = self.vision_encoder(vision_x)[1]
File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/open_clip/transformer.py", line 469, in forward
x = self.conv1(x) # shape = [*, width, grid, grid]
File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 463, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same
Environment
I am using python 3.9.17 with V100 GPUs.
open-clip-torch 2.16.0
torch 2.0.1
torchvision 0.15.2
transformers 4.28.1
Thanks for bringing this up! I will take a closer look later today. I do want to point out that we haven't gotten good performance with pure fp16 training. It could be more better if you use fp32 but use fsdp to shard model state across your GPUs rather than reducing the precision.
Thanks for clarifying. FSDP would be ideal. Still, I have problems training with FSDP. Namely, I am using MPT-1B and it does not have the get_output_embeddings
and set_output_embeddings
methods. I see there is a major refactor that is in progress. Looking forward to using it soon.
Got it. There is this version of mpt I use for testing if you want to give fsdp a shot before the new refactor is merged.
Great, thanks for bringing up this. I will give it a try on this model with fsdp.
I tried fsdp with "mpt-1b-redpajama-200b-hf-style" and it could pass the above error.
However, I get another error where the shape of input embeddings (self.transformer.wte.weight) has been altered. I believe it should be a 2-D tensor of shape (:, 2048) instead of a 1-D tensor of shape (25743360) which causes the size mismatch when computing the logits. More details below:
File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/src/flamingo.py", line 111, in forward
output = self.lang_encoder(
File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/src/flamingo_lm.py", line 157, in forward
return super().forward(**kwargs) # Call the other parent's forward method
File "/home/hqvo2/.cache/huggingface/modules/transformers_modules/anas-awadalla/mpt-1b-redpajama-200b-hf-style/f40a2c7f92621be8b12a01ac9214d3ed4ef50f60/mosaic_gpt.py", line 379, in forward
logits = F.linear(x, self.transformer.wte.weight, None)
RuntimeError: size mismatch, got 8, 8x2048,25743360
I tried fsdp with "mpt-1b-redpajama-200b-hf-style" and it could pass the above error.
However, I get another error where the shape of input embeddings (self.transformer.wte.weight) has been altered. I believe it should be a 2-D tensor of shape (:, 2048) instead of a 1-D tensor of shape (25743360) which causes the size mismatch when computing the logits. More details below:
File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/src/flamingo.py", line 111, in forward output = self.lang_encoder( File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/src/flamingo_lm.py", line 157, in forward return super().forward(**kwargs) # Call the other parent's forward method File "/home/hqvo2/.cache/huggingface/modules/transformers_modules/anas-awadalla/mpt-1b-redpajama-200b-hf-style/f40a2c7f92621be8b12a01ac9214d3ed4ef50f60/mosaic_gpt.py", line 379, in forward logits = F.linear(x, self.transformer.wte.weight, None) RuntimeError: size mismatch, got 8, 8x2048,25743360
did you resolve this? i get a very similar error while trying to use fsdp w/ openflamingo 9B:
File "/gpfs/data/oermannlab/users/alyaka01/.conda/envs/cns-flamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/gpfs/data/oermannlab/users/alyaka01/.conda/envs/cns-flamingo/lib/python3.9/site-packages/open_flamingo/src/flamingo.py", line 111, in forward
output = self.lang_encoder(
File "/gpfs/data/oermannlab/users/alyaka01/.conda/envs/cns-flamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/gpfs/data/oermannlab/users/alyaka01/.conda/envs/cns-flamingo/lib/python3.9/site-packages/open_flamingo/src/flamingo_lm.py", line 157, in forward
return super().forward(**kwargs) # Call the other parent's forward method
File "/gpfs/data/oermannlab/users/alyaka01/.cache/huggingface/modules/transformers_modules/anas-awadalla/mpt-7b/b772e556c8e8a17d087db6935e7cd019e5eefb0f/modeling_mpt.py", line 258, in forward
logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
RuntimeError: size mismatch, got 8192, 8192x4096,51486720
related: https://github.com/mlfoundations/open_flamingo/issues/129#issuecomment-1696570150