flash-attention
flash-attention copied to clipboard
Large gradients when used on Pythia-14m
Question
I'm seeing very large gradients when I use flash attention 2 + bfloat16 on Pythia-14m. I'm loading my models with transformers
.
Steps:
- Load
EleutherAI/pythia-14m
- Run it on a test sentence, e.g.,
"This is a test sentence to compare attention outputs."
- Run
.loss.backward()
on the output - Look at the gradients, compare with flash attention on and off.
Here's sample code for printing out the diff of the gradients against float32 on an example input:
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.nn.functional import cosine_similarity
DEVICE = torch.device("cuda")
def get_output(model_name, input_ids, use_flash_attention, dtype):
model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="flash_attention_2" if use_flash_attention else "eager",
torch_dtype=dtype,
).to(DEVICE)
out = model(input_ids=input_ids[:, :-1], labels=input_ids[:, 1:])
out.loss.backward()
grads = {k: v.grad for k, v in model.named_parameters()}
return model, out, grads
def compare_attention_outputs(model_name, input_text):
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
ref_model, ref_out, ref_grads = get_output(
model_name, input_ids, use_flash_attention=False, dtype=torch.float32
)
for dtype in [torch.float16, torch.bfloat16]:
for use_flash_attention in [False, True]:
print(f"{'-' * 50}\n{dtype}, flash_attention={use_flash_attention}")
model, out, _ = get_output(
model_name,
input_ids,
use_flash_attention=use_flash_attention,
dtype=dtype,
)
print("Loss difference:", (out.loss - ref_out.loss).item())
print(
"Logits difference / sqrt(dim):",
(out.logits - ref_out.logits).norm().item() / out.logits.numel() ** 0.5,
)
print("Gradient difference norm / sqrt(dim):")
for name, param in model.named_parameters():
error = (
param.grad - ref_grads[name]
).norm().item() / param.numel() ** 0.5
if error > 0.1:
print(name, " " * (50 - len(name)), error)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
help="model name",
default="EleutherAI/pythia-14m",
)
args = parser.parse_args()
input_text = "This is a test sentence to compare attention outputs."
compare_attention_outputs(args.model, input_text)
For instance, model.gpt_neox.layers.0.attention.dense.bias.grad[:5]
is [0.2433, 0.2482, 0.4011, -0.4509, 0.5477]
with float32, [-0.1240, -0.7500, -0.4961, 0.0400, 1.1797]
with bfloat16 (still pretty wrong but at least the right order of magnitude), and [-12.1250, -11.3750, -203.0000, 40.5000, 123.0000]
with bfloat16 + flash attention 2. This doesn't happen with larger models—the gradients on Pythia-410m look more sane.
What could be going on here? Is it simply the result of precision errors accumulating?
Versions
Python 3.10.14 Torch 2.2.2 CUDA 12.1 (Docker image pytorch/pytorch:2.2.2-cuda12.1-cudnn8-devel) flash-attn 2.5.9.post1 transformers 4.41.2 A6000 GPU