RWKV-LM
RWKV-LM copied to clipboard
Fix broken accumulate_grad_batches behavior
Fix broken accumulate_grad_batches argument in v5 trainer
While trying to finetune some of the RWKV-7-Pile models, I found that the accumulate_grad_batches argument sent to the main trainer file had some bugs.
- Upon exiting a run and restarting from the last checkpoint, the step count logged to W&B jumped significantly
- The learning rate schedule was completely broken (barely changed until a sharp drop after restarting from checkpoint)
These bugs occurred because the training code doesn't take into account gradient accumulation steps when calculating total tokens processed and the number of steps to resume at.
To fix this, I modified the trainer code so that the step logged to W&B is the actual optimization step, not the grad accumulation (micro) step, by dividing args.epoch_steps by args.accumulate_grad_batches in the calculation of real_step. This should have no effect when args.accumulate_grad_batches is equal to 1, which is the default. Then, I modified the learning rate schedule so that real_tokens and warmup_tokens are scaled by the number of gradient accumulation batches, so that the learning rate schedule works properly. All the other code is left the same, and the training progress bar still represents micro-steps.
In my testing, this appears to fix the issue of resuming training, and the learning rate scales properly. Below is a comparison for a training run before the change and a training run after from W&B, both using accumulate_grad_batches=8 (on different data, though).
Before (Stopped and resumed around 15k steps):
After (Stopped and resumed at ~200, 400 steps; the LR does scale properly with my_exit_tokens, but not visible from this image):
NOTE: While in theory this should not be a breaking change, I would still highly recommend testing on a multi-GPU setup for any bugs as I only had access to my local GPU while testing.
How to Test:
Do a training run in which the --accumulate_grad_batches argument is set to a number greater than 1; check that the learning rate schedule works properly and that resuming from a checkpoint does not cause step gaps in the loss curve.