autokeras icon indicating copy to clipboard operation
autokeras copied to clipboard

JAX Support

Open Anselmoo opened this issue 2 years ago • 0 comments

Feature Description

For further increasing the performance, I was wondering if you consider an optional JAX support, by rewriting the NumPy import.

Code Example


try:
    import jax.numpy as np
except ModuleNotFoundError:
    import numpy as np

pip install autokeras[jax]

Reason

Especially, for running on GPU or Google's Colab via TPU a further increase in speed for heavy data sets > 10GB; like in autokeras/preprocessors/encoders.py

If NumPy arrays are sliced or object dtype, this code sections have to be rewritten.

Solution

  1. Update the header of all NumPy imports
  2. Check if arrays function has to be rewritten
  3. Extend the setup.py via optional command
extras_require = {'jax': ['jax>=0.3.1,<0.4.0', 'jaxlib>=0.3.0,<0.4.0']}

Anselmoo avatar Mar 13 '22 18:03 Anselmoo