T5 models fail when loaded with `torch_dtype=torch.half`
System Info
transformersversion: 4.45.0.dev0- Platform: Linux-5.15.0-117-generic-x86_64-with-glibc2.35
- Python version: 3.10.15
- Huggingface_hub version: 0.26.0
- Safetensors version: 0.4.5
- Accelerate version: 1.0.1
- Accelerate config: not found
- PyTorch version (GPU?): 2.3.0a0+gitd2f9472 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: No
- Using GPU in script?: Yes
- GPU type: AMD Instinct MI250X/MI250
Who can help?
@ArthurZucker
Information
- [X] The official example scripts
- [ ] My own modified scripts
Tasks
- [X] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
import torch
from transformers import T5Tokenizer, T5EncoderModel
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5EncoderModel.from_pretrained("t5-small", device_map="auto", torch_dtype=torch.half)
input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
outputs = model(input_ids)
print(outputs[0].dtype)
Error:
Traceback (most recent call last):
File "/workspace/repro.py", line 10, in <module>
outputs = model(input_ids)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/transformers/src/transformers/models/t5/modeling_t5.py", line 1996, in forward
encoder_outputs = self.encoder(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/transformers/src/transformers/models/t5/modeling_t5.py", line 1131, in forward
layer_outputs = layer_module(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/transformers/src/transformers/models/t5/modeling_t5.py", line 711, in forward
self_attention_outputs = self.layer[0](
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/transformers/src/transformers/models/t5/modeling_t5.py", line 616, in forward
normed_hidden_states = self.layer_norm(hidden_states)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/apex/normalization/fused_layer_norm.py", line 386, in forward
return fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/apex/normalization/fused_layer_norm.py", line 189, in fused_rms_norm_affine
return FusedRMSNormAffineFunction.apply(*args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/apex/normalization/fused_layer_norm.py", line 69, in forward
output, invvar = fused_layer_norm_cuda.rms_forward_affine(
RuntimeError: expected scalar type Float but found Half
Expected behavior
With the default fp32 inference:
import torch
from transformers import T5Tokenizer, T5EncoderModel
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5EncoderModel.from_pretrained("t5-small", device_map="auto")
input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
outputs = model(input_ids)
print(outputs[0].dtype)
# Outputs `torch.float32`
I assume this issue occurs with all other T5 models (This issue was found while trying to run stabilityai/stable-diffusion-3-medium-diffusers in half precision, which uses the T5Encoder)
Related issues: #20287 #21391
Note: uninstalling accelerate from the environment fixes the issue. More specifically, the issue is caused by keep_in_fp32_modules=['wo'] for the T5 model (See https://github.com/huggingface/transformers/issues/20287#issuecomment-1342219429, https://github.com/huggingface/transformers/pull/20683), which force-sets low_cpu_mem_usage=True when accelerate is present (https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L4072).
The following script also fails, even if we try to explicitly disable low cpu memory loading:
import torch
from transformers import T5Tokenizer, T5EncoderModel
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5EncoderModel.from_pretrained("t5-small", torch_dtype=torch.half, low_cpu_mem_usage=False).to("cuda")
input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
outputs = model(input_ids)
print(outputs[0].dtype)
I think one solution would be to only import the apex FusedLayerNorm if dtype == torch.float32, will look further into it.
Hey @Rohan138, are you sure torch.half should be used this way?
According to the docs https://pytorch.org/docs/stable/generated/torch.Tensor.half.html it should be used as a replacement for to(torch.float16), here you're invoking it within a .to
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.