Why use two streams for context parallel
Hi, I see in https://github.com/NVIDIA/TransformerEngine/blob/29e8bfc99d803770ad82ae9351db63673bc34f69/transformer_engine/pytorch/attention.py#L624 that you used two cuda streams to resolve "wave quantization" in flash attention. Could you clarify what "wave quantization" means? I think flash attention just uses fp16/bf16
Two streams will help overlap communication and computation. The second stream can start processing the next chunk of data as soon as it is received, while the first stream is still working on the previous one.
In fact, I repeatedly see longer runtime using two streams. Wave quantization is defined here: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html