flash-attention
flash-attention copied to clipboard
comparing HF vs FA2 llama2 models
hi, i'm looking over the optimizations in the trainer here, and trying to port them to the transformers.trainer.Trainer for use with llama2
i put together this simple script to view the differences between the two:
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.llama import llama_config_to_gpt2_config
from transformers import LlamaConfig, LlamaForCausalLM
MODEL = "meta-llama/Llama-2-7b-chat-hf"
config = llama_config_to_gpt2_config(LlamaConfig.from_pretrained(MODEL))
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
print(config)
model = GPTLMHeadModel(config, device="meta")
print(model)
model = LlamaForCausalLM.from_pretrained(MODEL, device_map="meta")
print(model)
and I see:
GPTLMHeadModel(
(transformer): GPTModel(
(embeddings): GPT2Embeddings(
(word_embeddings): Embedding(32000, 4096)
)
(layers): ModuleList(
(0-31): 32 x Block(
(mixer): MHA(
(rotary_emb): RotaryEmbedding()
(Wqkv): FusedDense(in_features=4096, out_features=12288, bias=False)
(inner_attn): FlashSelfAttention(
(drop): Dropout(p=0.0, inplace=False)
)
(inner_cross_attn): FlashCrossAttention(
(drop): Dropout(p=0.0, inplace=False)
)
(out_proj): FusedDense(in_features=4096, out_features=4096, bias=False)
)
(dropout1): Dropout(p=0.0, inplace=False)
(drop_path1): StochasticDepth(p=0.0, mode=row)
(norm1): RMSNorm()
(mlp): GatedMlp(
(fc1): Linear(in_features=4096, out_features=22016, bias=False)
(fc2): Linear(in_features=11008, out_features=4096, bias=False)
)
(dropout2): Dropout(p=0.0, inplace=False)
(drop_path2): StochasticDepth(p=0.0, mode=row)
(norm2): RMSNorm()
)
)
(drop_f): Dropout(p=0.0, inplace=False)
(ln_f): RMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(32000, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
(up_proj): Linear(in_features=4096, out_features=11008, bias=False)
(down_proj): Linear(in_features=11008, out_features=4096, bias=False)
(act_fn): SiLUActivation()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
i'm trying to understand the differences here.
it appears there's an extra RMSNorm inserted between the attn and mlp? was this intentional?
it also looks like GatedMlp only has two linear layers, but the first is double sized. what's going on there?
i saw there's no fused version of GatedMlp yet. is there any reason to use it over LLamaMLP if i'm not doing tensor parallel?
We have code to convert weights from Meta and HF to be compatible with the implementation in this repo. Test is here to verify the the models implemented in this repo matches that of HF implementation.
input_layernorm -> norm1, post_attention_layernorm -> norm2 mlp.gate_proj and mlp.up_proj are combined into 1 matrix.
thanks @tridao, and sorry for the dumb questions.
i see now that the order of the nodes in the print output is not relevant, as the code dictates how they're wired together.
my goal is to use huggingface/transformers.trainer.Trainer on llama2, but using the optimizations found here.
i realized that some of the optimizations are standalone, and can be integrated into the transformers llama model directly. for example, xentropy and rmsnorm: https://github.com/huggingface/transformers/compare/main...tmm1:transformers:llama-flash
however other optimizations such as rotary_emb require more structural changes and are simplest to use with Block and MHA directly.
so i've tried to take the GPTLMHeadModel and feed it into transformers.Trainer directly (using #479). there are a few other minor incompatibilities (missing model.device, unimplemented model.gradient_checkpointing_enable(), etc). but after working through those, i end up with this confusing exception:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../transformers/src/transformers/trainer.py:1555: in train
return inner_training_loop(
../transformers/src/transformers/trainer.py:1837: in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
../transformers/src/transformers/trainer.py:2693: in training_step
self.accelerator.backward(loss)
/home/tmm1/micromamba/envs/dev/lib/python3.10/site-packages/accelerate/accelerator.py:1902: in backward
loss.backward(**kwargs)
/home/tmm1/micromamba/envs/dev/lib/python3.10/site-packages/torch/_tensor.py:487: in backward
torch.autograd.backward(
/home/tmm1/micromamba/envs/dev/lib/python3.10/site-packages/torch/autograd/__init__.py:193: in backward
grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
outputs = (tensor([[[ 0.2080, 0.0364, 0.2754, ..., 1.4062, 1.9844, 0.7070],
[-8.0000, -3.2188, -2.5000, ..., -5.... [-8.7500, -9.1875, 3.5938, ..., -6.8125, -6.8438, -5.2188]]],
device='cuda:0', grad_fn=<DivBackward0>),)
grads = (None,), is_grads_batched = False
def _make_grads(outputs: Sequence[torch.Tensor], grads: Sequence[_OptionalTensor],
is_grads_batched: bool) -> Tuple[_OptionalTensor, ...]:
new_grads: List[_OptionalTensor] = []
for out, grad in zip(outputs, grads):
if isinstance(grad, torch.Tensor):
first_grad = grad if not is_grads_batched else grad[0]
if not torch.is_same_size(out, first_grad):
out_shape, grad_shape = _calculate_shape(out, first_grad, is_grads_batched)
if is_grads_batched:
raise RuntimeError("If `is_grads_batched=True`, we interpret the first "
"dimension of each grad_output as the batch dimension. "
"The sizes of the remaining dimensions are expected to match "
"the shape of corresponding output, but a mismatch "
"was detected: grad_output["
+ str(grads.index(grad)) + "] has a shape of "
+ str(grad_shape) + " and output["
+ str(outputs.index(out)) + "] has a shape of "
+ str(out_shape) + ". "
"If you only want some tensors in `grad_output` to be considered "
"batched, consider using vmap.")
else:
raise RuntimeError("Mismatch in shape: grad_output["
+ str(grads.index(grad)) + "] has a shape of "
+ str(grad_shape) + " and output["
+ str(outputs.index(out)) + "] has a shape of "
+ str(out_shape) + ".")
if out.dtype.is_complex != grad.dtype.is_complex:
raise RuntimeError("For complex Tensors, both grad_output and output"
" are required to have the same dtype."
" Mismatch in dtype: grad_output["
+ str(grads.index(grad)) + "] has a dtype of "
+ str(grad.dtype) + " and output["
+ str(outputs.index(out)) + "] has a dtype of "
+ str(out.dtype) + ".")
new_grads.append(grad)
elif grad is None:
if out.requires_grad:
if out.numel() != 1:
> raise RuntimeError("grad can be implicitly created only for scalar outputs")
E RuntimeError: grad can be implicitly created only for scalar outputs
does this make any sense to you, or do you have ideas for where i can go next to investigate?
My guess is that it's because our GPTLMHeadModel doesn't return a loss, it returns the output which is of size (batch, seqlen, vocab_size). You'd need to have a separate loss function (e.g. CrossEntropy).
FWIW I have this model def'n integrated into our training code based on huggingface trainer. I can confirm that you need to override compute_loss (ideally using the fused cross entropy loss implemented in this codebase).
Also for sequence parallel, you may want to apply allreduce_sequence_parallel_grad after backward.
FWIW I have this model def'n integrated into our training code based on huggingface trainer. I can confirm that you need to override
compute_loss(ideally using the fused cross entropy loss implemented in this codebase).Also for sequence parallel, you may want to apply
allreduce_sequence_parallel_gradafter backward.
could you share what the override for compute_loss would look like to use this?
I've been comparing torch.nn.CrossEntropyLoss vs flash_attn.losses.cross_entropy.CrossEntropyLoss, but am not able to measure any memory or speed difference between the two.
Aha I just saw this:
partial(CrossEntropyLoss, inplace_backward=True)
I'm still not able to measure any difference. I'm using the HF trainer and model with this change:
import transformers
from functools import partial
from flash_attn.losses.cross_entropy import CrossEntropyLoss
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(CrossEntropyLoss, inplace_backward=True)
I'm still not able to measure any difference. I'm using the HF trainer and model with this change:
import transformers from functools import partial from flash_attn.losses.cross_entropy import CrossEntropyLoss transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(CrossEntropyLoss, inplace_backward=True)
It all depends on how much time / memory the cross entropy is taking. You can benchmark to see how much CE is taking. If it's taking like 1-2% of the time then it doesn't matter. For small models and/or large vocab size you'll see speedup and memory saving.
I discovered my patching code wasn't running for some silly reason.
I'm interested in quantifying the differences between these implementations, especially when it comes to VRAM usage.
I used the VRAM line profiler I'm working on here, and measured the following difference:
- Train VRAM peak: 1.94464 GB
+ Train VRAM peak: 1.78449 GB
Line # Hits Mem Per Hit % Mem Line Contents
==============================================================
- 852 2 262.1 131.1 15.1 loss = loss_fct(shift_logits, shift_labels)
+ 852 2 loss = loss_fct(shift_logits, shift_labels)
...
- 2693 2 318.8 159.4 15.5 self.accelerator.backward(loss)
+ 2693 2 408.9 204.5 21.7 self.accelerator.backward(loss)
This is a very limited test (2 steps w/ llama2-7b @ 2048 ctx), but it shows the slight benefit present when using the custom xentropy kernel.
Is this with batch size 1? My back-of-the-envelop calculation: the logits has size (batch, seqlen, vocab_size), taking 2 bytes each (e.g. training with bf16). Our xentropy kernel avoids storing an extra copy so we save (2 * batch * seqlen * vocab_size) bytes. With llama 7b and batch=1, seqlen=2048, vocab_size = 30k, this is 123MB. With larger batch size the memory saving is larger (but maybe you can't run with larger batch size because of GPU mem limit).
Yea, that was with batchsize=1.
I made some more measurements @ ctx=4096:
| cfg | mem |
|---|---|
| bs=1 xentropy=false | 3.59699 GB |
| bs=1 xentroy=true | 3.28644 GB |
| bs=2 xentropy=false | 7.27851 GB |
| bs=2 xentropy=true | 6.65741 GB |
Yup seems to check out with my calculation.
Now I applied the rmsnorm kernel on top, as follows:
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
and the results are further improved for ctx=4096:
| bs | xentropy | rms | mem |
|---|---|---|---|
| 1 | true | false | 3.28644 GB |
| 1 | true | true | 3.20246 GB |
| 2 | true | false | 6.65741 GB |
| 2 | true | true | 5.95893 GB |
| 4 | false | false | 13.9354 GB |
| 4 | true | false | 12.0428 GB |
| 4 | true | true | 11.8348 GB |
| 6 | false | false | 18.3586 GB |
| 6 | true | true | 17.6281 GB |
I'm working on the rotary kernel next, but am not quite sure if I'm handling the cos/sin correctly:
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, MSELoss
from flash_attn.losses.cross_entropy import CrossEntropyLoss
from flash_attn.ops.rms_norm import RMSNorm
+from flash_attn.layers.rotary import apply_rotary_emb
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
@@ -324,7 +325,14 @@ class LlamaAttention(nn.Module):
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ cos = cos.squeeze(1).squeeze(0)[position_ids].squeeze(0)
+ sin = sin.squeeze(1).squeeze(0)[position_ids].squeeze(0)
+ # cos, sin: (seqlen_rotary, rotary_dim / 2)
+ cos = cos[..., :cos.shape[-1] * 2:2]
+ sin = sin[..., :sin.shape[-1] * 2:2]
+ # fused rope expects (batch_size, seqlen, nheads, headdim)
+ query_states = apply_rotary_emb(query_states.transpose(1, 2), cos, sin, inplace=True).transpose(1, 2)
+ key_states = apply_rotary_emb(key_states.transpose(1, 2), cos, sin, inplace=True).transpose(1, 2)
if past_key_value is not None:
# reuse k, v, self_attention
Why not just convert the HF weights to use the Llama implementation in this repo? https://github.com/Dao-AILab/flash-attention/blob/a86442f0f35c135c8ed8d7af760b1bd6a832ec07/tests/models/test_llama.py#L65
You can also see how we use rotary in MHA here.
Thanks for the pointer.
I know I can convert the weights and use the trainer here, but I'm interested in features that transformers offers out of the box, such as loading weights in 4bit/8bit and using PEFT techniques such as LoRA, QLoRA and IA3. I'm also just trying to understand how this stuff works better, so translating between the two implementations is helpful as a learning exercise.
I'll read through the MHA implementation and see what I can figure out.
Okay I see I should probably be using the flash_attn.layers.rotary.RotaryEmbedding module instead of trying to call apply_rotary_emb directly.
On the transformers side there are several variations:
class LlamaRotaryEmbedding(torch.nn.Module):
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
I'll see what it takes to get these variations into the fa2 rotary module first
Yeah we should probably add those as arguments to the RotaryEmbedding module (scaling factor, scaling_method="ntk" or scaling_method="standard"). Is "standard" the right name or is there another name?
Thanks for the pointer.
I know I can convert the weights and use the trainer here, but I'm interested in features that transformers offers out of the box, such as loading weights in 4bit/8bit and using PEFT techniques such as LoRA, QLoRA and IA3. I'm also just trying to understand how this stuff works better, so translating between the two implementations is helpful as a learning exercise.
I'll read through the MHA implementation and see what I can figure out.
Have you been able to look at this? I was also wondering how would one use the fused MLP layers with huggingface?
You can use this subclass with the HF trainer: https://github.com/Dao-AILab/flash-attention/pull/486
You can use this subclass with the HF trainer: #486
Can it be used in 4bit and peft like hf models?
Yes, but not with fused MLP because there's no place for peft to hook into the linear layers.
Yes, but not with fused MLP because there's no place for peft to hook into the linear layers.
Thanks a lot. Can you give a pointer how to load that model in 4bit using huggingface?
Yes, but not with fused MLP because there's no place for peft to hook into the linear layers.
I dont think we can use that model in 4bit
Is there any information about training speed comparing HF vs FA2 llama2 models ?