transformer-debugger
transformer-debugger copied to clipboard
Compatibility with MPS backend
While running inference on my Mac with MacOS version 13.1, I received the following error:
RuntimeError: MPS does not support cumsum_out_mps op with int64 input. Support has been added in macOS 13.3
I received this error because of the use of cumsum in prep_pos_from_pad_and_prev_lens. cumsum is also used in other places in the repository.
The same error arises when running various tests on the MPS backend, as mentioned around the function get_default_device. I have checked that this error is also because of the inability of MacOS to compute cumsum with version < 13.3.
Should we modify the function get_default_device
to return torch.device("mps", 0)
only when MacOS version >= 13.3? We can remove the current workaround that avoids running pytests with this backend.
If this seems like a useful change, I will be happy to submit a pull request.
Thank you!