(Experimental) Integrate Metal PjRt plugin
This integrates the PjRt plugin from the jax-metal for running on the Apple GPU. To test it, one can set client: :mps on EXLA backend/compiler. Since the plugin is loaded as a separate dynamic library, it can be tested without any changes to XLA (just make sure to remove the cache/ directory).
Certain computations can already be run, but the plugin is still very much incomplete. This PR is a room for experimentation and is meant to track the plugin progress. I reported a number of issues upstream, comments in the code point to those. In a few places I applied workarounds as temporary solutions or just to avoid VM crashes (segfaults), those are marked with a TODO.
Issues
For tracking purposes, here is a list of the Metal plugin issues reported upstream:
Crucial
- https://github.com/google/jax/issues/21552
- https://github.com/google/jax/issues/21601
- https://github.com/google/jax/issues/21547
- https://github.com/google/jax/issues/21590
Not implemented
- https://github.com/google/jax/issues/21387
- https://github.com/google/jax/issues/21384
- https://github.com/google/jax/issues/21554
- https://github.com/google/jax/issues/21389
- Support for complex https://github.com/google/jax/issues/16416
- Support for float-64 https://github.com/google/jax/issues/20938
Edge cases
- https://github.com/google/jax/issues/21577
- https://github.com/google/jax/issues/21397
- https://github.com/google/jax/issues/21393
- https://github.com/google/jax/issues/21392
- https://github.com/google/jax/issues/21821
All issues: link.
Note: this PR is against the jk-s32 branch, which changes the default integer precision to 32 bits. This is a planned change (#1491), but it's not integrated yet to avoid conflicts with other work in progress.
(I pushed an update to reduce the diff, in case someone wants to look into another PjRt plugin integration)