[JAX API Update] Remove `jax_spmd_mode` from config
This refers to PR #1136
jax_spmd_modeis now obsolete- Tests on performance have been run for
fuji-3B-v3-flash-attention, and results are still matching the previous implementation:
| Metrics | This PR implementation | Previous AXLearn implementation |
|---|---|---|
| Tokens per sec per gpu | 9288 | 8904 |
| Seqs per sec per gpu | 2.26 | 2.17 |
| Average time step | 0.88 | 0.91 |
| TFLOPS per sec per GPU | 218.80 | 209.74 |
@apghml if you could review this please. thank you
Does this PR also need the validator? https://github.com/apple/axlearn/issues/1126#issuecomment-2863567411
Hey @dmarx This PR does not need the validator. I spotted a few more bugs wrt JAX versions, but I didn't push any PR yet. I will open them for reference for @matthew-e-hopkins
@Steboss Is this PR ready to merge?
@apghml yes it is :)
@apghml @dmarx is it possible to get a log on why this merge has failed? thank you :)
I believe they fixed the issue recently. Can you merge the latest main branch and then we can retry?
Also, @Steboss I see many of your recent PRs have been closed. Are you moving them somewhere else? If so, can you share a link? Thanks!
@apghml I've mainly moved all the PRs that were dealing with jax.tree in a single one, namely this one #1207
That PR seems to be by a different user?
@apghml #1206 sorry