DESC icon indicating copy to clipboard operation
DESC copied to clipboard

Compatibility of DESC with Apple Silicon Devices (jax-metal)

Open rogeriojorge opened this issue 2 years ago • 1 comments

This PR introduces the JAX device METAL as available in Apple Silicon devices (M1, M2 or M3) that have the jax-metal installed https://developer.apple.com/metal/jax/ However, not all functionalities are available yet.

  • Only 32-bit functionality is available.
  • As jax.scipy.special functions have limited availability, the functions gammaln and logsumexp have special definitions when METAL is the device being used.
  • An equilibrium cannot be created yet due to the following error when calling the Equilibrium class: jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: DESC/desc/backend.py:116:15: error: failed to legalize operation 'mhlo.scatter' return jnp.asarray(arr).at[inds].set(vals) It appears that mhlo.scatter will be implemented in the next version of jax-metal, as per this thread here https://developer.apple.com/forums/thread/738697 Therefore, I will wait for the. next update to see if this PR can go through.

rogeriojorge avatar Nov 08 '23 10:11 rogeriojorge

Codecov Report

Attention: Patch coverage is 25.00000% with 9 lines in your changes missing coverage. Please review.

Project coverage is 95.49%. Comparing base (831e7bd) to head (9df74ef). Report is 61 commits behind head on master.

Files with missing lines Patch % Lines
desc/__init__.py 10.00% 9 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #741      +/-   ##
==========================================
- Coverage   95.52%   95.49%   -0.04%     
==========================================
  Files          96       96              
  Lines       24010    24021      +11     
==========================================
+ Hits        22936    22938       +2     
- Misses       1074     1083       +9     
Files with missing lines Coverage Δ
desc/backend.py 90.30% <100.00%> (+0.05%) :arrow_up:
desc/__init__.py 41.79% <10.00%> (-5.58%) :arrow_down:

... and 1 file with indirect coverage changes

codecov[bot] avatar Nov 22 '23 20:11 codecov[bot]