get-started-with-JAX
get-started-with-JAX copied to clipboard
Replace jax.tree_util.tree_map() with jax.tree_util.tree_multimap()
trafficstars
jax.tree_util.tree_multimap() has been removed in JAX 0.3.16.
jax.tree_util.tree_map() is a direct replacement.
jax.tree_map() also works instead of jax.tree_util.tree_multimap().