axlearn
axlearn copied to clipboard
[UPDATE JAX API] Update `trainer_test` `_normalized_spec`
By running unittests:
XLA_FLAGS='--xla_force_host_platform_device_count=8' pytest --durations=100 -v -n auto -v -m "for_8_devices" --dist worksteal ${UNQUOTED_PYTEST_FILES}
E File "/opt/axlearn/axlearn/common/trainer_test.py", line 1093, in <lambda>
E AttributeError: 'PartitionSpec' object has no attribute '_normalized_spec'
The internal API for `PartitionSpec` has been updated, so `_normalized_spec` is now `_normalized_spec_for_aval` ( [reference](https://github.com/jax-ml/jax/blob/609fb7f6085b52861f65c7aa3b339c40dfd207fa/jax/_src/partition_spec.py#L166) )
@matthew-e-hopkins for visibility
Hey @ruomingp I can see the following error in the CI for many tests
#22 157.6 /opt/venv/lib/python3.10/site-packages/transformers/integrations/tensor_parallel.py:465: in __init__
#22 157.6 self.input_layouts = (input_layouts or Replicate(),)
#22 157.6 E NameError: name 'Replicate' is not defined
Any idea? I can investigate on this
This pull request has been automatically marked as stale because it has been inactive for 60 days. It will be closed in 7 days if no further activity occurs. If you would like to continue working on this, please remove the stale label or leave a comment.