litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

OptimizerArgs

Open rasbt opened this issue 9 months ago • 4 comments

This PR unbundles the OptimizerArgs approach from GaLore in #1192.

Todos

  • [ ] OptimizerArgs for full finetuning

    • [ ] Update code
    • [ ] Add docstrings
    • [ ] Update docs
    • [ ] Update config files
  • [ ] OptimizerArgs for LoRA

    • [ ] Update code
    • [ ] Add docstrings
    • [ ] Update docs
    • [ ] Update config files
  • [ ] OptimizerArgs for Adapter

    • [ ] Update code
    • [ ] Add docstrings
    • [ ] Update docs
    • [ ] Update config files
  • [ ] OptimizerArgs for Adapter v2

    • [ ] Update code
    • [ ] Add docstrings
    • [ ] Update docs
    • [ ] Update config files
  • [ ] ensure that both --optimizer torch.optim.AdamW and --optimizer AdamW works

  • [ ] Add tests

rasbt avatar May 10 '24 20:05 rasbt

Your jsonargparse example has been super helpful for understanding things a bit more @carmocca . Many thanks for this!

But maybe it's because it's Fri evening but my brain is just not working today. I've just been banging my head against how I would get this into the finetuning method's setup.

Adding the optimizer subclass to the parser and then calling the finetune command yields a

  File "/home/zeus/miniconda3/envs/cloudspace/bin/litgpt", line 8, in <module>
    sys.exit(main())
  File "/teamspace/studios/this_studio/litgpt/litgpt/__main__.py", line 145, in main
    fn(**kwargs)
TypeError: setup() got an unexpected keyword argument 'optimizer.class_path'

But we can't add optimizer argument to the finetuning setup signature because it's a duplicate command then. Conceptually, I am kind of stuck here.

Also, how would we get the args in

optimizer = instantiate_class(model.parameters(), init=args["optimizer"])

if we don't pass them on from the main() function in __main.py__. I suppose if I would do the

parser = jsonargparse.ArgumentParser()
parser.add_subclass_arguments(torch.optim.Optimizer, "optimizer", instantiate=False, fail_untyped=False, skip={"params"})
args = parser.parse_args()

in the finetuning script, but then it would erase all the previous arguments.

rasbt avatar May 10 '24 20:05 rasbt

As far as integrating into the scripts, I would:

Create an optimizer argument in https://github.com/Lightning-AI/litgpt/blob/36c6a77435d75872f525848ee1570467d120ae80/litgpt/finetune/lora.py#L40

To avoid the duplicate registration, you need to skip it when the function arguments are added https://github.com/omni-us/jsonargparse/blob/2de15ddfb1c02c2f7b3fe913ad11f13c5cb65dff/jsonargparse/_signatures.py#L166 https://github.com/Lightning-AI/litgpt/blob/36c6a77435d75872f525848ee1570467d120ae80/litgpt/main.py#L121

And call instantiate_class here https://github.com/Lightning-AI/litgpt/blob/36c6a77435d75872f525848ee1570467d120ae80/litgpt/finetune/lora.py#L185-L187

This should be enough to unblock you. The not-so-nice thing is that the CLI args structure leaks into the actual script, meaning that users who don't go through the CLI will have to create this dictionary manually.

carmocca avatar May 13 '24 10:05 carmocca

Awesome, thanks so much, this was great help! Figured it out now and got it to work. Many thanks, again learned something new!

rasbt avatar May 14 '24 12:05 rasbt

I now got it to work as follows:

litgpt finetune full \
  --checkpoint_dir checkpoints/EleutherAI/pythia-160m 

# Specify optimizer and optimizer args:
litgpt finetune full \
  --checkpoint_dir checkpoints/EleutherAI/pythia-160m \
  --optimizer  torch.optim.SGD \
  --optimizer.init_args.lr 1000

But I feel like the way I am passing the optimizer kwargs seems a bit hacky. Is this there a built-in/better way to handle it @carmocca ? The thing is that when I pass an --optimizer argument it also passes additional kwargs to the setup:

kwargs = {
    'optimizer.class_path': 'torch.optim.SGD',
    'optimizer.init_args.dampening': 0.0,
    'optimizer.init_args.differentiable': False,
    'optimizer.init_args.foreach': None,
    'optimizer.init_args.lr': 0.001,
    'optimizer.init_args.maximize': False,
    'optimizer.init_args.momentum': 0.0,
    'optimizer.init_args.nesterov': False,
    'optimizer.init_args.weight_decay': 0.0
}

That's why I added the parsing into class_path and init_args:

    optimizer_class_path = None
    optimizer_init_args = {}
    for key, value in list(kwargs.items()):
        if key.startswith("optimizer"):
            if "class_path" in key:
                optimizer_class_path = value
            elif "init_args" in key:
                init_arg_key = key.split(".")[-1]
                optimizer_init_args[init_arg_key] = value
            del kwargs[key]

Everything seems to work, but I wonder if there isn't a better way to do it?

rasbt avatar May 14 '24 14:05 rasbt

@rasbt I pushed a commit with what I would suggest. The str code path could be improved if we want to expose arguments like the learning rate outside of the CLI, but that should be straightforward to implement.

Also fyi, you don't need to specify the .init_args substring through command line

carmocca avatar May 21 '24 17:05 carmocca

The only caveat now is that the class path still needs to be specified. I.e., only specifying the learning rate doesn't work

litgpt finetune full  --optimizer.lr 200  --checkpoint_dir checkpoints/EleutherAI/pythia-160m

error: Parser key "optimizer":
  Not a valid subclass of Optimizer. Got value: NestedArg(key='lr', val='200')
  Subclass types expect one of:
  - a class path (str)
  - a dict with class_path entry
  - a dict without class_path but with init_args entry (class path given previously)

And the optimizer always needs to be specified explicitely

litgpt finetune full  --optimizer AdamW --optimizer.lr 200  --checkpoint_dir checkpoints/EleutherAI/pythia-160m

Do you know if that's a jsonargparse thing, @carmocca ? Because we already set a default value in the setup method I was thinking that this is a bit weird.

rasbt avatar May 21 '24 18:05 rasbt

I hope this is ready now @carmocca

rasbt avatar May 22 '24 23:05 rasbt

The azure failure does look real:

>       fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval, optimizer)
E       TypeError: fit() takes 9 positional arguments but 10 were given

/__w/6/s/extensions/thunder/pretrain.py:229: TypeError
----------------------------- Captured stderr call -----------------------------
Missing logger folder: /tmp/pytest-of-root/pytest-0/test_pretrain0/out/logs/tensorboard
Seed set to 42
=========================== short test summary info ============================
FAILED tests/test_thunder_pretrain.py::test_pretrain - TypeError: fit() takes 9 positional arguments but 10 were given

carmocca avatar May 23 '24 15:05 carmocca

It does. Let me investigate ...

rasbt avatar May 23 '24 15:05 rasbt

Should be fixed for good now @carmocca . I can switch the link to the original tinystories now that you have seen the green checks haha 😆

rasbt avatar May 23 '24 16:05 rasbt