axlearn
axlearn copied to clipboard
[JAX API UPDATE] Update `utils.py`, as `_registry_with_keypaths` was removed
JAX removed the attributed _registry_with_keypaths so you may incur in this error:
File "/opt/axlearn/axlearn/common/utils.py", line 1870, in pytree_children
egistry_with_keypaths = jax._src.tree_util._registry_with_keypaths
AttributeError: module 'jax._src.tree_util' has no attribute '_registry_with_keypaths'
Here I am fixing this error. The function itself stays exactly as the same as before, but we're checking for key_children with flatten_one_level_with_keys in later stage.
Tests are all positive, performance compares to the same as before the change, outputs between previous version and this one does match.
@apghml if you could take a look at this new approach please. Thank you