axlearn icon indicating copy to clipboard operation
axlearn copied to clipboard

[JAX API Update] Remove `jax_spmd_mode` from config

Open Steboss opened this issue 8 months ago • 4 comments

This refers to PR #1136

  • jax_spmd_mode is now obsolete
  • Tests on performance have been run for fuji-3B-v3-flash-attention, and results are still matching the previous implementation:
Metrics This PR implementation Previous AXLearn implementation
Tokens per sec per gpu 9288 8904
Seqs per sec per gpu 2.26 2.17
Average time step 0.88 0.91
TFLOPS per sec per GPU 218.80 209.74

@apghml if you could review this please. thank you

Steboss avatar May 06 '25 08:05 Steboss

Does this PR also need the validator? https://github.com/apple/axlearn/issues/1126#issuecomment-2863567411

dmarx avatar May 19 '25 06:05 dmarx

Hey @dmarx This PR does not need the validator. I spotted a few more bugs wrt JAX versions, but I didn't push any PR yet. I will open them for reference for @matthew-e-hopkins

Steboss avatar May 19 '25 08:05 Steboss

@Steboss Is this PR ready to merge?

apghml avatar May 19 '25 16:05 apghml

@apghml yes it is :)

Steboss avatar May 19 '25 19:05 Steboss

@apghml @dmarx is it possible to get a log on why this merge has failed? thank you :)

Steboss avatar May 22 '25 10:05 Steboss

I believe they fixed the issue recently. Can you merge the latest main branch and then we can retry?

apghml avatar May 22 '25 18:05 apghml

Also, @Steboss I see many of your recent PRs have been closed. Are you moving them somewhere else? If so, can you share a link? Thanks!

apghml avatar May 22 '25 18:05 apghml

@apghml I've mainly moved all the PRs that were dealing with jax.tree in a single one, namely this one #1207

Steboss avatar May 22 '25 20:05 Steboss

That PR seems to be by a different user?

apghml avatar May 22 '25 22:05 apghml

@apghml #1206 sorry

Steboss avatar May 23 '25 08:05 Steboss