DESC
DESC copied to clipboard
Compatibility of DESC with Apple Silicon Devices (jax-metal)
This PR introduces the JAX device METAL as available in Apple Silicon devices (M1, M2 or M3) that have the jax-metal installed https://developer.apple.com/metal/jax/ However, not all functionalities are available yet.
- Only 32-bit functionality is available.
- As jax.scipy.special functions have limited availability, the functions gammaln and logsumexp have special definitions when METAL is the device being used.
- An equilibrium cannot be created yet due to the following error when calling the Equilibrium class:
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: DESC/desc/backend.py:116:15: error: failed to legalize operation 'mhlo.scatter' return jnp.asarray(arr).at[inds].set(vals)It appears thatmhlo.scatterwill be implemented in the next version of jax-metal, as per this thread here https://developer.apple.com/forums/thread/738697 Therefore, I will wait for the. next update to see if this PR can go through.
Codecov Report
Attention: Patch coverage is 25.00000% with 9 lines in your changes missing coverage. Please review.
Project coverage is 95.49%. Comparing base (
831e7bd) to head (9df74ef). Report is 61 commits behind head on master.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| desc/__init__.py | 10.00% | 9 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## master #741 +/- ##
==========================================
- Coverage 95.52% 95.49% -0.04%
==========================================
Files 96 96
Lines 24010 24021 +11
==========================================
+ Hits 22936 22938 +2
- Misses 1074 1083 +9
| Files with missing lines | Coverage Δ | |
|---|---|---|
| desc/backend.py | 90.30% <100.00%> (+0.05%) |
:arrow_up: |
| desc/__init__.py | 41.79% <10.00%> (-5.58%) |
:arrow_down: |