ColabDesign icon indicating copy to clipboard operation
ColabDesign copied to clipboard

Replace calls to jax.tree_[func] with jax.tree_util.tree_[func]

Open noahharrison64 opened this issue 8 months ago • 1 comments

As of JAX 0.6.0 (April 16, 2025) several tree functions have been removed

From jax: jax.treedef_is_leaf, jax.tree_flatten, jax.tree_map, jax.tree_leaves, jax.tree_structure, jax.tree_transpose, and jax.tree_unflatten. Replacements can be found in {mod}jax.tree or {mod}jax.tree_util.

This PR updates all calls to said functions with calls to jax.tree_util.tree_[func] and should be compatible across all versions of JAX

noahharrison64 avatar Apr 24 '25 09:04 noahharrison64

I've used this branch to fix the jax issues I was having after the latest update. It's working well for me. Thanks! Hopefully @sokrypton can review it soon!

BrandonFrenz avatar Apr 30 '25 17:04 BrandonFrenz

Looks like this was fixed with d024c4e846fea83c090afcbe89a313eeee8ec01e

noahharrison64 avatar Aug 04 '25 10:08 noahharrison64