axlearn icon indicating copy to clipboard operation
axlearn copied to clipboard

[UPDATE JAX API] Update `trainer_test` `_normalized_spec`

Open Steboss opened this issue 7 months ago • 2 comments

By running unittests:

XLA_FLAGS='--xla_force_host_platform_device_count=8' pytest --durations=100 -v   -n auto -v -m "for_8_devices" --dist worksteal ${UNQUOTED_PYTEST_FILES}

E             File "/opt/axlearn/axlearn/common/trainer_test.py", line 1093, in <lambda>
E           AttributeError: 'PartitionSpec' object has no attribute '_normalized_spec'
The internal API for `PartitionSpec` has been updated, so `_normalized_spec` is now `_normalized_spec_for_aval` ( [reference](https://github.com/jax-ml/jax/blob/609fb7f6085b52861f65c7aa3b339c40dfd207fa/jax/_src/partition_spec.py#L166) )
@matthew-e-hopkins for visibility

Steboss avatar May 21 '25 16:05 Steboss

Hey @ruomingp I can see the following error in the CI for many tests

#22 157.6 /opt/venv/lib/python3.10/site-packages/transformers/integrations/tensor_parallel.py:465: in __init__
#22 157.6     self.input_layouts = (input_layouts or Replicate(),)
#22 157.6 E   NameError: name 'Replicate' is not defined

Any idea? I can investigate on this

Steboss avatar May 22 '25 20:05 Steboss

This pull request has been automatically marked as stale because it has been inactive for 60 days. It will be closed in 7 days if no further activity occurs. If you would like to continue working on this, please remove the stale label or leave a comment.

github-actions[bot] avatar Oct 20 '25 02:10 github-actions[bot]