Support for JAX versions > 0.5.4 (current AXLearn pinned to 0.4.38)
AXLearn currently supports JAX 0.4.38, while JAX has now progressed to 0.6.0. When attempting to use AXLearn with newer JAX releases (≥ 0.5.4), we see several compatibility issues:
-
jax.core.Primitiveis updated tojax.extend.core.Primitive -
jax.tree_mapis nowjax.tree.map. (These first two points have been addressed in these PR #1082 and #1106 ) - In
axlearn/common/utils.py, the call to 'jax._src.tree_util' references_registry_with_keypaths, which no longer exists. I’ve proposed a fix in my fork here - The config option
jax_spmd_modehas been removed in JAX (see here ). This causes the following error:
File "/opt/axlearn/axlearn/common/launch.py", line 115, in setup
setup_spmd(
File "/opt/axlearn/axlearn/common/utils_spmd.py", line 47, in setup
jax.config.update("jax_spmd_mode", "allow_all")
File "/opt/jaxlibs/jax/jax/_src/config.py", line 111, in update
raise AttributeError(f"Unrecognized config option: {name}")
AttributeError: Unrecognized config option: jax_spmd_mode
I can contribute patches to address these errors. However, additional optimizations or performance checks might be required before merging. Let me know how nvidia team can collaborate with you on this. Thank you
It looks like there's a new error we should be aware of:
File "/opt/axlearn/axlearn/common/config.py", line 282, in validate_config_field_value
raise InvalidConfigValueError(
axlearn.common.config.InvalidConfigValueError: Invalid config value type <class 'jax._src.partition_spec.PartitionSpecImpl'> for value "PartitionSpec(('data', 'expert', 'fsdp', 'seq'),)". Consider registering a custom validator with `register_validator`.
this happens in the validation in config.py and it's due because of this new JAX commit where they've changed PartitionSpec to be PartitionSpecImpl
A solution for this error is adding a specific validator in axlearn/common/config.py:
from jax.sharding import PartitionSpec
...
register_validator(
match_fn=lambda v: isinstance(v, PartitionSpec),
validate_fn=lambda v: validate_config_field_value(tuple(v)),
)