lolcats
lolcats copied to clipboard
HSDP question
Thanks for releasing this great work! I was able to get the training to run with sequence length 1024 on the Llama 8B model on 24GB GPUs. I would like to also be able to train on sequence lengths of 2048 and longer with these GPUs, and due to the memory constraint, I would need HSDP. I tried setting sharding_group_size to 2 and replica_group_size to 1 in the fsdp config. But I'm still getting OOM, so is there something else I need to do to get the hybrid sharding to work?