autokeras
autokeras copied to clipboard
JAX Support
Feature Description
For further increasing the performance, I was wondering if you consider an optional JAX support, by rewriting the NumPy
import.
Code Example
try:
import jax.numpy as np
except ModuleNotFoundError:
import numpy as np
pip install autokeras[jax]
Reason
Especially, for running on GPU or Google's Colab via TPU a further increase in speed for heavy data sets > 10GB; like in autokeras/preprocessors/encoders.py
If NumPy arrays are sliced or
object dtype
, this code sections have to be rewritten.
Solution
- Update the header of all NumPy imports
- Check if arrays function has to be rewritten
- Extend the
setup.py
via optional command
extras_require = {'jax': ['jax>=0.3.1,<0.4.0', 'jaxlib>=0.3.0,<0.4.0']}