torchstain
torchstain copied to clipboard
Macenko JAX backend support
This PR adds JAX backend support to Macenko.
Changes:
- Implemented Macenko with JAX backend and added it to base
- Added JAX unit test CI jobs (check that JAX yields similar results to numpy backend)
- Renamed CI names to better match their actual purpose
- Updated README regarding JAX backend support
- Fixed setup.py to support installation through
pip install torchstain[jax] - Fixed
np.float32deprecation in numpy macenko - Removed unwanted numpy import in macenko tf backend
Note that the JAX backend runtime-wise is not as optimized as the other backends. Hence, I would perhaps say that we only have experimental JAX support as of now. Here is how JAX backend compared to the other backends:
| backends | numpy | jax | torch | tf |
|---|---|---|---|---|
| runtime [s] | 0.455 | 2.427 | 0.201 | 0.442 |
Further optimization to the JAX implementation should be done in future work, but this is outside my area of expertise. Hence, for that, it would be great if more experienced JAX developers could contribute.