ring-flash-attention icon indicating copy to clipboard operation
ring-flash-attention copied to clipboard

verify causal masking

Open huseinzol05 opened this issue 4 months ago • 3 comments

Hi @zhuzilin, follow up from https://github.com/zhuzilin/ring-flash-attention/issues/15

I just wanted to verify the causal, and I simply use loop because I dont have multigpus, but it should be working, when I do causal using your ring logic, the argmax accuracy is super low, but when I do non causal, accuracy is almost perfect 100%, you can check the notebook at https://github.com/mesolitica/context-parallelism-xformers/blob/master/playground/flash-ring-attention-causal.ipynb

From what I understand, let say, I got 2 devices and seqlen of 100k, partitioned to 2, 100k // 2 = 50k 50k, so,

each 50k seq len, device 0: 50k q0k0v0 device 1: 50k q1k1v1

So the blockwise attention calculation, device 0: 50k q0k0v0 + 50k q0k1v1 device 1: 50k q1k0v0 + 50k q1k1v1

(+) denoted as blockwise attention.

For causal base, attention mask is necessary, so the attention mask originally is [100k, 100k] and attention mask we must chunk properly, to become mask0 = [50k, 100k] and mask1 = [50k, 100k], so the blockwise attention calculation,

device 0: 50k (q0k0 * mask0[:, 0:50k])v0 + 50k q0k1v1 * mask0[:, 50k:100k] device 1: 50k (q1k0 * mask1[:, 0:50k])v0 + 50k q1k1v1 * mask1[:, 50k:100k]

You can see this slicing from original https://github.com/forhaoliu/ringattention/blob/main/ringattention/ringattention_pallas_tpu.py#L61

Correct me if im wrong here, thanks!

huseinzol05 avatar Oct 06 '24 17:10 huseinzol05