torchtitan
torchtitan copied to clipboard
Context parallel on Turing GPUs?
As the title suggests, is torchtitan CP supported on Turing GPU?
I got the error RuntimeError: No available kernel. Aborting execution. using the default run_train.sh script with CP changed to 2.
I know Turing GPUs don't have flash attention support yet, but I read the torchtitan CP blog post here, and it seems like the memory-efficient attention backend would work with CP?
If this is the case, could you share how to enable this backend in torchtitan? I tried to wrap this line with with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):, but the error persists.
Thanks
Not sure if it could help, but have you tried keeping this line only and remove the other two? https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/utils.py#L159
Current CP only supports SDPA. This error is from SDPA indicating that it cannot find the available kernels. We only support memory efficient, flash, and cudnn attention.
Thanks everyone for your timely response!
I tested the SDPA backend locally, and it seems like memory-efficient attention backend only works fp16 and fp32 but not bf16, causing no kernel backend to choose. CP is fine switching to fp32 training.
Quick Q: Does the team have a plan to support fp16 training?
small feedback: would it be possible to emit warning/debug information on why kernels are not available?
hmm I thought bf16 is supported for memory-efficient, but I could be wrong. cc @drisspg for help and user feedback
We do emit warning information for SDPA as why kernels can't be chosen, not sure if the current CP implementation is passing that up. If you wanted to look at the source code: https://github.com/pytorch/pytorch/blob/dfcd98e684123b0cb0a143d8718b0672c58ec268/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp#L721
Its a little hard to read but we du support bf16 on sm80 or newer https://github.com/pytorch/pytorch/blob/dfcd98e684123b0cb0a143d8718b0672c58ec268/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp#L733C1-L735C7
It does have float16 support on older hardware
Thanks, I just realized that the warning is indeed emitted earlier (just not part of the error trace that I didn't notice).
I understand the memory-efficient sdpa backend work with fp16. I was wondering if torchtitan supports (or plans to support) fp16 with loss scaling? The config choice here seems only allow bf16 and fp32. It would be really nice to have fp16 support for older GPUs.