Replace calls to jax.tree_[func] with jax.tree_util.tree_[func]
As of JAX 0.6.0 (April 16, 2025) several tree functions have been removed
From jax: jax.treedef_is_leaf, jax.tree_flatten, jax.tree_map, jax.tree_leaves, jax.tree_structure, jax.tree_transpose, and jax.tree_unflatten. Replacements can be found in {mod}jax.tree or {mod}jax.tree_util.
This PR updates all calls to said functions with calls to jax.tree_util.tree_[func] and should be compatible across all versions of JAX
I've used this branch to fix the jax issues I was having after the latest update. It's working well for me. Thanks! Hopefully @sokrypton can review it soon!
Looks like this was fixed with d024c4e846fea83c090afcbe89a313eeee8ec01e