torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

run sdpa with dtensor

Open tianyu-l opened this issue 10 months ago • 4 comments

Stack from ghstack (oldest at bottom):

  • -> #180
  • #285
  • #161
  • #172

This PR gets rid of the manual adjustment of num of heads in attention layers, by using dtensor outputs of wq, wk, wv, so that the SDPA is aware of the distributedness.

tianyu-l avatar Mar 30 '24 00:03 tianyu-l

just curious, is this gonna land soon or does it have some risk or unfinished business?

also looks like this could use a rebase. i got a little confused applying it on my branch bc some of the sharding config seems changed (attention.wo and attention_norm)

wconstab avatar Apr 30 '24 17:04 wconstab

just curious, is this gonna land soon or does it have some risk or unfinished business?

also looks like this could use a rebase. i got a little confused applying it on my branch bc some of the sharding config seems changed (attention.wo and attention_norm)

It hasn't been landed because there is a very strange bug (#267) associated with (but seemingly not caused by) multiplication using DTensor. It would be triggered in the rotary embedding computation if this PR is landed. I will work on the bug soon since it will also benefit PP (iiuc). @wconstab

tianyu-l avatar Apr 30 '24 18:04 tianyu-l

It would be triggered in the rotary embedding computation if this PR is landed

oh, is this related to dispatching for complex numbers by any chance?

wconstab avatar Apr 30 '24 22:04 wconstab

oh, is this related to dispatching for complex numbers by any chance?

@wconstab Possibly, we don't know. The aten.mul op returns bad results with inputs being raw torch.Tensor (desugared from DTensor), and this bug is only present in the backward pass. Do you know who I should ask for help from?

tianyu-l avatar Apr 30 '24 22:04 tianyu-l