flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Large gradients when used on Pythia-14m

Open tomtseng opened this issue 7 months ago • 0 comments

Question

I'm seeing very large gradients when I use flash attention 2 + bfloat16 on Pythia-14m. I'm loading my models with transformers.

Steps:

  • Load EleutherAI/pythia-14m
  • Run it on a test sentence, e.g., "This is a test sentence to compare attention outputs."
  • Run .loss.backward() on the output
  • Look at the gradients, compare with flash attention on and off.

Here's sample code for printing out the diff of the gradients against float32 on an example input:

import argparse

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.nn.functional import cosine_similarity

DEVICE = torch.device("cuda")


def get_output(model_name, input_ids, use_flash_attention, dtype):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        attn_implementation="flash_attention_2" if use_flash_attention else "eager",
        torch_dtype=dtype,
    ).to(DEVICE)

    out = model(input_ids=input_ids[:, :-1], labels=input_ids[:, 1:])
    out.loss.backward()
    grads = {k: v.grad for k, v in model.named_parameters()}
    return model, out, grads


def compare_attention_outputs(model_name, input_text):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)

    ref_model, ref_out, ref_grads = get_output(
        model_name, input_ids, use_flash_attention=False, dtype=torch.float32
    )
    for dtype in [torch.float16, torch.bfloat16]:
        for use_flash_attention in [False, True]:
            print(f"{'-' * 50}\n{dtype}, flash_attention={use_flash_attention}")
            model, out, _ = get_output(
                model_name,
                input_ids,
                use_flash_attention=use_flash_attention,
                dtype=dtype,
            )

            print("Loss difference:", (out.loss - ref_out.loss).item())
            print(
                "Logits difference / sqrt(dim):",
                (out.logits - ref_out.logits).norm().item() / out.logits.numel() ** 0.5,
            )

            print("Gradient difference norm / sqrt(dim):")
            for name, param in model.named_parameters():
                error = (
                    param.grad - ref_grads[name]
                ).norm().item() / param.numel() ** 0.5
                if error > 0.1:
                    print(name, " " * (50 - len(name)), error)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        type=str,
        help="model name",
        default="EleutherAI/pythia-14m",
    )
    args = parser.parse_args()

    input_text = "This is a test sentence to compare attention outputs."
    compare_attention_outputs(args.model, input_text)

For instance, model.gpt_neox.layers.0.attention.dense.bias.grad[:5] is [0.2433, 0.2482, 0.4011, -0.4509, 0.5477] with float32, [-0.1240, -0.7500, -0.4961, 0.0400, 1.1797] with bfloat16 (still pretty wrong but at least the right order of magnitude), and [-12.1250, -11.3750, -203.0000, 40.5000, 123.0000] with bfloat16 + flash attention 2. This doesn't happen with larger models—the gradients on Pythia-410m look more sane.

What could be going on here? Is it simply the result of precision errors accumulating?

Versions

Python 3.10.14 Torch 2.2.2 CUDA 12.1 (Docker image pytorch/pytorch:2.2.2-cuda12.1-cudnn8-devel) flash-attn 2.5.9.post1 transformers 4.41.2 A6000 GPU

tomtseng avatar Jul 12 '24 00:07 tomtseng