torchtitan
torchtitan copied to clipboard
numerical issue when running SDPA with DTensor
The issue comes from the backward computation of aten.mul
of two complex numbers from DTensors: the result will be b + ai
when it should be a + bi
. Not sure why it happens -- when doing aten operations, the input tensors have been de-sugared and should have nothing to do with DTensor.
To replicate, put the following code in pytorch/test/distributed/tensor/parallel/test_tp_examples.py
@with_comms
def test_apply_rotary_embedding(self):
device_mesh = self.build_device_mesh()
def apply_rotary_emb(xq, freqs_cis):
xq_ = torch.view_as_complex(xq)
xq_out = torch.view_as_real(xq_ * freqs_cis)
return xq_out
with CommDebugMode():
# xq = torch.randn(1, 1, 2, requires_grad=True, device=self.device_type)
# freqs_cis = torch.randn(1, 1, dtype=torch.complex64, requires_grad=False, device=self.device_type)
# xq_out = apply_rotary_emb(xq, freqs_cis)
# xq_out.sum().backward()
xq = torch.randn(1, 1, 2, requires_grad=True, device=self.device_type)
freqs_cis = torch.randn(1, 1, dtype=torch.complex64, requires_grad=False, device=self.device_type)
xq_dt = distribute_tensor(xq, device_mesh, (Replicate(),))
freqs_cis_dt = distribute_tensor(freqs_cis, device_mesh, (Replicate(),))
xq_out_dt = apply_rotary_emb(xq_dt, freqs_cis_dt)
xq_out_dt.sum().backward()
A solution is proposed in https://github.com/pytorch/pytorch/issues/130646