flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

FlashAttention works with single GPU, but crash with accelerate DP on multiple GPU (FlashAttention only support fp16 and bf16 data type)

Open Andcircle opened this issue 1 year ago • 8 comments

System Info

`Accelerate` version: 0.22.0
Platform: Linux-5.10.192-183.736.amzn2.x86_64-x86_64-with-glibc2.29
Python version: 3.8.10
Numpy version: 1.23.1
PyTorch version (GPU?): 2.0.1+cu117 (True)
PyTorch XPU available: False
PyTorch NPU available: False
System RAM: 1121.81 GB
GPU type: NVIDIA A100-SXM4-80GB

transformers              4.37.2
trl                       0.7.11.dev0
flash-attn                2.5.2

out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type

Reproduction

The following script works as expected on 1 GPU, but if running on multiple GPU with DP, it will give error: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( RuntimeError: FlashAttention only support fp16 and bf16 data type

import os
import wandb

import torch
from accelerate import Accelerator
from datasets import load_from_disk
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    AutoTokenizer,
    TrainingArguments
)

from trl import DataCollatorForCompletionOnlyLM

from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, AutoTokenizer

import sys
project_root = '/'.join(os.path.dirname(__file__).split('/')[:-1])
print(project_root)
sys.path.append(project_root)
from utils.meta_loader import write_meta, read_meta

import transformers

# bench
alpha = 16
rank = 64
batch_size = 2
length = 4096
accumlate_steps = 1
lr = 5e-5

train_dataset = load_from_disk("/mnt/localssd/dataset/llava_processed_dataset/train")
eval_dataset = load_from_disk("/mnt/localssd/dataset/llava_processed_dataset/test")    

run_name = "llava_debug"
save_dir = "/mnt/localssd/llava_debug"

compute_dtype = getattr(torch, "float16")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    # load_in_8bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=True,
    # llm_int8_skip_modules=["multi_modal_projector"]
)

model = LlavaForConditionalGeneration.from_pretrained(
    "llava-hf/llava-1.5-7b-hf",
    # "llava-hf/bakLlava-v1-hf",
    quantization_config=bnb_config,
    trust_remote_code=True, 
    device_map={'':torch.cuda.current_device()},
    torch_dtype=torch.float16,
    use_flash_attention_2=True
    )

target_modules = [
    "*language_model.*q_proj", 
    "*language_model.*k_proj", 
    "*language_model.*v_proj", 
    "*language_model.*o_proj", 
    "*language_model.*gate_proj", 
    "*language_model.*up_proj", 
    "*language_model.*down_proj", 
    "*language_model.*lm_head"]

modules_to_save = ["linear_1", "linear_2"]
    
peft_config = LoraConfig(
    lora_alpha=alpha,
    lora_dropout=0.1,
    r=rank,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=target_modules,
    modules_to_save=modules_to_save
)

tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-1.5-7b-hf", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
model = get_peft_model(model, peft_config)

training_arguments = TrainingArguments(
    output_dir=save_dir,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=accumlate_steps,
    optim="paged_adamw_32bit",
    save_steps=500,
    logging_steps=10,
    learning_rate=lr,
    fp16=True,
    max_grad_norm=0.3,
    num_train_epochs=100,
    warmup_ratio=0.03,
    # group_by_length=True,
    lr_scheduler_type="constant",
    run_name=run_name,
    evaluation_strategy="steps",
    eval_steps=200,
    ddp_find_unused_parameters=False,
    gradient_checkpointing=True,
    # weight_decay=0.01,
    # dataloader_num_workers=NUM_PROC//2
)


model.config.use_cache = False # not use for fine tuning

def test_data_collator(datas):
    result = {}
    input_ids = [torch.Tensor(d['input_ids']) for d in datas]
    attention_mask = [torch.Tensor(d['attention_mask']) for d in datas]
    pixel_values = [torch.Tensor(d['pixel_values']) for d in datas]
    labels = [torch.Tensor(d['labels']) for d in datas]
    
    result['input_ids'] = torch.concat(input_ids).type(torch.int64)
    result['attention_mask'] = torch.concat(attention_mask).type(torch.int64)
    result['pixel_values'] = torch.concat(pixel_values)
    result['labels'] = torch.concat(labels).type(torch.int64)
    return result
    

trainer = transformers.Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=training_arguments,
    data_collator=test_data_collator
)

