Andrew Gu
Andrew Gu
You should be able to pass `device_mesh` as a 2D device mesh to enable HSDP. (You could also pass in a 2-tuple of process groups, but I think the checkpointing...
It depends on your inter-node bandwidth. If your inter-node bandwidth is fast, FSDP is probably still better, especially if your model is compute-dense like a transformer. The overall workflow I...
@ScottHoang yes, it will broadcast from global rank 0 to all ranks (including both intra and inter-node process groups): https://github.com/pytorch/pytorch/blob/afaa5fcecb07472a8805902074f4611dc5798f76/torch/distributed/fsdp/_init_utils.py#L632-L635
I have seen good results compiling the cross entropy loss by itself in torchtitan as well.
If I try to compile both the output linear and cross entropy loss together instead of just compiling the cross entropy loss, I get OOMs at the same batch size.
Llama3-8B With these changes: ``` [rank0]:2024-08-21 08:44:32,865 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank0]:2024-08-21 08:44:32,897 - root - INFO - Compiling each TransformerBlock with...
cc: @H-Huang @wconstab I am guessing we should run more steps to see if the memory keeps increasing or not. For the loss issue, I think only loss is reported...
@yifuwang I need to fix PP before this is landable 😢
@H-Huang @wconstab do you have any idea if the output logits being fp32 is a hard requirement for PP? anyway we can leave them as bf16?
lost the local branch going to reopen