torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

HSDP causes loss instability

Open apkumar opened this issue 10 months ago • 4 comments

I have a codebase forked from torchtitan with minor changes. FSDP trains very well with minimal instability, but HSDP on the same codebase exhibits loss spikes.

Is there some reason for this you folks can think of? Note that I have implemented gradient accumulation in my fork, though without changing any sharding behavior (just to accumulate the gradients on a larger batchsize)

apkumar avatar Jan 31 '25 03:01 apkumar

If it's not an HSDP bug (is it?), here are some things I'd look at:

  1. For gradient accumulation, are you doing sum or average on the gradient? To ensure it's similar to mean cross-entropy loss backward on a larger batch size, you'll need to do average.
  2. You can compare gradient accumulation vs. no grad accumulation on a small scale and see if numerics are stable. E.g. batch size 10 without grad accumulation and batch size 1 with grad accumulation for 10 iterations.
  3. Check data loading behavior -- are you loading the same data each global batch with FSDP vs. HSDP? Although, even the data loading behaviors are different, as long as data is randomly distributed, I don't expect loss behavior to be very different.

tianyu-l avatar Jan 31 '25 23:01 tianyu-l

cc: @weifengpy

gnadathur avatar Feb 06 '25 22:02 gnadathur

but HSDP on the same codebase exhibits loss spikes

curious what the spike look like? maybe a plot of HSDP vs FSDP helps. I know after warm up, there is a spike. but would like to see if thare are spikes here and there all along the training

weifengpy avatar Feb 06 '25 22:02 weifengpy

I noticed a spike in train loss at step 150 (warmup ends at step 200) for the current default config for llama3.1 8b using 8192 sequence length.

https://github.com/pytorch/torchtitan/blob/781ec0d0187e69935449b5f98b8f60af0dc5091f/torchtitan/models/llama3/train_configs/llama3_8b.toml

This spike disappears when using a lower sequence length of 4096, as shown in this graph. The dataset used was a 10B tokens subset of fineweb-edu. I don't know if this could help the discussion here.

Image

K-H-Ismail avatar May 02 '25 18:05 K-H-Ismail