axolotl icon indicating copy to clipboard operation
axolotl copied to clipboard

mistral fsdpa qlora crashes (cu_seqlens)

Open lucyknada opened this issue 1 year ago • 2 comments

Please check that this issue hasn't been reported before.

  • [X] I searched previous Bug Reports didn't find any similar reports.

Current behaviour

crashes with:

TypeError: MistralSdpaAttention.forward() got an unexpected keyword argument 'cu_seqlens'

Steps to reproduce

docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ./cache:/cache/huggingface winglian/axolotl:main-py3.11-cu121-2.1.2

run that docker, modify the mistral qlora example

fsdp:
  - full_shard
  - auto_wrap
fsdp_config:
  fsdp_offload_params: true
  # fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer

(no matter the options, it'll still crash)

Config yaml

base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer

load_in_8bit: false
load_in_4bit: true
strict: false

eval_sample_packing: False
datasets:
  - path: /workspace/axolotl/xxx.json
    type: completion
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./qlora-out

adapter: qlora
lora_model_dir:

sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
  - full_shard
  # - auto_wrap
fsdp_config:
  fsdp_offload_params: true
  # fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"

Possible solution

No response

Which Operating Systems are you using?

  • [X] Linux
  • [ ] macOS
  • [ ] Windows

Python Version

3.10, 3.11

axolotl branch-commit

9b6ee83a73d5ffbdc33cfb383a131a08c2b594ff

Acknowledgements

  • [X] My issue title is concise, descriptive, and in title casing.
  • [X] I have searched the existing issues to make sure this bug has not been reported yet.
  • [X] I am using the latest version of axolotl.
  • [X] I have provided enough information for the maintainers to reproduce and diagnose the issue.

lucyknada avatar Mar 10 '24 20:03 lucyknada

Can you verify that flash attention is installed?

winglian avatar Mar 10 '24 23:03 winglian

# pip show flash-attn
Name: flash-attn
Version: 2.5.5
Summary: Flash Attention: Fast and Memory-Efficient Exact Attention
Home-page: https://github.com/Dao-AILab/flash-attention
Author: Tri Dao
Author-email: [email protected]
License: 
Location: /root/miniconda3/envs/py3.11/lib/python3.11/site-packages
Requires: einops, ninja, packaging, torch
Required-by:

seems like it (inside the docker image)

lucyknada avatar Mar 10 '24 23:03 lucyknada