transformers
transformers copied to clipboard
fine tuning the updated Phi-2 with flash-attn-2 produces very high loss > 2
System Info
The updated code of phi-2 produces a high loss, I have tried fp16, bf16, deepspeed and fsdp the result is the same -> loss starts at 2 and keeps going higher. Setting use_flash_attention_2=False
fixes this or using the old phi-2 modeling file.
torch==2.1.2 flash-attn==2.4.2 transformers==4.37.0.dev0
Who can help?
No response
Information
- [ ] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
Fine-tune the updated phi-2 model using transformers trainer
Expected behavior
Loss go down
I experienced the same thing! Over 3 epochs same set up just updated code and flash attention, the loss went from 6 to 2. And on the old code without flash attention it was .60 to ~.29 . Very strange.
cc @younesbelkada @ArthurZucker
Hi @abacaj, as per @pacman100 guidelines in https://github.com/huggingface/transformers/pull/28142 / https://github.com/huggingface/transformers/pull/28142#issuecomment-1869513914 you need to make sure to load your model in full-precision and train with autocast (bf16=True). Also can you share more insights on how you train your model? (do you load the model in bf16/fp16, do you use PEFT, packing, etc.) ?
Hi @younesbelkada, this is a full fine tune using HF trainer. Padding only. Model is loaded in bf16. I try loading in "fp32" but get error:
ValueError: Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed torch.float32, this might lead to unexpected behaviour.
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=True,
config=config,
attn_implementation="flash_attention_2",
torch_dtype=torch.float32,
cache_dir=training_args.cache_dir,
)
Ok thanks @abacaj for getting back ! I think you get that error because the patch #28142 has not been released on pypi - can you try to build transformers from source?
pip install -U git+https://github.com/huggingface/transformers.git
That should hopefully solve it, let me know if you face into more issues!
Ok so I remove the explicit torch_dtype
following the comments in your link. The loss is still very high with flash-attn-2 using phi-2 model
@abacaj which padding side are you using for training?
I use padding_side="left"
. Here is how the loss goes with and without FA2 (green line has FA2) using phi-2:
FWIW changing padding side doesn't do anything to the loss, it's the same
I see, as a sanity check, can you share your TrainingArguments
?
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.95,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=True,
bf16_full_eval=False,
cache_dir=None,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=src/configs/deepspeed_2_config.json,
disable_tqdm=False,
dispatch_batches=None,
do_eval=False,
do_predict=False,
do_train=True,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=0.0,
evaluation_strategy=no,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=2,
gradient_checkpointing=True,
gradient_checkpointing_kwargs=None,
greater_is_better=None,
group_by_length=False,
half_precision_backend=auto,
hub_always_push=False,
hub_model_id=None,
hub_private_repo=False,
hub_strategy=every_save,
hub_token=<HUB_TOKEN>,
ignore_data_skip=False,
include_inputs_for_metrics=False,
include_num_input_tokens_seen=False,
include_tokens_per_second=False,
inference_length=2048,
jit_mode_eval=False,
label_names=None,
label_smoothing_factor=0.0,
learning_rate=5e-05,
length_column_name=length,
load_best_model_at_end=False,
local_rank=1,
log_level=passive,
log_level_replica=warning,
log_on_each_node=True,
logging_dir=checkpoints/results/2k-2k-dynamic-5e-5/runs/Jan15_13-42-15_sgpu,
logging_first_step=False,
logging_nan_inf_filter=True,
logging_steps=1.0,
logging_strategy=steps,
lr_scheduler_kwargs={},
lr_scheduler_type=cosine,
max_grad_norm=1.0,
max_steps=-1,
metric_for_best_model=None,
model_max_position_embeddings=2048,
mp_parameters=,
neftune_noise_alpha=None,
no_cuda=False,
num_train_epochs=3.0,
optim=adamw_torch,
optim_args=None,
output_dir=checkpoints/results/2k-2k-dynamic-5e-5,
overwrite_output_dir=False,
past_index=-1,
per_device_eval_batch_size=4,
per_device_train_batch_size=4,
prediction_loss_only=False,
push_to_hub=False,
push_to_hub_model_id=None,
push_to_hub_organization=None,
push_to_hub_token=<PUSH_TO_HUB_TOKEN>,
ray_scope=last,
remove_unused_columns=True,
report_to=['tensorboard'],
resume_from_checkpoint=None,
rope_scaling_factor=1.0,
rope_scaling_type=dynamic,
run_name=checkpoints/results/2k-2k-dynamic-5e-5,
save_on_each_node=False,
save_only_model=False,
save_safetensors=True,
save_steps=100.0,
save_strategy=epoch,
save_total_limit=None,
seed=70,
skip_memory_metrics=True,
split_batches=False,
tf32=None,
torch_compile=False,
torch_compile_backend=None,
torch_compile_mode=None,
torchdynamo=None,
tpu_metrics_debug=False,
tpu_num_cores=None,
use_cpu=False,
use_ipex=False,
use_legacy_prediction_loop=False,
use_mps_device=False,
warmup_ratio=0.02,
warmup_steps=0,
weight_decay=0.1
During my testing, I used bf16, trust remote code, no gradient ckpt, for SFT, with flshattn. The resulting model was terrible, I knew off the of (6 to 2)loss it wasn’t going to preform, but during testing it was worse than expected, very mangled answers. However when I trained the model, same arguments; just using the old phi repo code, and no flshattnt I got a great model. The loss from .6 to .29. Both were full fine tunes. Flash attention is critical for 24gb cards as without it it’s training off shared memory. I can help out more with testing when it’s done training in ~30 hours off shared mem 😭. The script I used is on #28381 . (Keep in mind the script doesn’t reflect me using bf16, however both times I trained the model I did have compute dtype set to bf16.)
Hello everyone!
Could you all please test using the latest revision on microsoft/phi-2
and report the results? We might have found the issue.
Regards, Gustavo.
FWIW - the model still comes out significantly worse using FA2. If anyone wants to fine-tune this model, I recommend you use it without FA2 currently. Running code benchmarks with FA2 < 50% on heval. Without FA2 (and all other hparams are identical, including seed) > 60% heval.
The first graph is a comparison between using and not using flash attention 2. It seems that the loss doesn't change much with fa2 (yellowish curve).
@abacaj could you please provide a minimal snippet to reproduce your fine-tuning?
We want to investigate it further more and attempt to find the root of the problem. We are doing a line-by-line comparison between the new model's code and the previous one.
FWIW - the model still comes out significantly worse using FA2. If anyone wants to fine-tune this model, I recommend you use it without FA2 currently. Running code benchmarks with FA2 < 50% on heval. Without FA2 (and all other hparams are identical, including seed) > 60% heval.
I second this, just woke up and checked training after 3 epochs with FA2 I went from .54 to ~.40, meanwhile, with no FA2 I went from .60 to .30. Both full fine tunes. I’m gonna train the fa2 check point on an additional epoch to see if it gives me the same loss as with out FA2. Or to see if it over fits.
EDIT: The loss is off to a terrible start. Went as low as .39 then up to as high as .49.( It’s only at .07 of an epoch. But i’m training a check point that has trained on this exact dataset already for 3 epochs.) Significantly better than before with the soft max scaling issues, but there is still something up.
The loss is acting quite random, in comparison to no FA2. Where the loss consistently went down.
SECOND UPDATE: I restarted the training with the same checkpoint and upped the learning rate by a order of 1, so from 2e5 to 2e6 and now the loss is more consistent, confusing why this hyper parameter differs in training when using fa2 and not using fa2.
Not perfect but better.
THIRD UPDATE: I tried retraining the base model with fa2 and the loss isnt going anywhere. After 1.5 epochs. Its almost as if the weights aren’t being updated at all, and if so very marginally. Just consistently staying between .5 and .4 but random at every logging step.
sorry to comment on this closed issue but I still have issues with FA2
- loss is different with FA2 compared to without
- loss is different even between two FA2 runs (used
set_seed
. doesn't happen without FA2 - loss always exactly the same)
I do another two runs using FA2 and without FA2 (blue line). Testing the models out using vLLM, the model without FA2 (blue line) scores 58% on humaneval, the model with FA2 scores 49%. I basically stopped using FA2 because the model comes out so much worse and I can't pinpoint why (the runs are identical with exception of use_flash_2)
Hi @ArthurZucker, Could you reconsider opening this issue again? I think it’s worth opening, as training with flash attention on phi-2 is still not viable. The performance gains are almost essential though. I appreciate it thank you!
Just wanted to acknowledge I have the same issue with using Fast Attention 2 with phi-2, the training loss hardly decreases with FA2 turned on, and works pretty well with it turned off.
same question...
We want to investigate it further more and attempt to find the root of the problem. We are doing a line-by-line comparison between the new model's code and the previous one.
@gugarosa is there any update on fixing FA2 for this amazing model?
There is a PR that should fix it but is hard to merge #28673
I am also seeing similar issue where loss is trending downwards but quite unstable and it seems to learn very slowly. I am running full fine-tuning of latest Phi2 model on my dataset.
@ArthurZucker I just started another run after reinstalling transformers with changes from #28673 to see if it fixes the issue (still using FA2). will post loss curve in next few hours.
[Incorrect] Update-1: loss curve after reinstalling transformers with changes from [#28673]. Looks like there is no change..
Update-2: Looks like my new transformer installation didn't include changes from #28673 so essentially both plot should be same. I tried reinstalling transformers again with PR changes and now training is failing:
File "/home/minimalist/miniconda3/envs/axolotl_Feb12/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/minimalist/miniconda3/envs/axolotl_Feb12/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "/home/minimalist/miniconda3/envs/axolotl_Feb12/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "/home/minimalist/work/projects/transformers/src/transformers/models/phi/modeling_phi.py", line 318, in forward
query_states = self.q_proj(hidden_states)
File "/home/minimalist/miniconda3/envs/axolotl_Feb12/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/minimalist/miniconda3/envs/axolotl_Feb12/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/minimalist/miniconda3/envs/axolotl_Feb12/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16
Update, it was a torch error, its training now, but the loss is the same as before, I lowered my dataset to 1k examples over 3 epochs with a lr of 2e6 and still the loss is random. Never consistently going down.
How are you guys testing this? It does seem to matter when doing a full fine tune, and a lora fine tune. Using FA2 I could never get the loss to even be consistent with a full fine tune ( with SFT). Right now I am doing a DPO of phi2 with QLORA, and the loss is not only consistent, it’s consistently going down; from .69 to .27 at just a single epoch.
I have not tried SFT with a lora, but maybe if we wanna use FA2 its just better to stick with using lora.
hi there, now that SDPA has been merged #29108 you can use FA-2 through pytorch interface:
0- Install pytorch 2.2
1- make sure to load SDPA attention by passing attn_implementation="sdpa"
in from_pretrained
2- Force-dispatch the SDPA kernel to use FA2 as follows:
- trainer.train()
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
+ trainer.train()
Perhaps this will fix all instability issues with respect to FA2 !
@younesbelkada Hey! Thanks for your response!(before starting my training run) I got pytorch2.2, and I pulled the latest commits from transformers and installed from source. I’m using the DPO.py, from TRL, and I saw the commit, so I tried to pass “—attn_implementation SPDA” but it gave me a SPDA not currently supported error, I wish I still had the error up, I’ll try it out again, once my training run ends in a little less than an hour. However I had only tried and pass it as a flag, not how you are just now telling me.
Hi @NickWithBotronics !
You need to use transformers from source: pip install -U git+https://github.com/huggingface/transformers