litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

Introduce OptimizerArgs and add support for GaLore

Open rasbt opened this issue 1 year ago • 13 comments

The current implementation adds GaLore to the full finetuning script.

Example

# regular
litgpt finetune full \
  --checkpoint_dir checkpoints/EleutherAI/pythia-160m \
  --data Alpaca2k \
  --train.max_steps 5 

# Training time: 14.13s
# Memory used: 3.44 GB



# with galore
litgpt finetune full \
  --checkpoint_dir checkpoints/EleutherAI/pythia-160m \
  --data Alpaca2k \
  --train.max_steps 5  \
  --galore.use_galore true

# Training time: 23.59s
# Memory used: 3.44 GB



# with 8bit galore
litgpt finetune full \
  --checkpoint_dir checkpoints/EleutherAI/pythia-160m \
  --data Alpaca2k \
  --train.max_steps 5  \
  --galore.use_galore true \
  --galore.galore_8bit

# Training time: 17.96s
# Memory used: 2.47 GB

Discuss

We could also add it to LoRA

  • this would require a check that GaLore is only used when QLoRA is disabled
  • we can actually use it with some bnb precision settings (this would be supported according to them via GaloreAdamW8Bit)

I specified the galore args similar to what we do with lora. But since this is more an addon to existing methods like full and lora , should we maybe make this part of TrainArgs?

We can also think about making a dedicated subcommand like for qlora in the future. Ie..,

litgpt finetune full --config ... 

litgpt finetune lora --config ... 

litgpt finetune qlora --config ... [in progress]

litgpt finetune galore --config ... [maybe in future]

