axolotl
axolotl copied to clipboard
OOM On Galore Axolotl
Please check that this issue hasn't been reported before.
- [X] I searched previous Bug Reports didn't find any similar reports.
Expected Behavior
Should start training without OOM, like Llama factory.
Current behaviour
Causing OOM issue on axolotl with my config. LLaMA Factory acted fine but axolotl is hating on me. On llama factory i was able to do 16bit, and 1024 rank, and 8k context, worked fine on same gpu. axolotl wont even work with 8bit and 128 rank, at 4k context,(out of mem)
I have tried:
- Enabling gradient checkpointing
- Disabling sample packing
- Lowering rank
- Enabling use_reentrant
Steps to reproduce
install galore: pip install galore-torch run the config posted below
- Make sure the GPU is A6000-48GB
Config yaml
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: Walmart-the-bag/alpaca-ingen
type:
field_instruction: instruction
field_output: output
format: "\n### Instruction:\n{instruction}\n### Response:\n"
no_input_format: "\n### Instruction:\n{instruction}\n### Response:\n"
dataset_prepared_path:
val_set_size: 0.05
output_dir: /notebooks/output
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
optim_args:
rank: 128
update_proj_gap: 200
scale: 0.25
proj_type: std
optim_target_modules:
- q_proj
- v_proj
- linear
- mlp
- attn
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: galore_adafactor
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: false
fp16: true
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
warmup_steps: 10
evals_per_epoch: 0
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
Possible solution
No response
Which Operating Systems are you using?
- [X] Linux
- [ ] macOS
- [ ] Windows
Python Version
3.11
axolotl branch-commit
main
Acknowledgements
- [X] My issue title is concise, descriptive, and in title casing.
- [X] I have searched the existing issues to make sure this bug has not been reported yet.
- [X] I am using the latest version of axolotl.
- [X] I have provided enough information for the maintainers to reproduce and diagnose the issue.