flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Fewer matrix multiplications, same results, should we consider adopting it?

Open pandaupc opened this issue 1 year ago • 23 comments
trafficstars

In the inner loop of FlashAttention-2, each computation of O requires a computation of V. I adopted a different implementation approach. For each block Q, after calculating the complete attention scores, they are then multiplied with V, which means each Q only needs to be multiplied with V once. I compared the results of my implementation with those obtained from standard attention and FlashAttention, and found them to be consistent.

pandaupc avatar May 08 '24 10:05 pandaupc

I don't quite understand your algorithm, can you add some pseudo code?

tridao avatar May 08 '24 22:05 tridao

For the mathematical derivation, I used induction, which simplified the logic. The implementation also used recursion. The implementation was done in PyTorch, with some simplifications; for instance, assuming 𝑄_block=𝐾𝑉_block, and only causal attention was implemented.

The mathematical derivation is as follows: image image image

The PyTorch code is as follows: def new_flash_attention_causal_forward(Q, K, V): # Q : batch_size * num_attention * seq_len * hidden_size # l : batch_size * num_attention * seq_len B,N,S,H = Q.shape # 分子 f = torch.ones(size=(B,N,S,S)).to(device='cuda') causal_mask = torch.tril(torch.ones(S, S), 0).to(device='cuda') f = torch.where(causal_mask == 0, 0, f) # 分母 l = torch.zeros(size=(B,N,S,1)).to(device='cuda') # 全局最大值 m = (torch.ones(size=(B,N,S,1)) * NEG_INF).to(device='cuda') O = torch.zeros_like(Q, requires_grad=True)

Q_BLOCK_SIZE = BLOCK_SIZE
K_BLOCK_SIZE = BLOCK_SIZE
Q_BLOCK= torch.split(Q,Q_BLOCK_SIZE,dim=2)
K_BLOCK = torch.split(K,K_BLOCK_SIZE,dim=2)
O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2))

l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2))
m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2))
f_BLOCKS = list(torch.split(f, Q_BLOCK_SIZE, dim=2))

scale = 1 / np.sqrt(Q.shape[-1])
Tq= len(Q_BLOCK)

for i in range(Tq):
    # B * N * Q_BLOCK * hidden_size
    Qi = Q_BLOCK[i]
    # 当前最大值
    mi = m_BLOCKS[i]
    li = l_BLOCKS[i]
    # B * N * Q_BLOCK * seq_len
    fi = f_BLOCKS[i]
    fi_BLOCK = list(torch.split(fi, Q_BLOCK_SIZE, dim=-1))
    # 仅仅计算比q小的块
    for j in range(i+1):
        Kj = K_BLOCK[j]
        # B * N * Q_BLOCK *Q_BLOCK
        # 比q小的块不需要mask
        if j < i:
            QKij = matrix_without_mask(Qi * scale,Kj)
        # 和q相等的块需要mask
        else:
            QKij = matrix_with_mask(Qi * scale,Kj)
        m_local,_ = torch.max(QKij,dim=-1,keepdim=True)
        QKij_local = torch.exp(QKij - m_local) 
        li_local = torch.sum(QKij_local,dim=-1,keepdims=True) + EPSILON
       # 初始化
        if j == 0:
            fi_BLOCK[0] = QKij_local
            mi = m_local
            li = li_local
        # 迭代分子分母
        else:
            # B * n * block * 1
            mi_new= torch.maximum(m_local,mi)
            li_new = torch.exp(mi - mi_new) * li + torch.exp(m_local - mi_new) * li_local
            for k in range(j):
                fi_BLOCK[k] = fi_BLOCK[k] * torch.exp(mi - mi_new)
            fi_BLOCK[j] = QKij_local * torch.exp(m_local - mi_new)
            mi =mi_new
            li =li_new
    fi = torch.cat(fi_BLOCK, dim=-1)
    local_attn = fi /li
    O_BLOCKS[i] = torch.matmul(local_attn,V)
O = torch.cat(O_BLOCKS, dim=2)
return O

Note:multiplying by V can further adopt block computation to reduce the memory requirements.

pandaupc avatar May 09 '24 02:05 pandaupc

I compared the standard softmax, flash_attention_mask, and my implementation, and the results were consistent.

pandaupc avatar May 09 '24 02:05 pandaupc

Then you'd need to write down the attention matrix (you call it f) of size (batch, nheads, seqlen, seqlen). Memory will be O(seqlen^2) and I don't think it'll be better than the standard way to compute attention.

