jiaopenglong

Results 1 issues of jiaopenglong

Hi~ @zhuzilin 我正在尝试将[BPT](https://arxiv.org/abs/2305.19370) 接入ring flash attention,使用chunk_size切分qkv,在local进行更小chunk的attention计算。 参照ring_flash_attn.py的forward和backward,实现了` blockwise_flash_attn_forward` 和 `blockwise_flash_attn_backward`,目前forward精度可以对齐,backward存在误差。我想问一下,backward的实现可能存在哪些问题? 下面是我的实现: ``` def blockwise_flash_attn_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_chunk_size: int, k_chunk_size: int, softmax_scale, dropout_p=0, causal=True, return_softmax=True ): assert...