keras-tuner
keras-tuner copied to clipboard
Implement Bayesian optimization with TF or Jax instead of using sklearn
trafficstars
Is your feature request related to a problem? Please describe.
Optimisation via Bayesian might cause some performance issues due to the evaluation of the Einstein Summation.
https://github.com/keras-team/keras-tuner/blob/a7a361f9521cb1033a05aba865c86eb30784d907/keras_tuner/tuners/bayesian.py#L124-L127
Describe the solution you'd like
Let performing jax.numpy.einsum can help.
Partially support can look like:
try:
from jax.numpy import einsum
except ImportError:
from numpy import einsum
or system dependent:
if sys.platform != "win32":
import jax.numpy as np
else:
import numpy as np