tridao avatar May 09 '24 04:05 tridao

Each loop only requires a small block, and only one block needs to be stored in SRAM.

pandaupc avatar May 09 '24 05:05 pandaupc

The amount of memory reads / writes to global memory will be O(seqlen^2) if i understand correctly. Then it's not very different from calling softmax on one row block of QK^T, then call matmul. You're just doing softmax by online softmax instead of normal softmax.

tridao avatar May 09 '24 06:05 tridao

The instances of f within each loop are independent and are only used within the loop itself; therefore, initializing f inside the loop does not affect the final result.

pandaupc avatar May 09 '24 15:05 pandaupc

I removed the initialization of f from outside the loop and instead initialized it within each iteration. This change reduces the memory complexity to O(seq_len).

def get_lower_triangular_block(N, block_size, i): """ 直接生成第i个块,该块是一个大小为(block_size x N)的下三角矩阵块。 参数: - N: 整个矩阵的列数 - block_size: 块的行数 - i: 想要的块的索引(从1开始计数) """ # 计算块的起始行和结束行 start_row = i * block_size end_row = min((i+1) * block_size, N)

# 创建块
block_height = end_row - start_row
block = torch.zeros(block_height, N)

# 填充下三角部分为1,包括对角线
for row in range(block_height):
    max_col = start_row + row + 1  # 包括对角线
    max_col = min(max_col, N)  # 限制不超过矩阵的列数
    block[row, :max_col] = 1

return block.to(device='cuda')

def new_flash_attention_causal_forward(Q, K, V): # Q : batch_size * num_attention * seq_len * hidden_size # l : batch_size * num_attention * seq_len B,N,S,H = Q.shape # 分母 l = torch.zeros(size=(B,N,S,1)).to(device='cuda') # 全局最大值 m = (torch.ones(size=(B,N,S,1)) * NEG_INF).to(device='cuda') O = torch.zeros_like(Q, requires_grad=True)

Q_BLOCK_SIZE = BLOCK_SIZE
K_BLOCK_SIZE = BLOCK_SIZE
Q_BLOCK= torch.split(Q,Q_BLOCK_SIZE,dim=2)
K_BLOCK = torch.split(K,K_BLOCK_SIZE,dim=2)
O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2))

l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2))
m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2))

scale = 1 / np.sqrt(Q.shape[-1])
Tq= len(Q_BLOCK)

for i in range(Tq):
    # B * N * Q_BLOCK * hidden_size
    Qi = Q_BLOCK[i]
    # 当前最大值
    mi = m_BLOCKS[i]
    li = l_BLOCKS[i]
    # B * N * Q_BLOCK * seq_len
    fi =  torch.ones(size=(B,N,BLOCK_SIZE,S)).to(device='cuda')
    causal_mask = get_lower_triangular_block(S,BLOCK_SIZE,i)
    fi = torch.where(causal_mask == 0, 0, fi)
    fi_BLOCK = list(torch.split(fi, Q_BLOCK_SIZE, dim=-1))
    # 仅仅计算比q小的块
    for j in range(i+1):
        Kj = K_BLOCK[j]
        # B * N * Q_BLOCK *Q_BLOCK
        # 比q小的块不需要mask
        if j < i:
            QKij = matrix_without_mask(Qi * scale,Kj)
        # 和q相等的块需要mask
        else:
            QKij = matrix_with_mask(Qi * scale,Kj)
        m_local,_ = torch.max(QKij,dim=-1,keepdim=True)
        QKij_local = torch.exp(QKij - m_local) 
        li_local = torch.sum(QKij_local,dim=-1,keepdims=True) + EPSILON
       # 初始化
        if j == 0:
            fi_BLOCK[0] = QKij_local
            mi = m_local
            li = li_local
        # 迭代分子分母
        else:
            # B * n * block * 1
            mi_new= torch.maximum(m_local,mi)
            li_new = torch.exp(mi - mi_new) * li + torch.exp(m_local - mi_new) * li_local
            for k in range(j):
                fi_BLOCK[k] = fi_BLOCK[k] * torch.exp(mi - mi_new)
            fi_BLOCK[j] = QKij_local * torch.exp(m_local - mi_new)
            mi =mi_new
            li =li_new
    fi = torch.cat(fi_BLOCK, dim=-1)
    local_attn = fi /li
    O_BLOCKS[i] = torch.matmul(local_attn,V)
