HSDP causes loss instability
I have a codebase forked from torchtitan with minor changes. FSDP trains very well with minimal instability, but HSDP on the same codebase exhibits loss spikes.
Is there some reason for this you folks can think of? Note that I have implemented gradient accumulation in my fork, though without changing any sharding behavior (just to accumulate the gradients on a larger batchsize)
If it's not an HSDP bug (is it?), here are some things I'd look at:
- For gradient accumulation, are you doing sum or average on the gradient? To ensure it's similar to mean cross-entropy loss backward on a larger batch size, you'll need to do average.
- You can compare gradient accumulation vs. no grad accumulation on a small scale and see if numerics are stable. E.g. batch size 10 without grad accumulation and batch size 1 with grad accumulation for 10 iterations.
- Check data loading behavior -- are you loading the same data each global batch with FSDP vs. HSDP? Although, even the data loading behaviors are different, as long as data is randomly distributed, I don't expect loss behavior to be very different.
cc: @weifengpy
but HSDP on the same codebase exhibits loss spikes
curious what the spike look like? maybe a plot of HSDP vs FSDP helps. I know after warm up, there is a spike. but would like to see if thare are spikes here and there all along the training
I noticed a spike in train loss at step 150 (warmup ends at step 200) for the current default config for llama3.1 8b using 8192 sequence length.
https://github.com/pytorch/torchtitan/blob/781ec0d0187e69935449b5f98b8f60af0dc5091f/torchtitan/models/llama3/train_configs/llama3_8b.toml
This spike disappears when using a lower sequence length of 4096, as shown in this graph. The dataset used was a 10B tokens subset of fineweb-edu. I don't know if this could help the discussion here.