OpenDelta
OpenDelta copied to clipboard
Flash Attention and Open Delta LoRA
Hello @ShengdingHu,
Are you able to confirm whether Flash Attention will be compatible with Open Delta LoRA?
For example:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1.4b")
tokenizer.pad_token = tokenizer.mask_token
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-1.4b")
max_positions = model_args.max_positions
tokenizer.model_max_length = max_positions
for layer in model.gpt_neox.layers:
original_emb = layer.attention.rotary_emb
layer.attention.rotary_emb = RotaryEmbedding(layer.attention.rotary_ndims,max_positions,10000)
layer.attention.bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
)
layer.attention = FlashAttentionWrapper(layer.attention, max_seqlen = max_positions)
# patching for the random contiguous tensors bug
for p in model.parameters():
p = p.contiguous()
Visualization(model).structure_graph()
delta_model1 = LoraModel(
backbone_model=model,
modified_modules=[
'attention.attention.query_key_value',
'mlp.dense_h_to_4h',
]
)
delta_model1.freeze_module()
delta_model1.log(delta_ratio=True, trainable_ratio=True, visualization=True)
Thank you for your great work,
Enrico
Sorry, we can't find the library where FlashAttentionWrapper
is located, could you please tell us which library it is?
Sorry, we can't find the library where
FlashAttentionWrapper
is located, could you please tell us which library it is?
My apologies. I thought I had linked it previously.
Here is the link to the Huggingface wrappers utilizing Flash Attention: https://github.com/kyleliang919/Long-context-transformers/blob/main/flash_attn_wrappers.py
The wrappers changed a bit so now it would be the NEOX one:
class FlashAttentionWrapperWithRotary(torch.nn.Module):
def __init__(self, attention, max_seqlen = 8192):
super().__init__()
self.attention = attention
self.max_seqlen = max_seqlen
self.flash_self_attention = FlashSelfAttention(causal = True, softmax_scale = 1/self.attention.norm_factor)
self.dropout_p = 0.0
def forward(self,
hidden_states,
attention_mask,
head_mask=None,
layer_past=None,
use_cache=False,
output_attentions=False):
has_layer_past = layer_past is not None
# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (np * 3 * head_size)]
qkv = self.attention.query_key_value(hidden_states)
# [batch, seq_len, (num_heads * 3 * head_size)]
# --> [batch, seq_len, num_heads, 3 * head_size]
new_qkv_shape = qkv.size()[:-1] + (self.attention.num_attention_heads, 3 * self.attention.head_size)
qkv = qkv.view(*new_qkv_shape)
# [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
query = qkv[..., : self.attention.head_size].permute(0, 2, 1, 3)
key = qkv[..., self.attention.head_size : 2 * self.attention.head_size].permute(0, 2, 1, 3)
value = qkv[..., 2 * self.attention.head_size :].permute(0, 2, 1, 3)
# Compute rotary embeddings on rotary_ndims
query_rot = query[..., : self.attention.rotary_ndims]
query_pass = query[..., self.attention.rotary_ndims :]
key_rot = key[..., : self.attention.rotary_ndims]
key_pass = key[..., self.attention.rotary_ndims :]
# Compute token offset for rotary embeddings (when decoding)
seq_len = key.shape[-2]
offset = 0
if has_layer_past:
offset = layer_past[0].shape[-2]
seq_len += offset
cos, sin = self.attention.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset)
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)
# Cache QKV values
if has_layer_past:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = (key, value) if use_cache else None
# Compute attention
#attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
qkv = torch.concat([query.unsqueeze(2), key.unsqueeze(2), value.unsqueeze(2)], dim = 2).permute(0, 3, 2, 1, 4).half()
attn_output = self.flash_self_attention(qkv)
attn_weights = None
# Reshape outputs
attn_output = attn_output.view(attn_output.size(0), attn_output.size(1), self.attention.num_attention_heads * self.attention.head_size)
attn_output = self.attention.dense(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs
Thank you,
Enrico
Sorry for not replying in time. Flash Attention cannot be easily compatible with Open Delta LoRA. LoRA essentially splits the parameters of the model (such as q) into two matrices AB. Flash Attention needs to use AB as a whole, but Open Delta LoRA only treats AB as two parts. Therefore, the current Open Delta LoRA cannot directly adapt to Flash Attention. Sorry for not being able to help you, we will conduct further research and try our best to make improvements so that Open Delta can adapt to Flash Attention in the future.