algorithmic-efficiency icon indicating copy to clipboard operation
algorithmic-efficiency copied to clipboard

Add schedule-free adamw submission in JAX

Open priyakasimbeg opened this issue 1 year ago • 7 comments

Description

Currently we have been unable to reproduce the schedule free adamw results with JAX. There seem to be differences between the optax implementation of schedule-free adamw and the pytorch submission.

priyakasimbeg avatar Oct 31 '24 17:10 priyakasimbeg

I can help debug any issues here. Do you have any code you can share? If there are issues with the optax jax implementation I want to get it fixed asap.

adefazio avatar Nov 01 '24 23:11 adefazio

There are many small differences between the behavior of schedule-free jax wrapper and the original algoperf submission. Some differences I'm aware of:

  • The bias correction in the submission scales the weight decay at early steps. This is slightly faster for fastMRI but doesn't appear to affect any other workloads in my experiments.
  • Weight decay is applied at y in the Jax version. This decay-at-y version is very similar in my experiments, if not slightly better (when testing in PyTorch). The experiments in the schedule-free paper use this decay-at-y version.
  • There is a r=0.5 weighting in the submission version - this seems to make little if any difference in practice (hard to tell due to noise).

So overall I expect the jax wrapper version to give as good results on all problems (maybe slightly slower on fastmrI), so if there is a difference it would be from some sort of bug.

adefazio avatar Nov 06 '24 20:11 adefazio

Hi Aaron! thanks for weighing in on this. I seemed to have missed your messages on this thread.

We have a slightly modified version based on the optax code here: https://github.com/priyakasimbeg/algorithmic-efficiency/blob/compare_schedule_free/tests/test_algorithms/schedule_free_adamw/jax/schedule_free_optax.py. This code adds r and we tested it with 0.75 on our google internal codebase.

I'm working on a test to compare the pytorch and jax implementations side by side on the algoperf github code but the test is still in progress. I can perhaps run a full training run on some of the workloads. But in the meantime feel free to weigh in again if you spot any other differences

priyakasimbeg avatar Nov 19 '24 21:11 priyakasimbeg

Ok, I take a look and see if I spot any differences.

adefazio avatar Nov 19 '24 21:11 adefazio

It looks like the z buffer my be initialized with zeros: https://github.com/priyakasimbeg/algorithmic-efficiency/blob/5556015054e3dda681e2a25e05a2f217d933453d/tests/test_algorithms/schedule_free_adamw/jax/submission.py#L58C51-L59C1 It needs to be initialized the same as the main parameter buffer. I think this line is a copy-paste error from the Jax version of NAdamW and other methods, where all optimizer state is normally initialized at zero.

Suggestion: you might want to set z on the first call to the main optimizer update, that's what we do in the pytorch version.

adefazio avatar Nov 19 '24 23:11 adefazio

@priyakasimbeg Let me know if that initialization issue was the problem.

adefazio avatar Nov 22 '24 23:11 adefazio

Hi Aaron thanks for spotting that! We did the end-to-end training with an internal codebase using the optax implementation: https://github.com/google-deepmind/optax/blob/3ba9822c2a8d5fa7d046180b2574e108094523b4/optax/contrib/_schedule_free.py. It's not immediately obvious to me how this is initialized but will investigate and report back here

priyakasimbeg avatar Nov 25 '24 23:11 priyakasimbeg

Hello @adefazio Just following up on this thread, me and @priyakasimbeg were actually debugging this issue and found that z buffer initialization to zeros might actually be the reason for this, instead now we are passing model_params as suggested and it seems to solve the problem...

This is how the val loss of Jax vs Pytorch looked when we initialize the z buffer as the above lines (https://github.com/priyakasimbeg/algorithmic-efficiency/blob/5556015054e3dda681e2a25e05a2f217d933453d/tests/test_algorithms/schedule_free_adamw/jax/submission.py#L58C51-L59C1) Image

and This is how it looks now after passing the model_params in opt_init_fn function:

Image

FYI, We are still going to run some more tests on it and observe the difference

init-22 avatar Apr 10 '25 18:04 init-22

Awesome, let me know if there are any other issues I can help debug.

adefazio avatar Apr 10 '25 18:04 adefazio

I ran serveral other workloads comparing the diff between the jax implementation with proper initialized z-buffer and the pytorch implementation.

The result we see shows diff for different workloads, here's the plot bwteen global_step and log-scaled validation loss:

Image Image Image Image

will look into it more to find where the diff comes from

wyfEmma avatar Oct 21 '25 21:10 wyfEmma

interestingly for imagenet pytorch is unstable:

Image

wyfEmma avatar Nov 06 '25 15:11 wyfEmma