transformers
transformers copied to clipboard
remove unnecessary backend related checks in training_args.py
Feature request
IMO these checks in transformers should be removed.
if (
self.framework == "pt"
and is_torch_available()
and (self.device.type != "cuda")
and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) != "GPU")
and (self.fp16 or self.fp16_full_eval)
):
raise ValueError(
"FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
" (`--fp16_full_eval`) can only be used on CUDA or NPU devices or certain XPU devices (with IPEX)."
)
if (
self.framework == "pt"
and is_torch_available()
and (self.device.type != "cuda")
and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) != "GPU")
and (get_xla_device_type(self.device) != "TPU")
and (self.device.type != "cpu")
and (self.bf16 or self.bf16_full_eval)
):
raise ValueError(
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
" (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX), NPU or CPU/TPU/NeuronCore devices."
)
Motivation
To make things work each vendor need to extend this if
by putting another line of and (self.device.type != "my_precious_chip")
.
It makes code bloated in transformers.
And I don't really think it's transformers' job to determine capability for backends. Just passthrough the paramters and let backend itself to determine if they can handle the dtype. They should have enough means to report a error.
Your contribution
I'm glad to delete them if approved : -p