Todos

  • [x] Add Galore for full finetuning
  • [x] Check if default args are good
  • [x] Add docstrings
  • [x] Discuss if we use TrainArgs (see above)
  • [x] Add Galore for lora finetuning (investigate NotImplementedError: Cannot merge the pretrained weights of type torch.float16 and LoRA weights of type torch.float32
  • [x] Throw error if galore and qlora are used at the same time if Qlora is not 8bit
  • [x] Should we also allow 8bit galore without QLoRA? I'd say yes. How? galore_8bit = True?
  • [x] Update full and lora config files
  • [x] Add galore for pretraining
  • [x] Consider adding it for adapter and adapter v2
  • [ ] Add tests
  • [x] Restrict to single GPU training
  • [x] Add GaLore package to the acknowledgements section
  • [x] Add documentation
  • [x] Add configs YAMLs and benchmarks

Fixes #1075

rasbt avatar Mar 25 '24 21:03 rasbt

After our discussion today, I think we should only enable vanilla Galore for now, not worrying about the LoRA support. We can look into LoRA support later if there is high-demand. I am getting some precision-related errors when trying to use it with LoRA, which has likely something to do with the precision that is used by the Galore optimizer under the hood. I am expecting the Galore package to evolve in the upcoming weeks and months, and we can then revist if LoRA works without us having to make additional tweaks to the GaLore optimizer etc.

rasbt avatar Mar 26 '24 16:03 rasbt

@rasbt How much of an improvement in VRAM consumption you saw with LoRA+GaLore? With any PEFT algo the amount of parameters to optimize shouldn't be that significant.

Andrei-Aksionov avatar Mar 26 '24 17:03 Andrei-Aksionov

The combination of LoRA + GaLore doesn't really work yet due precision mismatches when merging the LoRA weighs at the end so it didn't get to the code line that prints the memory usage. I could comment out the merging and try it again, but I think let's just focus on Galore for full finetuning first. Like you said, I don't expect a big improvement when combined with LoRA.

rasbt avatar Mar 27 '24 00:03 rasbt

I changed the GaloreArgs to OptimizerArgs and here are some results for phi-2. What's puzzling is the pretraining performance. I couldn't find the issue and may need to investigate more. Also need to update the config files once we settled on the API.

Full

AdamW

litgpt finetune full \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5

# Training time: 32.76s
# Memory used: 55.84 GB

GaLore

litgpt finetune full \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5 \
  --optim.optimizer "galore_adamw"

# Training time: 128.55s
# Memory used: 36.14 GB

GaLore 8-bit

litgpt finetune full \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5 \
  --optim.optimizer "galore_adamw_8bit"

# Training time: 128.68s
# Memory used: 33.81 GB

LoRA

AdamW

litgpt finetune lora \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5

# Training time: 36.43s
# Memory used: 18.56 GB

GaLore

litgpt finetune lora \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5 \
  --optim.optimizer "galore_adamw"

# Training time: 25.98s
# Memory used: 18.56 GB

GaLore 8-bit

litgpt finetune lora \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5 \
  --optim.optimizer "galore_adamw_8bit"

# Training time: 26.01s
# Memory used: 18.54 GB

Adapter

AdamW

litgpt finetune adapter \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5

# Training time: 31.16s
# Memory used: 17.94 GB

GaLore

litgpt finetune adapter \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5 \
  --optim.optimizer "galore_adamw"

# Training time: 24.81s
# Memory used: 17.94 GB

GaLore 8-bit

litgpt finetune adapter_v2 \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5 \
  --optim.optimizer "galore_adamw_8bit"

# Training time: 26.36s
# Memory used: 20.10 GB

Adapter v2

AdamW

litgpt finetune adapter_v2 \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5

# Training time: 26.35s
# Memory used: 20.11 GB

GaLore

litgpt finetune adapter_v2 \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5 \
  --optim.optimizer "galore_adamw"
# Training time: 26.31s
# Memory used: 20.11 GB

GaLore 8-bit

litgpt finetune adapter_v2 \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5 \
  --optim.optimizer "galore_adamw_8bit"
# Training time: 26.26s
# Memory used: 20.10 GB

Pretrain (Pythia 14M)

AdamW

litgpt pretrain \
  --model_name pythia-14m \
  --tokenizer_dir checkpoints/EleutherAI/pythia-14m/ \
  --data TextFiles \
  --data.train_data_path "custom_texts" \
  --train.max_tokens 100_000

# Training time: 34.07s
# Memory used: 1.44 GB

GaLore

litgpt pretrain \
  --model_name pythia-14m \
  --tokenizer_dir checkpoints/EleutherAI/pythia-14m/ \
  --data TextFiles \
  --data.train_data_path "custom_texts" \
  --train.max_tokens 100_000 \
  --optim.optimizer "galore_adamw"

# Training time: 25.31s
# Memory used: 1.44 GB

GaLore 8-bit

litgpt pretrain \
  --model_name pythia-14m \
  --tokenizer_dir checkpoints/EleutherAI/pythia-14m/ \
  --data TextFiles \
  --data.train_data_path "custom_texts" \
  --train.max_tokens 100_000 \
  --optim.optimizer "galore_adamw_8bit"
# Training time: 25.31s
# Memory used: 1.44 GB

rasbt avatar May 03 '24 22:05 rasbt

I tried many things and even ended up replacing all instances of torch's AdamW with Galore's to make sure it's actually used, but for for some reason, I cannot see any difference in memory usage when pretraining. Mind boggling.

rasbt avatar May 06 '24 23:05 rasbt

I changed the hardcoded galore arguments to general extra_kwargs so they could be used for other optimizer options as well. This way it adds less clutter to the CLI.

So, what's new is that we now have optimizer kwargs. E.g., this adds

# Optimizer-related arguments
optim: 
  # Which optimizer to use. Possible choices: "adamw", "galore_adamw", "galore_adamw_8bit". (type: Optional[str], default: "adamw")
  optimizer: "adamw"

  #   (type: float, default: 0.0003)
  learning_rate: 0.0002

  #   (type: float, default: 0.02)
  weight_decay: 0.0

  #   (type: float, default: 0.9)
  beta1: 0.9

  #   (type: float, default: 0.95)
  beta2: 0.95

  # Additional optimizer keyword arguments, for example, "rank=8,update_proj_gap=200" for GaLore. (type: Optional[str], default: None)
  extra_kwargs:

What do you think about this approach and interface @carmocca @lantiga @awaelchli ?

rasbt avatar May 09 '24 18:05 rasbt

The jsonargparse-y way of doing this would be to instead specify which Optimizer class you want to select to let the parser pull out the arguments of said class. For example, that is exactly how the data is selected and parsed

carmocca avatar May 10 '24 12:05 carmocca

OMG I made it way more complicated than it need be 🤦‍♂️. Thanks for the hint. Now I know.

rasbt avatar May 10 '24 16:05 rasbt

After trying this, I realize that this may not be cleanly possible because optimizers require params as positional argument. So we would have to wrap the optimizer in our own optimizer class. The other problem is with the Galore optimizer, which needs to split the params into regular params and galore params prior to passing them. It kind of gets ugly real quick.

We could probably have this jsonargparse approach for PyTorch native optimizers, but I don't think it will be easy to support Galore this way in a non-hacky way.

I can make a PR with just PyTorch optimizer support and then we can decide whether which route want to go, only supporting PyTorch optimizers, or revisiting this implementation here with our own extra_args parsing.

rasbt avatar May 10 '24 17:05 rasbt

Yes, we cannot have jsonargparse instantiate the class directly for that reason.

But you can still tell it to add all the arguments of a class (or classes) into a group of args, basically getting you OptimizerArgs automatically for that class. Then those args can be used to instantiate the real optimizer instance later in the script.

The PyTorch Lightning CLI implementation works that way: https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/cli.py#L154-L177

carmocca avatar May 10 '24 17:05 carmocca

Arg, I am still struggling with this.

I.e.,

 litgpt finetune full --optimizer.help torch.optim.AdamW    

works without problem but then even if I don't do anything else, jsonargparse tries to initialize it already via

litgpt finetune full  ... --optimizer torch.optim.AdamW 

before I can pass it to anything else. Not sure how to avoid that. I think I need to study jsonargparse a bit better because right now I feel like I am trying to hack things together somehow ...

rasbt avatar May 10 '24 18:05 rasbt

You can start by understanding this minimal example:

import torch
import jsonargparse

parser = jsonargparse.ArgumentParser()
parser.add_subclass_arguments(torch.optim.Optimizer, "optimizer", instantiate=False, fail_untyped=False, skip={"params"})
args = parser.parse_args()
print(args)
python example.py --optimizer Adam   
Namespace(optimizer=Namespace(class_path='torch.optim.Adam', init_args=Namespace(lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False, foreach=None, maximize=False, capturable=False, differentiable=False, fused=None)))

carmocca avatar May 10 '24 18:05 carmocca

And here's how you would use the above to instantiate the optimizer:

from typing import Any, Tuple, Dict, Union

def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
    """Instantiates a class with the given args and init.

    Args:
        args: Positional arguments required for instantiation.
        init: Dict of the form {"class_path":...,"init_args":...}.

    Returns:
        The instantiated class object.

    """
    kwargs = init.get("init_args", {})
    if not isinstance(args, tuple):
        args = (args,)
    class_module, class_name = init["class_path"].rsplit(".", 1)
    module = __import__(class_module, fromlist=[class_name])
    args_class = getattr(module, class_name)
    return args_class(*args, **kwargs)


model = torch.nn.Linear(1, 1)
optimizer = instantiate_class(model.parameters(), init=args["optimizer"])
print(optimizer)

We define instantiate_class for the PyTorch Lightning CLI here: https://github.com/Lightning-AI/pytorch-lightning/blob/90d04b5b86f37994cdceccc6de32f0e93b1cc7f0/src/lightning/pytorch/cli.py#L752-L769

carmocca avatar May 10 '24 19:05 carmocca