replace jax.ops.index_update with jax.numpy.at
jax.ops.index_update has been removed (see https://github.com/google/jax/commit/f51a05a889f2fcb19946352b9d65f2b6c49fec4a) which breaks all the bandit code. Please use jnp.at() instead.
See the following lines:
https://github.com/probml/bandits/blob/4c686514da53dd0272e143b87c04828f16a6dfe7/bandits/agents/linear_kf_bandit.py#L2
https://github.com/probml/bandits/blob/181e60d45916622d4c15b2ebf86464f767c19780/bandits/agents/linear_bandit.py#L4
https://github.com/probml/bandits/blob/737ce05db48385df90be7a1f9a62f16e47459c4d/bandits/agents/limited_memory_neural_linear.py#L6
https://github.com/probml/bandits/blob/737ce05db48385df90be7a1f9a62f16e47459c4d/bandits/agents/neural_linear.py#L4
https://github.com/probml/bandits/blob/181e60d45916622d4c15b2ebf86464f767c19780/bandits/agents/linear_bandit_wide.py#L13
I ran into this issue and resolved it as mentioned, but I'm having a really difficult time getting the code to run because I can't get any of the Kalman related classes(KalmanFilterNoiseEstimation and DiagonalExtendedKalmanFilter)to import from JSL. It's like they've been removed. I've been trying to reverse engineer it, but there's a lot going on in the code. I don't want to open an issue because I'm not sure if it's just me. Any advice would be very appreciated.
We plan to reimplement the bandit code on top of our new rebayes library in the new few weeks. Please check back later
Thank you Kevin! On a related note, I'm very grateful for all the work you do/have done. Best, David
Closed in https://github.com/probml/bandits/commit/3ade11e128e14e13284082b93d863d0cb398ec4b