axlearn icon indicating copy to clipboard operation
axlearn copied to clipboard

Support for JAX versions > 0.5.4 (current AXLearn pinned to 0.4.38)

Open Steboss opened this issue 8 months ago • 2 comments

AXLearn currently supports JAX 0.4.38, while JAX has now progressed to 0.6.0. When attempting to use AXLearn with newer JAX releases (≥ 0.5.4), we see several compatibility issues:

  1. jax.core.Primitive is updated to jax.extend.core.Primitive
  2. jax.tree_map is now jax.tree.map. (These first two points have been addressed in these PR #1082 and #1106 )
  3. In axlearn/common/utils.py, the call to 'jax._src.tree_util' references _registry_with_keypaths, which no longer exists. I’ve proposed a fix in my fork here
  4. The config option jax_spmd_mode has been removed in JAX (see here ). This causes the following error:
File "/opt/axlearn/axlearn/common/launch.py", line 115, in setup
    setup_spmd(
File "/opt/axlearn/axlearn/common/utils_spmd.py", line 47, in setup
     jax.config.update("jax_spmd_mode", "allow_all")
File "/opt/jaxlibs/jax/jax/_src/config.py", line 111, in update
     raise AttributeError(f"Unrecognized config option: {name}")
AttributeError: Unrecognized config option: jax_spmd_mode

I can contribute patches to address these errors. However, additional optimizations or performance checks might be required before merging. Let me know how nvidia team can collaborate with you on this. Thank you

Steboss avatar Apr 23 '25 09:04 Steboss

It looks like there's a new error we should be aware of:

File "/opt/axlearn/axlearn/common/config.py", line 282, in validate_config_field_value
raise InvalidConfigValueError(
 axlearn.common.config.InvalidConfigValueError: Invalid config value type <class 'jax._src.partition_spec.PartitionSpecImpl'> for value "PartitionSpec(('data', 'expert', 'fsdp', 'seq'),)". Consider registering a custom validator with `register_validator`.

this happens in the validation in config.py and it's due because of this new JAX commit where they've changed PartitionSpec to be PartitionSpecImpl

Steboss avatar May 08 '25 14:05 Steboss

A solution for this error is adding a specific validator in axlearn/common/config.py:

from jax.sharding import PartitionSpec
...
register_validator(
    match_fn=lambda v: isinstance(v, PartitionSpec),
    validate_fn=lambda v: validate_config_field_value(tuple(v)),
)

Steboss avatar May 08 '25 16:05 Steboss