Steboss

Results 8 issues of Steboss

This refers to PR #1136 - `jax_spmd_mode` is now [obsolete](https://github.com/jax-ml/jax/commit/7634230cdcd2d3cb42d1093f6ab255f47f9869d5) - Tests on performance have been run for `fuji-3B-v3-flash-attention`, and results are still matching the previous implementation: Metrics | This...

JAX removed the attributed `_registry_with_keypaths` so you may incur in this error: ```bash File "/opt/axlearn/axlearn/common/utils.py", line 1870, in pytree_children egistry_with_keypaths = jax._src.tree_util._registry_with_keypaths AttributeError: module 'jax._src.tree_util' has no attribute '_registry_with_keypaths' ```...

AXLearn currently supports JAX 0.4.38, while JAX has now progressed to 0.6.0. When attempting to use AXLearn with newer JAX releases (≥ 0.5.4), we see several compatibility issues: 1. `jax.core.Primitive`...

@matthew-e-hopkins Hey people, this is a huge update, to allow us to use JAX > 0.5.3 (we're currently testing AXLearn with JAX 0.7.2). I've implemented the following changes: - I've...

In the latest JAX version there's been a modification of the behaviour of `jax.experiment.array_serialization.serialization` where the attributes: - `_spec_has_metadata` - `ts` have been transferred to `ts_impl` ([here](https://github.com/jax-ml/jax/blob/main/jax/experimental/array_serialization/tensorstore_impl.py, )). Thus I...

By running unittests: ```bash XLA_FLAGS='--xla_force_host_platform_device_count=8' pytest --durations=100 -v -n auto -v -m "for_8_devices" --dist worksteal ${UNQUOTED_PYTEST_FILES} E File "/opt/axlearn/axlearn/common/trainer_test.py", line 1093, in E AttributeError: 'PartitionSpec' object has no attribute '_normalized_spec'...

This PR gathers all the related PRs that are working on the `jax.tree` structure #1202 #1200 #1199

Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements) - [x] Did you read the [contributor guideline](https://github.com/Lightning-AI/pytorch-lightning/blob/main/.github/CONTRIBUTING.md), Pull Request section?...