LLaVA-NeXT
LLaVA-NeXT copied to clipboard
How do we turn off Flash attention in LLaVA-NeXT?
Since my server environment does not seem to support Ampere GPU, I have been trying to disable Flash attention.
First, I simply brought the train_xformers.py and llama_xformers_attn_monkey_patch.py files to my directory so I can use xformers instead of train_mem.py on LLaVA-NeXT as well.
Second, I removed 'attn_implementation' argument to completely disable the usage of Flash attention.
However, none of the settings help me not face "RuntimeError: FlashAttention only supports Ampere GPUs or newer." issue.
Does anyone know how else I can try disabling Flash attention?
I will share with you my shell file for finetyning siglip A4 below.
LLM_VERSION="mylesgoose/Meta-Llama-3.1-8B-Instruct-goose-abliterated" LLM_VERSION_CLEAN="${LLM_VERSION////}" VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION////}"
############### Pretrain ################
PROMPT_VERSION=plain
BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain" echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
CKPT_PATH=$LLM_VERSION
deepspeed llava/train/train_xformers.py \
--lora_enable True --lora_r 16 --lora_alpha 256 --mm_projector_lr 2e-5 \
--deepspeed scripts/zero3_offload_new.json \
--model_name_or_path ${CKPT_PATH} \
--version ${PROMPT_VERSION} \
--data_path ./playground/floorplan_vqa_1000.json \
--image_folder /home/work/testdataset1/LLaVA/playground/data/floorplan_data/ \
--pretrain_mm_mlp_adapter="/home/work/testdataset1/LLaVA-NeXT/checkpoints/projectors/llavanext-google_siglip-so400m-patch14-384-mylesgoose_Meta-Llama-3.1-8B-Instruct-goose-abliterated-mlp2x_gelu-pretrain_blip558k_plain/checkpoint-1500/mm_projector.bin" \
--mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \
--mm_vision_tower_lr=2e-6 \
--vision_tower ${VISION_MODEL_VERSION} \
--mm_projector_type mlp2x_gelu \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--group_by_modality_length True \
--image_aspect_ratio anyres \
--image_grid_pinpoints "[(384, 768), (768, 384), (768, 768), (1152, 384), (384, 1152)]" \
--mm_patch_merge_type spatial_unpad \
--fp16 True \
--bf16 False \
--output_dir "./checkpoints/Meta-Llama-3.1-8B-Instruct-goose-abliterated-pre" \
--num_train_epochs 1 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 16 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 500 \
--save_total_limit 2 \
--learning_rate 1e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 False \
--model_max_length 1024 \
--gradient_checkpointing True \
--dataloader_num_workers 0 \
--lazy_preprocess True \
--report_to wandb \
--torch_compile False \
--torch_compile_backend "inductor" \
--dataloader_drop_last True \
--run_name llavanext-siglip-400m-Meta-Llama-3.1-8B-pretrain_blip558k_plain \
Error message:
UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants. return fn(*args, **kwargs) /home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:600: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants. return fn(*args, **kwargs) Traceback (most recent call last): [rank1]: Traceback (most recent call last): [rank1]: File "/home/work/testdataset1/LLaVA-NeXT/llava/train/train_xformers.py", line 13, in
[rank1]: train() [rank1]: File "/home/work/testdataset1/LLaVA-NeXT/llava/train/train_5img.py", line 1672, in train [rank1]: trainer.train() [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train [rank1]: return inner_training_loop( [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop [rank1]: tr_loss_step = self.training_step(model, inputs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3318, in training_step [rank1]: loss = self.compute_loss(model, inputs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3363, in compute_loss [rank1]: outputs = model(**inputs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank1]: return self._call_impl(*args, **kwargs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl [rank1]: return forward_call(*args, **kwargs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn [rank1]: ret_val = func(*args, **kwargs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1846, in forward [rank1]: loss = self.module(*inputs, **kwargs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank1]: return self._call_impl(*args, **kwargs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl [rank1]: result = forward_call(*args, **kwargs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/peft/peft_model.py", line 1577, in forward [rank1]: return self.base_model( [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank1]: return self._call_impl(*args, **kwargs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl [rank1]: result = forward_call(*args, **kwargs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 188, in forward [rank1]: return self.model.forward(*args, **kwargs) [rank1]: File "/home/work/testdataset1/LLaVA-NeXT/llava/model/language_model/llava_llama.py", line 109, in forward [rank1]: return super().forward( [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1139, in forward [rank1]: outputs = self.model( [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank1]: return self._call_impl(*args, **kwargs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl [rank1]: result = forward_call(*args, **kwargs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 930, in forward [rank1]: layer_outputs = self._gradient_checkpointing_func( [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner [rank1]: return disable_fn(*args, **kwargs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn [rank1]: return fn(*args, **kwargs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint [rank1]: return CheckpointFunction.apply(function, preserve, *args) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply [rank1]: return super().apply(*args, **kwargs) # type: ignore[misc] [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 255, in forward [rank1]: outputs = run_function(*args) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank1]: return self._call_impl(*args, **kwargs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl [rank1]: result = forward_call(*args, **kwargs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 677, in forward [rank1]: hidden_states, self_attn_weights, present_key_value = self.self_attn( [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank1]: return self._call_impl(*args, **kwargs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl [rank1]: result = forward_call(*args, **kwargs) [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 500, in forward [rank1]: attn_output = _flash_attention_forward( [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 214, in _flash_attention_forward [rank1]: attn_output = flash_attn_func( [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 880, in flash_attn_func [rank1]: return FlashAttnFunc.apply( [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply [rank1]: return super().apply(*args, **kwargs) # type: ignore[misc] [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 546, in forward [rank1]: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( [rank1]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_forward [rank1]: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( [rank1]: RuntimeError: FlashAttention only supports Ampere GPUs or newer. File "/home/work/testdataset1/LLaVA-NeXT/llava/train/train_xformers.py", line 13, in train() File "/home/work/testdataset1/LLaVA-NeXT/llava/train/train_5img.py", line 1672, in train trainer.train() File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train return inner_training_loop( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop tr_loss_step = self.training_step(model, inputs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3318, in training_step loss = self.compute_loss(model, inputs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3363, in compute_loss outputs = model(**inputs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn ret_val = func(*args, **kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1846, in forward loss = self.module(*inputs, **kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl result = forward_call(*args, **kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/peft/peft_model.py", line 1577, in forward return self.base_model( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl result = forward_call(*args, **kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 188, in forward return self.model.forward(*args, **kwargs) File "/home/work/testdataset1/LLaVA-NeXT/llava/model/language_model/llava_llama.py", line 109, in forward return super().forward( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1139, in forward outputs = self.model( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl result = forward_call(*args, **kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 930, in forward layer_outputs = self._gradient_checkpointing_func( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner return disable_fn(*args, **kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn return fn(*args, **kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint return CheckpointFunction.apply(function, preserve, *args) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 255, in forward outputs = run_function(*args) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl result = forward_call(*args, **kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 677, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl result = forward_call(*args, **kwargs) File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 500, in forward attn_output = _flash_attention_forward( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 214, in _flash_attention_forward attn_output = flash_attn_func( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 880, in flash_attn_func return FlashAttnFunc.apply( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 546, in forward out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_forward out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( RuntimeError: FlashAttention only supports Ampere GPUs or newer. [rank0]: Traceback (most recent call last): [rank0]: File "/home/work/testdataset1/LLaVA-NeXT/llava/train/train_xformers.py", line 13, in [rank0]: train() [rank0]: File "/home/work/testdataset1/LLaVA-NeXT/llava/train/train_5img.py", line 1672, in train [rank0]: trainer.train() [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train [rank0]: return inner_training_loop( [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop [rank0]: tr_loss_step = self.training_step(model, inputs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3318, in training_step [rank0]: loss = self.compute_loss(model, inputs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 3363, in compute_loss [rank0]: outputs = model(**inputs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn [rank0]: ret_val = func(*args, **kwargs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1846, in forward [rank0]: loss = self.module(*inputs, **kwargs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl [rank0]: result = forward_call(*args, **kwargs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/peft/peft_model.py", line 1577, in forward [rank0]: return self.base_model( [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl [rank0]: result = forward_call(*args, **kwargs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 188, in forward [rank0]: return self.model.forward(*args, **kwargs) [rank0]: File "/home/work/testdataset1/LLaVA-NeXT/llava/model/language_model/llava_llama.py", line 109, in forward [rank0]: return super().forward( [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1139, in forward [rank0]: outputs = self.model( [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl [rank0]: result = forward_call(*args, **kwargs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 930, in forward [rank0]: layer_outputs = self._gradient_checkpointing_func( [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner [rank0]: return disable_fn(*args, **kwargs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn [rank0]: return fn(*args, **kwargs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint [rank0]: return CheckpointFunction.apply(function, preserve, *args) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply [rank0]: return super().apply(*args, **kwargs) # type: ignore[misc] [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 255, in forward [rank0]: outputs = run_function(*args) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl [rank0]: result = forward_call(*args, **kwargs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 677, in forward [rank0]: hidden_states, self_attn_weights, present_key_value = self.self_attn( [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl [rank0]: result = forward_call(*args, **kwargs) [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 500, in forward [rank0]: attn_output = _flash_attention_forward( [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 214, in _flash_attention_forward [rank0]: attn_output = flash_attn_func( [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 880, in flash_attn_func [rank0]: return FlashAttnFunc.apply( [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply [rank0]: return super().apply(*args, **kwargs) # type: ignore[misc] [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 546, in forward [rank0]: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( [rank0]: File "/home/work/anaconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_forward [rank0]: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( [rank0]: RuntimeError: FlashAttention only supports Ampere GPUs or newer.
I send my appreciation to "mylesgoose" for providing with me his pretrained model.