Recommended Approach for Gradient Accumulation
Opening an issue for the discussion linked below.
The proposal that @marcvanzee and I discussed today is:
- Create a minimal example (e.g. MNIST) that uses gradient accumulation, likely the Optax one.
- In the example, try to demonstrate how it works -- display/explain the intermediate outputs, add more documentation where needed.
- Based on the experience of creating and running the above, consider opening up a discussion with Optax about whether they could update their implementation to address the pain points, or if not then whether we should implement our own. The former would be preferable.
Discussed in https://github.com/google/flax/discussions/2030
Originally posted by sanchit-gandhi April 6, 2022 I'm working on a training script for a Speech model in Flax, and was wondering if I could get an opinion from the community on what the best way is of implementing gradient accumulation in JAX/Flax. To my understanding, there are two viable options:
- Optax MultiSteps: the wrapper can incorporated with another Optax optimizer (e.g. Optax Adamw) to provide gradient updates at a prescribed number of gradient accumulation steps, without any change to the training step itself.
- A custom gradient accumulation training step: for each training step, a superbatch is formed comprising of all the batched data for N gradient accumulation steps. Gradients are computed and manually accumulated over the N gradient accumulation steps, and only applied once following this.
The simpler of the two approaches, I initially trialled using Optax MultiSteps. Keeping the per-device batch size fixed, I was not able to increase the number of gradient accumulation steps to be any greater than 1 (equivalent to no gradient accumulation!). Thus, I implemented a version of gradient accumulation by hand (see here). Once again, I was not able to increase the number of gradient accumulation steps to be any greater than 1 keeping the per-device batch size fixed. Are there any caveats with using Optax MultiSteps in regards to memory that people have experienced before? As for the custom approach, is there a 'standard' way of going about doing this? Many thanks for all your help!