axlearn icon indicating copy to clipboard operation
axlearn copied to clipboard

Migrate from Legacy JAX APIs jax.tree_util to jax.tree

Open apivovarov opened this issue 9 months ago • 3 comments

Description

This PR migrates the axlearn codebase from Legacy JAX APIs (jax.tree_util) to the recommended jax.tree module.

The jax.tree API was introduced in JAX v0.4.25 and is now the preferred approach over jax.tree_util. Upgrading to jax.tree ensures better compatibility with future JAX versions and improves code maintainability.

jax.tree doc

jax.tree_util doc

pre-commit

$ pre-commit run -a      
Check Yaml...............................................................Passed
Fix End of Files.........................................................Passed
Trim Trailing Whitespace.................................................Passed
black....................................................................Passed
isort....................................................................Passed
pylint...................................................................Passed

pytype

$ pytype -j auto axlearn
...
Success: no errors found

pytest

pytest -v -n 96 -m "not (gs_login or tpu or high_cpu or fp64)" axlearn/common

========== 0 failed, 6220 passed, 10364 skipped in 734.23s (0:12:14) ==========

apivovarov avatar Feb 12 '25 22:02 apivovarov

I wonder how the CI passed with those typos. Do they fail locally for you?

markblee avatar Feb 14 '25 00:02 markblee

Apparently, jax.jax.jax.tree.leaves works file

import jax
import jax.numpy as jnp
import sys

x = (jnp.array(0), jnp.array(1))
y = jax.jax.jax.tree.leaves(x)

print(y)
print(type(y))
print(jax.jax)
print(jax.jax.jax)
[Array(0, dtype=int32, weak_type=True), Array(1, dtype=int32, weak_type=True)]
<class 'list'>
<module 'jax' from '/usr/local/lib/python3.11/dist-packages/jax/__init__.py'>
<module 'jax' from '/usr/local/lib/python3.11/dist-packages/jax/__init__.py'>

https://colab.research.google.com/drive/1ruOWXG6GXFSh1xdHyBVJVLyRwZTYq6GQ?usp=sharing

apivovarov avatar Feb 14 '25 00:02 apivovarov

Thanks @apivovarov , though this kind of cleanups would be better handled if you propose and let us fix it. We typically need to run many internal validation etc before merging the PR. The hairy part is from our internal repo which uses AxLearn as the core library.

We will take this cleanup PR as a low priority, so does the other cleanups since they are not blocking anything at the moment. It would be great if aws can focus more on prioritizing trainium2 fixes.

kelvin-zou avatar Feb 14 '25 02:02 kelvin-zou

Hi @apivovarov do you intend to move forward with this PR? Thanks.

changlan avatar Jul 26 '25 00:07 changlan