unilm
unilm copied to clipboard
Add an isolated implementation of FlashDiffAttention
This PR is trying to implement a FlashDiffAttention class similar to the FlashSelfAttention in the origin flash attention repo (https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py#L53), so that training frameworks could easily add diff transformer support with and without varlen support.
The main idea is to set the num_head in the training process twice as the origin transformer so that we no longer need to change the code relates to RoPE.
A simple test script for the code is:
from dataclasses import dataclass
import torch
import torch.distributed as dist
from flash_attn.layers.rotary import RotaryEmbedding
from einops import rearrange
from multihead_flashdiff_2 import MultiheadFlashDiff2
from flashdiff import FlashDiffAttention
from kernel.rotary import apply_rotary_emb
@dataclass
class Args:
model_parallel_size: int
decoder_kv_attention_heads: int
def create_new_impl(origin_impl, head_dim, depth):
diff_attn_func = FlashDiffAttention(
head_dim=embed_dim // num_new_heads, depth=depth, causal=True
).to(device, dtype=dtype)
# make the initialization the same
diff_attn_func.lambda_q1.data.copy_(origin_impl.lambda_q1.data)
diff_attn_func.lambda_k1.data.copy_(origin_impl.lambda_k1.data)
diff_attn_func.lambda_q2.data.copy_(origin_impl.lambda_q2.data)
diff_attn_func.lambda_k2.data.copy_(origin_impl.lambda_k2.data)
#diff_attn_func.subln.weight.data.copy_(origin_impl.subln.weight.data)
def new_impl(x, rel_pos):
bsz, tgt_len, embed_dim = x.size()
src_len = tgt_len
q = origin_impl.q_proj(x)
k = origin_impl.k_proj(x)
v = origin_impl.v_proj(x)
# here we no longer need "// 2"
num_heads = embed_dim // head_dim
num_kv_heads = k.shape[-1] // head_dim
q = q.view(bsz, tgt_len, num_heads, head_dim)
k = k.view(bsz, src_len, num_kv_heads, head_dim)
v = v.view(bsz, src_len, num_kv_heads, head_dim)
q = apply_rotary_emb(q, *rel_pos, interleaved=True)
k = apply_rotary_emb(k, *rel_pos, interleaved=True)
output = diff_attn_func(q, k, v)
output = rearrange(output, '... H D -> ... (H D)')
output = origin_impl.out_proj(output)
return output
return new_impl
if __name__ == "__main__":
dist.init_process_group(backend="nccl")
device = torch.device("cuda")
dtype = torch.bfloat16
args = Args(model_parallel_size=1, decoder_kv_attention_heads=4)
batch_size = 2
num_heads = 16
seq_len = 512
embed_dim = 2048
depth = 12
# in the new implementation, the num_heads should be twice the original num_heads
num_new_heads = num_heads * 2
head_dim = embed_dim // num_new_heads
print("initializing modules")
origin_impl = MultiheadFlashDiff2(args, embed_dim=embed_dim, depth=depth, num_heads=num_heads).to(device, dtype=dtype)
new_impl = create_new_impl(origin_impl, head_dim, depth)
print("creating test data")
rotary_emb = RotaryEmbedding(
head_dim,
base=10000.0,
interleaved=True,
device=device,
)
rotary_emb._update_cos_sin_cache(seq_len, device=device, dtype=torch.bfloat16)
rel_pos = (rotary_emb._cos_cached, rotary_emb._sin_cached)
hidden_states = torch.randn((batch_size, seq_len, embed_dim), device=device, dtype=dtype)
print("run origin forward")
origin_output = origin_impl(hidden_states, rel_pos)
print("run new forward")
new_output = new_impl(hidden_states, rel_pos)
assert torch.allclose(origin_output, new_output, atol=1e-6)
Thank you for your time on reviewing this PR.
You could go even closer to attention and use it as is with a doubled interleave. E.g.
def alternative_forward(
self,
x,
rel_pos,
attn_mask=None,
):
bsz, tgt_len, embed_dim = x.size()
src_len = tgt_len
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim)
k = k.view(bsz, src_len, 2 * self.num_kv_heads, self.head_dim)
v = v.view(bsz, src_len, self.num_kv_heads, 2 * self.head_dim)
q = apply_rotary_emb(q, *rel_pos, interleaved=True)
k = apply_rotary_emb(k, *rel_pos, interleaved=True)
q = q.transpose(1, 2)
k = torch.repeat_interleave(k.transpose(1, 2), dim=1, repeats=self.n_rep)
v = torch.repeat_interleave(v.transpose(1, 2), dim=1, repeats=self.n_rep * 2)
if attn_mask is None:
attn_mask = torch.triu(
torch.zeros([tgt_len, src_len])
.float()
.fill_(float("-inf"))
.type_as(q),
1 + src_len - tgt_len,
)
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
attn_weights = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=attn_mask, scale=self.scaling)
every_other_mask = torch.arange(attn_weights.size(1)) % 2 == 0
attn = attn_weights[:, every_other_mask, :, :] - lambda_full * attn_weights[:, ~every_other_mask, :, :]
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
attn = attn.transpose(1, 2).reshape(bsz, tgt_len, self.num_heads * 2 * self.head_dim)
attn = self.out_proj(attn)
return attn