torchstain icon indicating copy to clipboard operation
torchstain copied to clipboard

Macenko JAX backend support

Open andreped opened this issue 2 years ago • 0 comments

This PR adds JAX backend support to Macenko.

Changes:

  • Implemented Macenko with JAX backend and added it to base
  • Added JAX unit test CI jobs (check that JAX yields similar results to numpy backend)
  • Renamed CI names to better match their actual purpose
  • Updated README regarding JAX backend support
  • Fixed setup.py to support installation through pip install torchstain[jax]
  • Fixed np.float32 deprecation in numpy macenko
  • Removed unwanted numpy import in macenko tf backend

Note that the JAX backend runtime-wise is not as optimized as the other backends. Hence, I would perhaps say that we only have experimental JAX support as of now. Here is how JAX backend compared to the other backends:

backends numpy jax torch tf
runtime [s] 0.455 2.427 0.201 0.442

Further optimization to the JAX implementation should be done in future work, but this is outside my area of expertise. Hence, for that, it would be great if more experienced JAX developers could contribute.

andreped avatar Jan 29 '23 11:01 andreped