ring-flash-attention
ring-flash-attention copied to clipboard
verify causal masking
Hi @zhuzilin, follow up from https://github.com/zhuzilin/ring-flash-attention/issues/15
I just wanted to verify the causal, and I simply use loop because I dont have multigpus, but it should be working, when I do causal using your ring logic, the argmax accuracy is super low, but when I do non causal, accuracy is almost perfect 100%, you can check the notebook at https://github.com/mesolitica/context-parallelism-xformers/blob/master/playground/flash-ring-attention-causal.ipynb
From what I understand, let say, I got 2 devices and seqlen of 100k, partitioned to 2, 100k // 2 = 50k 50k, so,
each 50k seq len, device 0: 50k q0k0v0 device 1: 50k q1k1v1
So the blockwise attention calculation, device 0: 50k q0k0v0 + 50k q0k1v1 device 1: 50k q1k0v0 + 50k q1k1v1
(+) denoted as blockwise attention.
For causal base, attention mask is necessary, so the attention mask originally is [100k, 100k] and attention mask we must chunk properly, to become mask0 = [50k, 100k] and mask1 = [50k, 100k], so the blockwise attention calculation,
device 0: 50k (q0k0 * mask0[:, 0:50k])v0 + 50k q0k1v1 * mask0[:, 50k:100k] device 1: 50k (q1k0 * mask1[:, 0:50k])v0 + 50k q1k1v1 * mask1[:, 50k:100k]
You can see this slicing from original https://github.com/forhaoliu/ringattention/blob/main/ringattention/ringattention_pallas_tpu.py#L61
Correct me if im wrong here, thanks!