transformers
transformers copied to clipboard
remove unnecessary backend related checks in training_args.py
Feature request
IMO these checks in transformers should be removed.
if (
self.framework == "pt"
and is_torch_available()
and (self.device.type != "cuda")
and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) != "GPU")
and (self.fp16 or self.fp16_full_eval)
):
raise ValueError(
"FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
" (`--fp16_full_eval`) can only be used on CUDA or NPU devices or certain XPU devices (with IPEX)."
)
if (
self.framework == "pt"
and is_torch_available()
and (self.device.type != "cuda")
and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) != "GPU")
and (get_xla_device_type(self.device) != "TPU")
and (self.device.type != "cpu")
and (self.bf16 or self.bf16_full_eval)
):
raise ValueError(
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
" (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX), NPU or CPU/TPU/NeuronCore devices."
)
Motivation
To make things work each vendor need to extend this if by putting another line of and (self.device.type != "my_precious_chip").
It makes code bloated in transformers.
And I don't really think it's transformers' job to determine capability for backends. Just passthrough the paramters and let backend itself to determine if they can handle the dtype. They should have enough means to report a error.
Your contribution
I'm glad to delete them if approved : -p
cc @muellerzr @pacman100
Happy new year! Any update?
Completely makes sense. For example M1 does not support certain dtypes, but M2 now supports some of them so it doesn't make sense to have the above assumptions.
any updates?
hi @ArthurZucker @muellerzr
Any priority on this issue? We need to patch this piece of code in order to make transformers work. I do believe all the vendors who are not on this list are affected.
Thanks for looking into it. Cheers
Accelerate should handle most of this now, cc @SunMarc if you want to give this a try!
@janboeye yes PyTorch does not have mixed precision support on MPS at this time
Actually yeah, makes sense to remove. Do you want to open a PR @kevint324 ?
@ArthurZucker Yes, and here it is https://github.com/huggingface/transformers/pull/30999
Thanks!