flash-attention
flash-attention copied to clipboard
FlashAttention works with single GPU, but crash with accelerate DP on multiple GPU (FlashAttention only support fp16 and bf16 data type)
System Info
`Accelerate` version: 0.22.0
Platform: Linux-5.10.192-183.736.amzn2.x86_64-x86_64-with-glibc2.29
Python version: 3.8.10
Numpy version: 1.23.1
PyTorch version (GPU?): 2.0.1+cu117 (True)
PyTorch XPU available: False
PyTorch NPU available: False
System RAM: 1121.81 GB
GPU type: NVIDIA A100-SXM4-80GB
transformers 4.37.2
trl 0.7.11.dev0
flash-attn 2.5.2
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type
Reproduction
The following script works as expected on 1 GPU, but if running on multiple GPU with DP, it will give error: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( RuntimeError: FlashAttention only support fp16 and bf16 data type
import os
import wandb
import torch
from accelerate import Accelerator
from datasets import load_from_disk
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
AutoTokenizer,
TrainingArguments
)
from trl import DataCollatorForCompletionOnlyLM
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, AutoTokenizer
import sys
project_root = '/'.join(os.path.dirname(__file__).split('/')[:-1])
print(project_root)
sys.path.append(project_root)
from utils.meta_loader import write_meta, read_meta
import transformers
# bench
alpha = 16
rank = 64
batch_size = 2
length = 4096
accumlate_steps = 1
lr = 5e-5
train_dataset = load_from_disk("/mnt/localssd/dataset/llava_processed_dataset/train")
eval_dataset = load_from_disk("/mnt/localssd/dataset/llava_processed_dataset/test")
run_name = "llava_debug"
save_dir = "/mnt/localssd/llava_debug"
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
# load_in_8bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
# llm_int8_skip_modules=["multi_modal_projector"]
)
model = LlavaForConditionalGeneration.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
# "llava-hf/bakLlava-v1-hf",
quantization_config=bnb_config,
trust_remote_code=True,
device_map={'':torch.cuda.current_device()},
torch_dtype=torch.float16,
use_flash_attention_2=True
)
target_modules = [
"*language_model.*q_proj",
"*language_model.*k_proj",
"*language_model.*v_proj",
"*language_model.*o_proj",
"*language_model.*gate_proj",
"*language_model.*up_proj",
"*language_model.*down_proj",
"*language_model.*lm_head"]
modules_to_save = ["linear_1", "linear_2"]
peft_config = LoraConfig(
lora_alpha=alpha,
lora_dropout=0.1,
r=rank,
bias="none",
task_type="CAUSAL_LM",
target_modules=target_modules,
modules_to_save=modules_to_save
)
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-1.5-7b-hf", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
model = get_peft_model(model, peft_config)
training_arguments = TrainingArguments(
output_dir=save_dir,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=accumlate_steps,
optim="paged_adamw_32bit",
save_steps=500,
logging_steps=10,
learning_rate=lr,
fp16=True,
max_grad_norm=0.3,
num_train_epochs=100,
warmup_ratio=0.03,
# group_by_length=True,
lr_scheduler_type="constant",
run_name=run_name,
evaluation_strategy="steps",
eval_steps=200,
ddp_find_unused_parameters=False,
gradient_checkpointing=True,
# weight_decay=0.01,
# dataloader_num_workers=NUM_PROC//2
)
model.config.use_cache = False # not use for fine tuning
def test_data_collator(datas):
result = {}
input_ids = [torch.Tensor(d['input_ids']) for d in datas]
attention_mask = [torch.Tensor(d['attention_mask']) for d in datas]
pixel_values = [torch.Tensor(d['pixel_values']) for d in datas]
labels = [torch.Tensor(d['labels']) for d in datas]
result['input_ids'] = torch.concat(input_ids).type(torch.int64)
result['attention_mask'] = torch.concat(attention_mask).type(torch.int64)
result['pixel_values'] = torch.concat(pixel_values)
result['labels'] = torch.concat(labels).type(torch.int64)
return result
trainer = transformers.Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_arguments,
data_collator=test_data_collator
)
trainer.train()```
### Expected behavior
Expect the behavior should be the same for both single GPU and Multi GPU
I'm not familiar with accelerate or how transformers uses FlashAttention, you'd probably get better help asking on those repos.
I am getting a similar issue without training with torch nightly on Llama so can confirm something's wrong! Might be on our side, but as far as I tested all the inputs's dtypes were bfloat16, still got the issue.
Reproducer is here with attn_implementation="flash_attention_2" and the corresponding PR on transformers.
- `transformers` version: 4.38.0.dev0
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
- Python version: 3.10.0
- Huggingface_hub version: 0.20.3
- Safetensors version: 0.4.2
- Accelerate version: 0.27.0
- Accelerate config: not found
- PyTorch version (GPU?): 2.3.0.dev20240208+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
flash_attn=2.5.3+ torch nightly so (2.3 ish)
>>> from flash_attn import flash_attn_func
>>> import torch
>>> print(torch.__version__)
2.3.0.dev20240208+cu121
>>> flash_attn_func(torch.ones((2,3), dtype=torch.bfloat16), torch.ones((2,3), dtype=torch.bfloat16), torch.ones((2,3), dtype=torch.bfloat16), 1, softmax_scale=1, causal=True)
....
File ~/miniconda3/envs/py310/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py:51, in _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax)
49 maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
50 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
---> 51 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
52 q,
53 k,
54 v,
55 None,
56 alibi_slopes,
57 dropout_p,
58 softmax_scale,
59 causal,
60 window_size[0],
61 window_size[1],
62 return_softmax,
63 None,
64 )
65 return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
RuntimeError: FlashAttention only support fp16 and bf16 data type
this doesn't work for me again, might be because I have. cc @tridao not sure how relevant this is
this doesn't work for me again, might be because I have. cc @tridao not sure how relevant this is
The q, k, v need to be on 'cuda' and have shape (batch, seqlen, nheads, headdim).
The error is before that, but it seems it's torch nightly, the transformers snippet works with torch2.2 ! (vs getting the FlashAttention only support fp16 and bf16 data type with nightly)
So more reliable.
(I am getting RuntimeError: q must be on CUDA with my snippet on torch2.2 so different error)
I am getting a similar issue without training with torch nightly on Llama so can confirm something's wrong! Might be on our side, but as far as I tested all the inputs's dtypes were bfloat16, still got the issue. Reproducer is here with
attn_implementation="flash_attention_2"and the corresponding PR ontransformers.- `transformers` version: 4.38.0.dev0 - Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31 - Python version: 3.10.0 - Huggingface_hub version: 0.20.3 - Safetensors version: 0.4.2 - Accelerate version: 0.27.0 - Accelerate config: not found - PyTorch version (GPU?): 2.3.0.dev20240208+cu121 (True) - Tensorflow version (GPU?): not installed (NA) - Flax version (CPU?/GPU?/TPU?): not installed (NA) - Jax version: not installed - JaxLib version: not installed - Using GPU in script?: <fill in> - Using distributed or parallel set-up in script?: <fill in>
flash_attn=2.5.3+ torch nightly so (2.3 ish)
I can't run the reproducer right now bc StaticCache is not in transformers 4.37.2 (latest stable version).
this doesn't work for me again, might be because I have. cc @tridao not sure how relevant this is
The q, k, v need to be on 'cuda' and have shape (batch, seqlen, nheads, headdim).
Yeah flash attention uses (batch , seqlen, nheads, headdim ) to represent inputs, however in many software (triton, for example) we have reasons to use (batch, nheads, seqlen, headim) for easy arrangement of layout.
Actually they are equivalent with this mapping:
def permute(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.nheads, self.headim)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
But it is weird that the error (I have tested in the lastest version) says "FlashAttention only support fp16 and bf16 data type".
# mha_fwd https://github.com/Dao-AILab/flash-attention/blob/6bbc532388e61185a92e2a563126739967b4c8c5/csrc/flash_attn/flash_api.cpp#L339-L339
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
// We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
if (q_dtype == torch::kBFloat16) {
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
}
I have checked the repo, we need to update our C++ templates to support various dtype, I have experiences in near memory chip op libs. Currently I have to do these unnecessary cast to help teams to use flash attention v2:
if q.dtype == torch.float32:
q = q.to(torch.float16, non_blocking=True)
k = k.to(torch.float16, non_blocking=True)
v = v.to(torch.float16, non_blocking=True)
elif q.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz):
capability = torch.cuda.get_device_capability()
if capability[0] <= 8:
raise RuntimeError("Flash attention for FP8 (need hoper TE support) is currently only supported for compute capability >= 80")
else:
# TODO (yiakwy) : add FP8 support
raise NotImplemented
output = flash_attn_func(q, k, v, dropout_p=self.dropout.p, causal=is_causal)
output = revert_mold_flash_attn_input(output)
if output_attentions:
raise Exception("Does not support output attention weights inside flash attention.")
if output.dtype != torch.float32:
# TODO (yiakwy) : add support of fp16 and bf16
# if output dtype is not FP32 (by default Flash attetnion generate FP16 output), we need to cast it back
output = output.to(torch.float32, non_blocking=True)
So we need to update the error information, right ?
I confirm that flash-attn==2.5.6 doesn't work with torch==2.3.0a0+40ec155e58.nv24.3 nightly even though inputs are indeed in torch.bfloat16 format! I rolled back to torch2.2 stable and reinstalled flash-attn and now it works.