ring-flash-attention
ring-flash-attention copied to clipboard
Ring attention implementation with flash attention
In the backward function of ring-attn, rng_state does not use the value from forward function, but directly passes in None. Does this indicate that ring-attn does not support dropout?
Normally we could use other long context methods like deepspeed ulysseus to avoid implementing this.
Hi~ @zhuzilin 我正在尝试将[BPT](https://arxiv.org/abs/2305.19370) 接入ring flash attention,使用chunk_size切分qkv,在local进行更小chunk的attention计算。 参照ring_flash_attn.py的forward和backward,实现了` blockwise_flash_attn_forward` 和 `blockwise_flash_attn_backward`,目前forward精度可以对齐,backward存在误差。我想问一下,backward的实现可能存在哪些问题? 下面是我的实现: ``` def blockwise_flash_attn_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_chunk_size: int, k_chunk_size: int, softmax_scale, dropout_p=0, causal=True, return_softmax=True ): assert...
精度问题
There are some arithmetic errors with the current implementation. The reason for them is probably that flash attention will return bf16 value for each block, so we cannot accumluate the...
 Thanks for sharing this excellent implementation of ring attention. Here are my test results on 2*A100 (with nvlink). Judging from the results, the memory usage of ring attention(ring_flash_attn_qkvpacked_func) seems...
I tried to measure the time spent in the `reqs` returned `batch_isend_irecv()`. Interestingly this time seems to be indepentent of sequence length and in total negligible. Could be that on...
请问最低的flash-attention版本是?
Hey loving the work on ring flash attention, I'm contacting you as our community cuda-mode is working on a cuda/pytorch version of ring attention, so feel free to join the...
ring attention本质是flash attention的分布式版本,flash attentionV2里面会维护softmax分母但是在更新out的时候好像只会更新最大值不会更新分母用于减少计算吧?在Q和一圈KV算完了以后最后除以一个softmax的global分母就可以了,所以作者这个ring attention实现的分布式FA可以理解成是v1版本的FA吗?