Andrew Gu

Results 159 comments of Andrew Gu

@lw because this requires changing the model code, I think @tianyu-l left it as a non-merged PR to show how people could change their own (forked) code to enable this.

For fun: - Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, `reshard_after_forward=True`, using @yifuwang 's hack (0.5 ms sleep), offloading 2 large FFN activations per transformer block...

@yifuwang made the good point there may be interference for inter-node collectives since they also use PCIe to send/recv data to/from the NIC, competing for the D2H/H2D activation copies. The...

For my understanding, is it possible for SDPA to detect that the input key/value shapes are targeting GQA so that we do not need to pass `enable_gqa=True`? In this case,...

https://github.com/pytorch/torchtitan/pull/382 is probably closer to ideal wrapping. I agree that separately wrapping embeddings and final output linear is more efficient. cc: @tianyu-l if he wants to change it.

@tianyu-l I think we can get rid of the `reshard_after_forward=False` for last transformer block. I think it increases peak memory slightly, and I saw several places copy it from torchtitan...

I think it was copied from the original reference Llama implementation, which was meant for inference ([code](https://github.com/meta-llama/llama/blob/8fac8befd776bc03242fe7bc2236cdb41b6c609c/llama/model.py#L30)) and the `max_batch_size` was used for the KV cache. We should probably remove...

Out of curiosity, what gaps are you seeing with DPO in torchtune (https://github.com/pytorch/torchtune/blob/main/docs/source/recipes/dpo.rst)? E.g. multi-node support? anything else?

It looks like your code as-is is not supported since only a strict subset of ranks are entering `new_group`.