Andrew Gu

Results 159 comments of Andrew Gu

You should be able to pass `device_mesh` as a 2D device mesh to enable HSDP. (You could also pass in a 2-tuple of process groups, but I think the checkpointing...

It depends on your inter-node bandwidth. If your inter-node bandwidth is fast, FSDP is probably still better, especially if your model is compute-dense like a transformer. The overall workflow I...

@ScottHoang yes, it will broadcast from global rank 0 to all ranks (including both intra and inter-node process groups): https://github.com/pytorch/pytorch/blob/afaa5fcecb07472a8805902074f4611dc5798f76/torch/distributed/fsdp/_init_utils.py#L632-L635

I have seen good results compiling the cross entropy loss by itself in torchtitan as well.

If I try to compile both the output linear and cross entropy loss together instead of just compiling the cross entropy loss, I get OOMs at the same batch size.

Llama3-8B With these changes: ``` [rank0]:2024-08-21 08:44:32,865 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank0]:2024-08-21 08:44:32,897 - root - INFO - Compiling each TransformerBlock with...

cc: @H-Huang @wconstab I am guessing we should run more steps to see if the memory keeps increasing or not. For the loss issue, I think only loss is reported...

@yifuwang I need to fix PP before this is landable 😢

@H-Huang @wconstab do you have any idea if the output logits being fp32 is a hard requirement for PP? anyway we can leave them as bf16?