transformers
transformers copied to clipboard
Llama 2 model divergence with FSDP
System Info
-
transformers
version: 4.37.1 - Platform: Linux-5.10.199-190.747.amzn2.x86_64-x86_64-with-glibc2.31
- Python version: 3.10.8
- Huggingface_hub version: 0.20.2
- Safetensors version: 0.3.3
- Accelerate version: 0.26.1
- Accelerate config: not found
- PyTorch version (GPU?): 2.1.2 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: yes
Who can help?
When fine-tuning Llama 2 model with HF 4.37 and PT FSDP, found model divergence in comparison to HF 4.31. Fine-tuning with 4.31 works fine, but with HF 4.37, the loss consistently rises instead of stabilizing when setting attn_implementation="flash_attention_2", while attn_implementation="sdpa" works fine.
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [X] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
The model is inited as
model = AutoModelForCausalLM.from_pretrained(pretrained_model_weights, attn_implementation="flash_attention_2")
Expected behavior
The loss should not go up as the training goes.
cc @younesbelkada I think we have seen something similar recently?
@Teng-xu are you correctly enabling mixed precision through bf16=True
in TrainingArguments
?
Yeah bf16 was passed into the training args, and I can verify it is being applied correctly.
Just to provide more context on this issue I am attaching a simple script to reproduce the issue and its associated output. Note, I am just using a random tensor as the dataset and for consistency I just saved the labels associated from another training script and loaded it from a pickle object.
Script:
import functools
import numpy as np
import torch
# pylint: disable=import-error,import-outside-toplevel,invalid-name,no-member,no-name-in-module,protected-access
import transformers
from fsdp_utils import get_transformer_layer
from learning_rates import AnnealingLR # pylint: disable=wrong-import-order
from logging_utils import get_logger
from packaging import version as pversion
from torch.nn import LayerNorm
from transformers import AutoModelForCausalLM
from transformers.models.llama.modeling_llama import LlamaRMSNorm
#model init
# flash_attention_2, sdpa, eager
model1 = AutoModelForCausalLM.from_pretrained(pretrained_model_weights, attn_implementation="flash_attention_2")
model2 = AutoModelForCausalLM.from_pretrained(pretrained_model_weights, attn_implementation="sdpa")
model1.model.layers = model1.model.layers[:4]
model2.model.layers = model2.model.layers[:4]
model1 = model1.type(torch.bfloat16)
model2 = model2.type(torch.bfloat16)
model1 = model1.to("cuda")
model2 = model2.to("cuda")
# creating dummy tensor
tensor = torch.randint(low=0, high=9, size=(1, 4096), dtype=torch.int32).to("cuda")
#tensor = torch.randint([1, 4096], dtype=torch.int32).to("cuda")
import pickle
labels = pickle.load( open( "labels.p", "rb" ) ).to("cuda")
# model fwd/bwd pass
out1 = model1(input_ids=tensor, attention_mask=None, labels=labels)
loss1 = out1["loss"]
logits1 = out1["logits"]
out2 = model2(input_ids=tensor, attention_mask=None, labels=labels)
loss2 = out2["loss"]
logits2 = out2["logits"]
# model output cmp
if torch.allclose(logits1, logits2, atol=1e-0):
print("logits equal~~~~~~~~~")
else:
print("logits not equal~~~~~~~~~~")
print("logits 1:")
print(logits1)
print("logits 2:")
print(logits2)
print("max diff between logits:")
print(torch.max(torch.abs(logits1 - logits2)))
loss1.backward()
loss2.backward()
print("loss 1:")
print(loss1)
print("loss 2:")
print(loss2)
if (torch.allclose(loss1, loss2)):
print("loss equal~~~~~~~~~")
else:
print("loss not equal~~~~~~~~~~")
Output of script:
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. No dtype was provided, you should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator.
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. No dtype was provided, you should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:15<00:00, 7.91s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:24<00:00, 12.15s/it]
logits equal~~~~~~~~~
logits 1:
tensor([[[-1.3047, -2.2812, 2.2500, ..., -1.6094, -1.5078, 0.6914],
[-2.1094, -4.6875, 1.2031, ..., -1.1484, -1.7109, -0.4336],
[-1.1719, -4.6562, 0.3516, ..., 0.3301, -0.9727, 0.2852],
...,
[-2.2188, 8.8125, 1.4219, ..., -1.3906, -1.7266, -3.6250],
[-0.9844, 11.0625, 0.7617, ..., -0.4609, 0.0225, -2.7188],
[-1.0234, 10.8750, 0.8125, ..., -0.4395, -0.1641, -2.7656]]],
device='cuda:0', grad_fn=<ToCopyBackward0>)
logits 2:
tensor([[[-1.3047e+00, -2.2812e+00, 2.2500e+00, ..., -1.6094e+00,
-1.5078e+00, 6.9141e-01],
[-2.1094e+00, -4.6875e+00, 1.2031e+00, ..., -1.1484e+00,
-1.7109e+00, -4.3359e-01],
[-1.1719e+00, -4.6562e+00, 3.5156e-01, ..., 3.3008e-01,
-9.7266e-01, 2.8516e-01],
...,
[-2.2188e+00, 8.8125e+00, 1.4297e+00, ..., -1.3984e+00,
-1.7344e+00, -3.6562e+00],
[-9.8047e-01, 1.1062e+01, 7.5391e-01, ..., -4.3945e-01,
-3.9673e-04, -2.7188e+00],
[-1.0391e+00, 1.0875e+01, 8.1641e-01, ..., -4.4922e-01,
-1.7188e-01, -2.7812e+00]]], device='cuda:0',
grad_fn=<ToCopyBackward0>)
max diff between logits:
tensor(0.2500, device='cuda:0', grad_fn=<MaxBackward1>)
loss 1:
tensor(13.4215, device='cuda:0', grad_fn=<NllLossBackward0>)
loss 2:
tensor(13.4206, device='cuda:0', grad_fn=<NllLossBackward0>)
loss not equal~~~~~~~~~~```
Tagging @pacman100 to take a look.
Hi @rnadimp Thanks for the snippet ! I am not surprised to see that there is a relatively small difference between SDPA and FA2. The diff you shared is quite small and acceptable IMO, note that even though FA2 guarantees numerically identical results against SDPA, in practice due to kernels being different, there is always going to be a small difference between both implementations.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.