trainer.train()```

### Expected behavior

Expect the behavior should be the same for both single GPU and Multi GPU

Andcircle avatar Feb 10 '24 07:02 Andcircle

I'm not familiar with accelerate or how transformers uses FlashAttention, you'd probably get better help asking on those repos.

tridao avatar Feb 10 '24 08:02 tridao

I am getting a similar issue without training with torch nightly on Llama so can confirm something's wrong! Might be on our side, but as far as I tested all the inputs's dtypes were bfloat16, still got the issue. Reproducer is here with attn_implementation="flash_attention_2" and the corresponding PR on transformers.

- `transformers` version: 4.38.0.dev0
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
- Python version: 3.10.0
- Huggingface_hub version: 0.20.3
- Safetensors version: 0.4.2
- Accelerate version: 0.27.0
- Accelerate config:    not found
- PyTorch version (GPU?): 2.3.0.dev20240208+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
  • flash_attn=2.5.3 + torch nightly so (2.3 ish)

ArthurZucker avatar Feb 12 '24 07:02 ArthurZucker

>>> from flash_attn import flash_attn_func
>>> import torch
>>> print(torch.__version__)
2.3.0.dev20240208+cu121
>>> flash_attn_func(torch.ones((2,3), dtype=torch.bfloat16), torch.ones((2,3), dtype=torch.bfloat16), torch.ones((2,3), dtype=torch.bfloat16), 1, softmax_scale=1, causal=True)

....

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py:51, in _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax)
     49 maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
     50 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
---> 51 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
     52     q,
     53     k,
     54     v,
     55     None,
     56     alibi_slopes,
     57     dropout_p,
     58     softmax_scale,
     59     causal,
     60     window_size[0],
     61     window_size[1],
     62     return_softmax,
     63     None,
     64 )
     65 return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state

RuntimeError: FlashAttention only support fp16 and bf16 data type

this doesn't work for me again, might be because I have. cc @tridao not sure how relevant this is

ArthurZucker avatar Feb 12 '24 08:02 ArthurZucker

this doesn't work for me again, might be because I have. cc @tridao not sure how relevant this is

The q, k, v need to be on 'cuda' and have shape (batch, seqlen, nheads, headdim).

tridao avatar Feb 12 '24 08:02 tridao

The error is before that, but it seems it's torch nightly, the transformers snippet works with torch2.2 ! (vs getting the FlashAttention only support fp16 and bf16 data type with nightly)
So more reliable. (I am getting RuntimeError: q must be on CUDA with my snippet on torch2.2 so different error)

ArthurZucker avatar Feb 12 '24 08:02 ArthurZucker

I am getting a similar issue without training with torch nightly on Llama so can confirm something's wrong! Might be on our side, but as far as I tested all the inputs's dtypes were bfloat16, still got the issue. Reproducer is here with attn_implementation="flash_attention_2" and the corresponding PR on transformers.

- `transformers` version: 4.38.0.dev0
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
- Python version: 3.10.0
- Huggingface_hub version: 0.20.3
- Safetensors version: 0.4.2
- Accelerate version: 0.27.0
- Accelerate config:    not found
- PyTorch version (GPU?): 2.3.0.dev20240208+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
  • flash_attn=2.5.3 + torch nightly so (2.3 ish)

I can't run the reproducer right now bc StaticCache is not in transformers 4.37.2 (latest stable version).

tridao avatar Feb 12 '24 08:02 tridao

this doesn't work for me again, might be because I have. cc @tridao not sure how relevant this is

The q, k, v need to be on 'cuda' and have shape (batch, seqlen, nheads, headdim).

Yeah flash attention uses (batch , seqlen, nheads, headdim ) to represent inputs, however in many software (triton, for example) we have reasons to use (batch, nheads, seqlen, headim) for easy arrangement of layout.

Actually they are equivalent with this mapping:

    def permute(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.nheads, self.headim)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

But it is weird that the error (I have tested in the lastest version) says "FlashAttention only support fp16 and bf16 data type".

# mha_fwd https://github.com/Dao-AILab/flash-attention/blob/6bbc532388e61185a92e2a563126739967b4c8c5/csrc/flash_attn/flash_api.cpp#L339-L339
    bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
    bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
    TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
    // We will support Turing in the near future
    // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");

    auto q_dtype = q.dtype();
    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
                "FlashAttention only support fp16 and bf16 data type");
    if (q_dtype == torch::kBFloat16) {
        TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
    }

I have checked the repo, we need to update our C++ templates to support various dtype, I have experiences in near memory chip op libs. Currently I have to do these unnecessary cast to help teams to use flash attention v2:

    if q.dtype == torch.float32:
        q = q.to(torch.float16, non_blocking=True)
        k = k.to(torch.float16, non_blocking=True)
        v = v.to(torch.float16, non_blocking=True)
    elif q.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz):
        capability = torch.cuda.get_device_capability()
        if capability[0] <= 8:
            raise RuntimeError("Flash attention for FP8 (need hoper TE support) is currently only supported for compute capability >= 80")
        else:
            # TODO (yiakwy) : add FP8 support
            raise NotImplemented
        
    output = flash_attn_func(q, k, v, dropout_p=self.dropout.p, causal=is_causal)
    output = revert_mold_flash_attn_input(output)
        
    if output_attentions:
        raise Exception("Does not support output attention weights inside flash attention.")
    
    if output.dtype != torch.float32:
        # TODO (yiakwy) : add support of fp16 and bf16
        # if output dtype is not FP32 (by default Flash attetnion generate FP16 output), we need to cast it back
        output = output.to(torch.float32, non_blocking=True)

So we need to update the error information, right ?

I confirm that flash-attn==2.5.6 doesn't work with torch==2.3.0a0+40ec155e58.nv24.3 nightly even though inputs are indeed in torch.bfloat16 format! I rolled back to torch2.2 stable and reinstalled flash-attn and now it works.

thepowerfuldeez avatar Apr 02 '24 16:04 thepowerfuldeez