torchtitan
torchtitan copied to clipboard
run sdpa with dtensor
Stack from ghstack (oldest at bottom):
- -> #180
- #285
- #161
- #172
This PR gets rid of the manual adjustment of num of heads in attention layers, by using dtensor outputs of wq
, wk
, wv
, so that the SDPA is aware of the distributedness.
just curious, is this gonna land soon or does it have some risk or unfinished business?
also looks like this could use a rebase. i got a little confused applying it on my branch bc some of the sharding config seems changed (attention.wo and attention_norm)
just curious, is this gonna land soon or does it have some risk or unfinished business?
also looks like this could use a rebase. i got a little confused applying it on my branch bc some of the sharding config seems changed (attention.wo and attention_norm)
It hasn't been landed because there is a very strange bug (#267) associated with (but seemingly not caused by) multiplication using DTensor. It would be triggered in the rotary embedding computation if this PR is landed. I will work on the bug soon since it will also benefit PP (iiuc). @wconstab
It would be triggered in the rotary embedding computation if this PR is landed
oh, is this related to dispatching for complex numbers by any chance?
oh, is this related to dispatching for complex numbers by any chance?
@wconstab Possibly, we don't know. The aten.mul
op returns bad results with inputs being raw torch.Tensor (desugared from DTensor), and this bug is only present in the backward pass. Do you know who I should ask for help from?