axolotl icon indicating copy to clipboard operation
axolotl copied to clipboard

WIP: galore optimizer

Open maximegmd opened this issue 3 months ago • 2 comments

Adds support for Galore optimizers

Still a WIP, untested.

maximegmd avatar Mar 07 '24 11:03 maximegmd

@maximegmd any chance you could provide an example config file on how to use this?

fakerybakery avatar Mar 07 '24 23:03 fakerybakery

@maximegmd any chance you could provide an example config file on how to use this?

Set the optimizer argument in the axolotl config to one of [galore_adamw, galore_adamw8bit, galore_ada_factor]. Probably galore_adamw8bit will give the biggest optimization.

casper-hansen avatar Mar 08 '24 15:03 casper-hansen

Hi ! I tried to upstream these changes into transformers so that you guys can directly leverage that in axolotl: https://github.com/huggingface/transformers/pull/29588 I am running some quick experiments so far it seems the training is quite slow, here is how I am running the training using Galore:

import torch
import datasets
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
import trl

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
    output_dir="./test-galore",
    max_steps=100,
    per_device_train_batch_size=2,
    optim="galore_adamw",
    galore_target_modules=["attn", "mlp"],
    gradient_checkpointing=True,
)

# model_id = "mistralai/Mistral-7B-v0.1"
model_id = "google/gemma-2b"

config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)

trainer = trl.SFTTrainer(
    model=model, 
    args=args,
    train_dataset=train_dataset,
    dataset_text_field='text',
    max_seq_length=512,
)

trainer.train()

younesbelkada avatar Mar 11 '24 12:03 younesbelkada

Got it working on Gemma-2b !

import torch
import datasets
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
import trl

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
    output_dir="./test-galore-new",
    max_steps=100,
    per_device_train_batch_size=2,
    optim="galore_adamw",
    galore_target_modules=["attn", "mlp"],
    gradient_checkpointing=True,
    logging_strategy="steps",
    logging_steps=5,
    learning_rate=2e-3,
    save_strategy="no",
    run_name="galore-imdb"
)

model_id = "google/gemma-2b"

config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)

trainer = trl.SFTTrainer(
    model=model, 
    args=args,
    train_dataset=train_dataset,
    dataset_text_field='text',
    max_seq_length=512,
)

trainer.train()

After ~50 steps:

{'loss': 11.8705, 'grad_norm': 13.43569564819336, 'learning_rate': 0.0019, 'epoch': 0.0}                                                                
{'loss': 9.8208, 'grad_norm': 7.467105865478516, 'learning_rate': 0.0018000000000000002, 'epoch': 0.0}                                                  
{'loss': 8.606, 'grad_norm': 6.2992963790893555, 'learning_rate': 0.0017, 'epoch': 0.0}                                                                 
{'loss': 7.8436, 'grad_norm': 5.3465986251831055, 'learning_rate': 0.0016, 'epoch': 0.0}                                                                
{'loss': 7.6177, 'grad_norm': 6.2392964363098145, 'learning_rate': 0.0015, 'epoch': 0.0}                                                                
{'loss': 7.5346, 'grad_norm': 4.487287998199463, 'learning_rate': 0.0014, 'epoch': 0.0}                                                                 
{'loss': 7.6909, 'grad_norm': 4.615128517150879, 'learning_rate': 0.0013000000000000002, 'epoch': 0.0}                                                  
{'loss': 7.0826, 'grad_norm': 5.807451248168945, 'learning_rate': 0.0012, 'epoch': 0.0}                                                                 
{'loss': 7.1936, 'grad_norm': 3.470165729522705, 'learning_rate': 0.0011, 'epoch': 0.0}                                                                 
{'loss': 7.1926, 'grad_norm': 4.511063575744629, 'learning_rate': 0.001, 'epoch': 0.0}  

Using a single A100 80GB, the loss seems to converge nicely. It is expected that at init the optimizer takes some time to initialize itself

younesbelkada avatar Mar 11 '24 13:03 younesbelkada

@younesbelkada I tried your gemma code and faced the following error:

image

savanth14 avatar Mar 13 '24 07:03 savanth14

Thanks @younesbelkada! I'll open up another PR with just the validation and training args pieces and wait for the upstream integration. Much appreciated!

winglian avatar Mar 16 '24 02:03 winglian

thanks so much @winglian !

younesbelkada avatar Mar 17 '24 10:03 younesbelkada

Superseded by #1409. Thanks for getting this rolling @maximegmd. Props to @younesbelkada for getting this working upstream in transformers.

winglian avatar Mar 19 '24 16:03 winglian