nx icon indicating copy to clipboard operation
nx copied to clipboard

(Experimental) Integrate Metal PjRt plugin

Open jonatanklosko opened this issue 1 year ago • 1 comments

This integrates the PjRt plugin from the jax-metal for running on the Apple GPU. To test it, one can set client: :mps on EXLA backend/compiler. Since the plugin is loaded as a separate dynamic library, it can be tested without any changes to XLA (just make sure to remove the cache/ directory).

Certain computations can already be run, but the plugin is still very much incomplete. This PR is a room for experimentation and is meant to track the plugin progress. I reported a number of issues upstream, comments in the code point to those. In a few places I applied workarounds as temporary solutions or just to avoid VM crashes (segfaults), those are marked with a TODO.

Issues

For tracking purposes, here is a list of the Metal plugin issues reported upstream:

Crucial

  • https://github.com/google/jax/issues/21552
  • https://github.com/google/jax/issues/21601
  • https://github.com/google/jax/issues/21547
  • https://github.com/google/jax/issues/21590

Not implemented

  • https://github.com/google/jax/issues/21387
  • https://github.com/google/jax/issues/21384
  • https://github.com/google/jax/issues/21554
  • https://github.com/google/jax/issues/21389
  • Support for complex https://github.com/google/jax/issues/16416
  • Support for float-64 https://github.com/google/jax/issues/20938

Edge cases

  • https://github.com/google/jax/issues/21577
  • https://github.com/google/jax/issues/21397
  • https://github.com/google/jax/issues/21393
  • https://github.com/google/jax/issues/21392
  • https://github.com/google/jax/issues/21821

All issues: link.


Note: this PR is against the jk-s32 branch, which changes the default integer precision to 32 bits. This is a planned change (#1491), but it's not integrated yet to avoid conflicts with other work in progress.

jonatanklosko avatar Jun 05 '24 07:06 jonatanklosko

(I pushed an update to reduce the diff, in case someone wants to look into another PjRt plugin integration)

jonatanklosko avatar Dec 10 '24 03:12 jonatanklosko