CogVideo
CogVideo copied to clipboard
RuntimeError: expected scalar type Float but found Half
System Info / 系統信息
CUDA==11.8 pytorch==2.3.0 diffusers==0.30.1 transformer==4.44.2 apex==0.1
Information / 问题信息
- [X] The official example scripts / 官方的示例脚本
- [ ] My own modified scripts / 我自己修改的脚本和任务
Reproduction / 复现过程
run the demo code: python inference/cli_demo.py --prompt "A girl ridding a bike." --model_path THUDM/CogVideoX-2b
File "/data/proj/CogVideo/inference/cli_demo.py", line 126, in <module>
generate_video(
File "/data/proj/CogVideo/inference/cli_demo.py", line 89, in generate_video
video = pipe(
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/diffusers/pipelines/cogvideo/pipeline_cogvideox.py", line 629, in __call__
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/diffusers/pipelines/cogvideo/pipeline_cogvideox.py", line 297, in encode_prompt
prompt_embeds = self._get_t5_prompt_embeds(
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/diffusers/pipelines/cogvideo/pipeline_cogvideox.py", line 240, in _get_t5_prompt_embeds
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 1970, in forward
encoder_outputs = self.encoder(
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 1105, in forward
layer_outputs = layer_module(
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 685, in forward
self_attention_outputs = self.layer[0](
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 592, in forward
normed_hidden_states = self.layer_norm(hidden_states)
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/apex/normalization/fused_layer_norm.py", line 416, in forward
return fused_rms_norm_affine(
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/apex/normalization/fused_layer_norm.py", line 215, in fused_rms_norm_affine
return FusedRMSNormAffineFunction.apply(*args)
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/torch/autograd/function.py", line 598, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/root/miniconda3/envs/opensora/lib/python3.9/site-packages/apex/normalization/fused_layer_norm.py", line 75, in forward
output, invvar = fused_layer_norm_cuda.rms_forward_affine(
RuntimeError: expected scalar type Float but found Half
I think it is related to T5 model. "wo" is set to keep fp32, when the hidden_states pass "wo" module, it become fp32 dtype, and thus cause the RuntimeError. This error can be fixed by fp32 mode or explicit cast the data type of hidden_states to fp16 after "wo" module. But I still wonder if this is a common bug or just caused by my corrupted lib dependency?
Expected behavior / 期待表现
model can be run in fp16 mode without error
It seems some bots left comments to spread virus I think. BE CAREFUL.
This error shouldn’t occur because, in the CLI demo, the entire model pipeline is loaded using FP16 (by default), so there shouldn’t be an issue with FP32. Can you print the dtype of the pipeline?
This error shouldn’t occur because, in the CLI demo, the entire model pipeline is loaded using FP16 (by default), so there shouldn’t be an issue with FP32. Can you print the dtype of the pipeline?
Sure. I set breakpoints and confirmed the pipe dtype is fp16. And I dive into the code, the bug is raised within T5 model (also confirmed as fp16 dtype), just after the "wo" module, hidden_state become fp32 and at the SECOND layer of SA module (after a "wo" module), the layer norm op raise the error (transformer.models.t5.modeling_t5:592).
#### SECOND PASS
ipdb> hidden_states.dtype
torch.float32
ipdb> self.layer_norm
FusedRMSNorm(torch.Size([4096]), eps=1e-06, elementwise_affine=True)
ipdb> self.layer_norm.weight.dtype
torch.float16
This shouldn't be the case. I haven't encountered this situation for the time being because this part of the replacement is reasonable and shouldn't directly cause errors
I had the same problem in T5
I met the same problem, how to fix it, please.
This line is the root cause for fp32 conversion. And it looks like it's only affecting fp16, due to the internal logic to handle _keep_in_fp32_modules
A quick fix is to convert dtype back after T5DenseActDense or T5DenseGatedActDense. Or to just use bf16