Ke Wen

Results 65 comments of Ke Wen

Also confirmed that a line like this `z = torch.zeros_like(y, dtype=y.dtype)` would burn dtype into the kwargs: ``` # in forward, code: z = torch.zeros_like(y, dtype=y.dtype) zeros_like: "f32[2, 4, 3]"...

The doc of `torch.zeros_like` says: ``` torch.zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) ``` dtype ([torch.dtype](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype), optional) – the desired data type of returned Tensor. Default: if `None`, defaults to...

Exporting the llama model and printing the stack shows me that the `zeros_like` is from the `scaled_dot_product_attention` ``` # File: /data/users/kw2501/torchtitan/torchtitan/models/llama/model.py:203 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)...

More specifically: - In pytorch/aten/src/ATen/native/transformers/attention.cpp: ![Screenshot 2024-05-10 at 5 29 10 PM](https://github.com/pytorch/PiPPy/assets/6676466/346fb99d-f91b-4cf2-8e5f-45fb9cb71736) - Then in `convert_boolean_attn_mask`: ![Screenshot 2024-05-10 at 5 30 26 PM](https://github.com/pytorch/PiPPy/assets/6676466/a64172b1-f5c4-4178-b999-0d42ba1dffc3) https://github.com/pytorch/pytorch/blob/a5c93a6899c657832944cd2eeb5069449e28dbea/aten/src/ATen/native/transformers/attention.cpp#L523

CC: @zhxchen17 @tugsbayasgalan let me know you are preparing an improvement to unburn the dtype as well? (in addition to device). We will be thrilled to try that out. CC:...

Meanwhile, @tugsbayasgalan mentioned that pre-dispatch mode is now the default mode of torch.export. That can also work around this issue by using this new mode to avoid tracing into SPDA.

Sorry I cannot reproduce the hang on my system (8xA100). ``` $ torchrun --standalone --nproc-per-node 4 pippy_llama.py Downloading shards: 100%|█████████████████████████████████████████████████████████████| 2/2 [00:00

If we put in special logic for `numel() == 1`, what about the case of `numel() < nranks`?