axlearn
axlearn copied to clipboard
Migrate from Legacy JAX APIs jax.tree_util to jax.tree
Description
This PR migrates the axlearn codebase from Legacy JAX APIs (jax.tree_util) to the recommended jax.tree module.
The jax.tree API was introduced in JAX v0.4.25 and is now the preferred approach over jax.tree_util. Upgrading to jax.tree ensures better compatibility with future JAX versions and improves code maintainability.
pre-commit
$ pre-commit run -a
Check Yaml...............................................................Passed
Fix End of Files.........................................................Passed
Trim Trailing Whitespace.................................................Passed
black....................................................................Passed
isort....................................................................Passed
pylint...................................................................Passed
pytype
$ pytype -j auto axlearn
...
Success: no errors found
pytest
pytest -v -n 96 -m "not (gs_login or tpu or high_cpu or fp64)" axlearn/common
========== 0 failed, 6220 passed, 10364 skipped in 734.23s (0:12:14) ==========
I wonder how the CI passed with those typos. Do they fail locally for you?
Apparently, jax.jax.jax.tree.leaves works file
import jax
import jax.numpy as jnp
import sys
x = (jnp.array(0), jnp.array(1))
y = jax.jax.jax.tree.leaves(x)
print(y)
print(type(y))
print(jax.jax)
print(jax.jax.jax)
[Array(0, dtype=int32, weak_type=True), Array(1, dtype=int32, weak_type=True)]
<class 'list'>
<module 'jax' from '/usr/local/lib/python3.11/dist-packages/jax/__init__.py'>
<module 'jax' from '/usr/local/lib/python3.11/dist-packages/jax/__init__.py'>
https://colab.research.google.com/drive/1ruOWXG6GXFSh1xdHyBVJVLyRwZTYq6GQ?usp=sharing
Thanks @apivovarov , though this kind of cleanups would be better handled if you propose and let us fix it. We typically need to run many internal validation etc before merging the PR. The hairy part is from our internal repo which uses AxLearn as the core library.
We will take this cleanup PR as a low priority, so does the other cleanups since they are not blocking anything at the moment. It would be great if aws can focus more on prioritizing trainium2 fixes.
Hi @apivovarov do you intend to move forward with this PR? Thanks.