axlearn icon indicating copy to clipboard operation
axlearn copied to clipboard

[JAX API UPDATE] Update `utils.py`, as `_registry_with_keypaths` was removed

Open Steboss opened this issue 8 months ago • 0 comments

JAX removed the attributed _registry_with_keypaths so you may incur in this error:

File "/opt/axlearn/axlearn/common/utils.py", line 1870, in pytree_children
egistry_with_keypaths = jax._src.tree_util._registry_with_keypaths
AttributeError: module 'jax._src.tree_util' has no attribute '_registry_with_keypaths'

Here I am fixing this error. The function itself stays exactly as the same as before, but we're checking for key_children with flatten_one_level_with_keys in later stage. Tests are all positive, performance compares to the same as before the change, outputs between previous version and this one does match. @apghml if you could take a look at this new approach please. Thank you

Steboss avatar May 06 '25 13:05 Steboss