axolotl icon indicating copy to clipboard operation
axolotl copied to clipboard

OOM On Galore Axolotl

Open m626zNq opened this issue 10 months ago • 13 comments

Please check that this issue hasn't been reported before.

  • [X] I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

Should start training without OOM, like Llama factory.

Current behaviour

Causing OOM issue on axolotl with my config. LLaMA Factory acted fine but axolotl is hating on me. On llama factory i was able to do 16bit, and 1024 rank, and 8k context, worked fine on same gpu. axolotl wont even work with 8bit and 128 rank, at 4k context,(out of mem)

I have tried:

  • Enabling gradient checkpointing
  • Disabling sample packing
  • Lowering rank
  • Enabling use_reentrant

Steps to reproduce

install galore: pip install galore-torch run the config posted below

  • Make sure the GPU is A6000-48GB

Config yaml

base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer

load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
  - path: Walmart-the-bag/alpaca-ingen
    type:
      field_instruction: instruction
      field_output: output
      format: "\n### Instruction:\n{instruction}\n### Response:\n"
      no_input_format: "\n### Instruction:\n{instruction}\n### Response:\n"
dataset_prepared_path:
val_set_size: 0.05
output_dir: /notebooks/output

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
optim_args:
  rank: 128
  update_proj_gap: 200
  scale: 0.25
  proj_type: std
optim_target_modules:
  - q_proj
  - v_proj
  - linear
  - mlp
  - attn
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: galore_adafactor
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: false
fp16: true
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false

warmup_steps: 10
evals_per_epoch: 0
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"

Possible solution

No response

Which Operating Systems are you using?

  • [X] Linux
  • [ ] macOS
  • [ ] Windows

Python Version

3.11

axolotl branch-commit

main

Acknowledgements

  • [X] My issue title is concise, descriptive, and in title casing.
  • [X] I have searched the existing issues to make sure this bug has not been reported yet.
  • [X] I am using the latest version of axolotl.
  • [X] I have provided enough information for the maintainers to reproduce and diagnose the issue.

m626zNq avatar Mar 27 '24 20:03 m626zNq