Ke Wen
Ke Wen
Also confirmed that a line like this `z = torch.zeros_like(y, dtype=y.dtype)` would burn dtype into the kwargs: ``` # in forward, code: z = torch.zeros_like(y, dtype=y.dtype) zeros_like: "f32[2, 4, 3]"...
The doc of `torch.zeros_like` says: ``` torch.zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) ``` dtype ([torch.dtype](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype), optional) – the desired data type of returned Tensor. Default: if `None`, defaults to...
Exporting the llama model and printing the stack shows me that the `zeros_like` is from the `scaled_dot_product_attention` ``` # File: /data/users/kw2501/torchtitan/torchtitan/models/llama/model.py:203 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)...
More specifically: - In pytorch/aten/src/ATen/native/transformers/attention.cpp:  - Then in `convert_boolean_attn_mask`:  https://github.com/pytorch/pytorch/blob/a5c93a6899c657832944cd2eeb5069449e28dbea/aten/src/ATen/native/transformers/attention.cpp#L523
CC: @zhxchen17 @tugsbayasgalan let me know you are preparing an improvement to unburn the dtype as well? (in addition to device). We will be thrilled to try that out. CC:...
Meanwhile, @tugsbayasgalan mentioned that pre-dispatch mode is now the default mode of torch.export. That can also work around this issue by using this new mode to avoid tracing into SPDA.
Sorry I cannot reproduce the hang on my system (8xA100). ``` $ torchrun --standalone --nproc-per-node 4 pippy_llama.py Downloading shards: 100%|█████████████████████████████████████████████████████████████| 2/2 [00:00
@pytorchbot merge
@pytorchbot merge
If we put in special logic for `numel() == 1`, what about the case of `numel() < nranks`?