O = torch.cat(O_BLOCKS, dim=2)
return O

pandaupc avatar May 09 '24 15:05 pandaupc

You'd still need to write fi of each block to global memory / HBM before calling matmul with V. So the total number of bytes written to global memory / HBM is quadratic (even though the maximum amount of memory required is smaller). An analogy is if you write a number to memory 1000 times, then the number of bytes written in 1000, but the total amount of memory required is 1.

tridao avatar May 09 '24 18:05 tridao

The variable fi is used within the loop and can be deleted after its use, so I believe there is no quadratic memory requirement. It does not need to be stored in HBM.

I want to emphasize that it is not necessary to multiply by V every time; multiplying just once at the end will yield the same result.

pandaupc avatar May 09 '24 19:05 pandaupc

Will the total amount of memory written during the execution of the flash-attention algorithm be less?

pandaupc avatar May 09 '24 19:05 pandaupc

There are 2 separate things: (1) total amount of memory required, and (2) total number of bytes written to memory. I agree that your approach would have (1) being subquadratic, but I think (2) would be quadratic. Memory access takes time.

You can see Theorem 2 in the FlashAttention paper. For different values of M (size of SRAM) the number of bytes written could be much less than N^2. Screenshot 2024-05-09 at 12 25 22 PM

tridao avatar May 09 '24 19:05 tridao

Could you explain which step in my method would lead to increased HBM read and write operations? I didn't fully understand.

pandaupc avatar May 09 '24 19:05 pandaupc

This line O_BLOCKS[i] = torch.matmul(local_attn,V). You'd need local_attn to be in HBM? And local attn has shape (block_size, seqlen), for each iteration, which means iteration you're writing down (block_size * seqlen) bytes to HBM. Across all iterations you'd write down (seqlen * seqlen) bytes to HBM. Did I understand your algorithm correctly?

tridao avatar May 09 '24 20:05 tridao

I understand your point. I believe that local_atten does not need to be written to HBM since it won't be used subsequently. I think updating fi directly in SRAM could avoid frequent writes to HBM.

pandaupc avatar May 09 '24 20:05 pandaupc

Do you have enough space in SRAM to hold a tensor of size (block_size, seqlen)? Maybe if seqlen is not too long.

tridao avatar May 09 '24 20:05 tridao

Each iteration's fi is independent, and if we disregard the mask, fi can be initialized to 1 in each iteration. There is no need to read from or write to HBM.

pandaupc avatar May 09 '24 20:05 pandaupc

When the sequence length increases, SRAM indeed cannot accommodate the entire block_size * seq_len. However, the current FlashAttention would encounter the same issue. I just haven't understood which of my operations would lead to higher HBM read/write operations, whereas FlashAttention would not.

pandaupc avatar May 09 '24 20:05 pandaupc

Of course, the operation local_atten @ V can also be broken down into a loop for computation.

pandaupc avatar May 09 '24 20:05 pandaupc

I'd recommend you writing the out algorithm and annotate which tensor lives on SRAM and which are written to HBM, then (1) check that you have enough SRAM space (2) the total number of bytes read from / written to HBM.

Similar to the FlashAttention-2 algo.

Screenshot 2024-05-09 at 1 12 07 PM

tridao avatar May 09 '24 20:05 tridao

Alright, I will delve deeper into analyzing data exchanges between HBM and SRAM today, and I hope you can also help me analyze which of my operations might lead to increased memory read/write operations. Additionally, I am starting to write CUDA code and am not very experienced, so I would appreciate your further analysis. Thank you.

pandaupc avatar May 09 '24 20:05 pandaupc

I think I understand your point now. The complexity of fi is seq×block_size, which would lead to frequent writes to HBM and SRAM. I will continue to optimize my algorithm and try to complete the computation on a block_size×block_size dimension. Thank you

pandaupc avatar May 09 '24 20:05 pandaupc

Thank you for your guidance. After analysis, I found that if the computation of fi is explicitly performed, the memory complexity is block_size×seq_len. However, when the seq_len is of a moderate size (i.e., SRAM can accommodate O(block_size×seq_len)), using my method does not increase the read and write operations between HBM and SRAM, but it can reduce the computation load (as V only needs to be multiplied once).

By controlling the length of the sequence, an appropriate implementation can be chosen.

pandaupc avatar May 10 '24 00:05 pandaupc