transformers icon indicating copy to clipboard operation
transformers copied to clipboard

remove unnecessary backend related checks in training_args.py

Open kevint324 opened this issue 1 year ago • 10 comments

Feature request

Here

IMO these checks in transformers should be removed.

   if (
            self.framework == "pt"
            and is_torch_available()
            and (self.device.type != "cuda")
            and (self.device.type != "npu")
            and (self.device.type != "xpu")
            and (get_xla_device_type(self.device) != "GPU")
            and (self.fp16 or self.fp16_full_eval)
        ):
            raise ValueError(
                "FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
                " (`--fp16_full_eval`) can only be used on CUDA or NPU devices or certain XPU devices (with IPEX)."
            )

        if (
            self.framework == "pt"
            and is_torch_available()
            and (self.device.type != "cuda")
            and (self.device.type != "npu")
            and (self.device.type != "xpu")
            and (get_xla_device_type(self.device) != "GPU")
            and (get_xla_device_type(self.device) != "TPU")
            and (self.device.type != "cpu")
            and (self.bf16 or self.bf16_full_eval)
        ):
            raise ValueError(
                "BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
                " (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX), NPU or CPU/TPU/NeuronCore devices."
            )

Motivation

To make things work each vendor need to extend this if by putting another line of and (self.device.type != "my_precious_chip"). It makes code bloated in transformers.

And I don't really think it's transformers' job to determine capability for backends. Just passthrough the paramters and let backend itself to determine if they can handle the dtype. They should have enough means to report a error.

Your contribution

I'm glad to delete them if approved : -p

kevint324 avatar Dec 18 '23 10:12 kevint324

cc @muellerzr @pacman100

amyeroberts avatar Dec 18 '23 14:12 amyeroberts

Happy new year! Any update?

kevint324 avatar Jan 02 '24 01:01 kevint324

Completely makes sense. For example M1 does not support certain dtypes, but M2 now supports some of them so it doesn't make sense to have the above assumptions.

vstoyanoff avatar Jan 04 '24 15:01 vstoyanoff

any updates?

b8kings0ga avatar Jan 10 '24 11:01 b8kings0ga

hi @ArthurZucker @muellerzr

Any priority on this issue? We need to patch this piece of code in order to make transformers work. I do believe all the vendors who are not on this list are affected.

Thanks for looking into it. Cheers

kevint324 avatar Apr 29 '24 09:04 kevint324

Accelerate should handle most of this now, cc @SunMarc if you want to give this a try!

muellerzr avatar Apr 29 '24 17:04 muellerzr

@muellerzr here accelerator does not support fp16 on mps, right?

Thanks

janboeye avatar May 19 '24 04:05 janboeye

@janboeye yes PyTorch does not have mixed precision support on MPS at this time

muellerzr avatar May 19 '24 09:05 muellerzr

Actually yeah, makes sense to remove. Do you want to open a PR @kevint324 ?

ArthurZucker avatar May 23 '24 13:05 ArthurZucker

@ArthurZucker Yes, and here it is https://github.com/huggingface/transformers/pull/30999

Thanks!

kevint324 avatar May 24 '24 01:05 kevint324