litgpt
litgpt copied to clipboard
OptimizerArgs
This PR unbundles the OptimizerArgs approach from GaLore in #1192.
Todos
-
[ ] OptimizerArgs for full finetuning
- [ ] Update code
- [ ] Add docstrings
- [ ] Update docs
- [ ] Update config files
-
[ ] OptimizerArgs for LoRA
- [ ] Update code
- [ ] Add docstrings
- [ ] Update docs
- [ ] Update config files
-
[ ] OptimizerArgs for Adapter
- [ ] Update code
- [ ] Add docstrings
- [ ] Update docs
- [ ] Update config files
-
[ ] OptimizerArgs for Adapter v2
- [ ] Update code
- [ ] Add docstrings
- [ ] Update docs
- [ ] Update config files
-
[ ] ensure that both
--optimizer torch.optim.AdamW
and--optimizer AdamW
works -
[ ] Add tests
Your jsonargparse example has been super helpful for understanding things a bit more @carmocca . Many thanks for this!
But maybe it's because it's Fri evening but my brain is just not working today. I've just been banging my head against how I would get this into the finetuning method's setup
.
Adding the optimizer
subclass to the parser and then calling the finetune command yields a
File "/home/zeus/miniconda3/envs/cloudspace/bin/litgpt", line 8, in <module>
sys.exit(main())
File "/teamspace/studios/this_studio/litgpt/litgpt/__main__.py", line 145, in main
fn(**kwargs)
TypeError: setup() got an unexpected keyword argument 'optimizer.class_path'
But we can't add optimizer
argument to the finetuning setup
signature because it's a duplicate command then. Conceptually, I am kind of stuck here.
Also, how would we get the args
in
optimizer = instantiate_class(model.parameters(), init=args["optimizer"])
if we don't pass them on from the main()
function in __main.py__
. I suppose if I would do the
parser = jsonargparse.ArgumentParser()
parser.add_subclass_arguments(torch.optim.Optimizer, "optimizer", instantiate=False, fail_untyped=False, skip={"params"})
args = parser.parse_args()
in the finetuning script, but then it would erase all the previous arguments.
As far as integrating into the scripts, I would:
Create an optimizer argument in https://github.com/Lightning-AI/litgpt/blob/36c6a77435d75872f525848ee1570467d120ae80/litgpt/finetune/lora.py#L40
To avoid the duplicate registration, you need to skip it when the function arguments are added https://github.com/omni-us/jsonargparse/blob/2de15ddfb1c02c2f7b3fe913ad11f13c5cb65dff/jsonargparse/_signatures.py#L166 https://github.com/Lightning-AI/litgpt/blob/36c6a77435d75872f525848ee1570467d120ae80/litgpt/main.py#L121
And call instantiate_class
here https://github.com/Lightning-AI/litgpt/blob/36c6a77435d75872f525848ee1570467d120ae80/litgpt/finetune/lora.py#L185-L187
This should be enough to unblock you. The not-so-nice thing is that the CLI args structure leaks into the actual script, meaning that users who don't go through the CLI will have to create this dictionary manually.
Awesome, thanks so much, this was great help! Figured it out now and got it to work. Many thanks, again learned something new!
I now got it to work as follows:
litgpt finetune full \
--checkpoint_dir checkpoints/EleutherAI/pythia-160m
# Specify optimizer and optimizer args:
litgpt finetune full \
--checkpoint_dir checkpoints/EleutherAI/pythia-160m \
--optimizer torch.optim.SGD \
--optimizer.init_args.lr 1000
But I feel like the way I am passing the optimizer kwargs seems a bit hacky. Is this there a built-in/better way to handle it @carmocca ? The thing is that when I pass an --optimizer
argument it also passes additional kwargs to the setup:
kwargs = {
'optimizer.class_path': 'torch.optim.SGD',
'optimizer.init_args.dampening': 0.0,
'optimizer.init_args.differentiable': False,
'optimizer.init_args.foreach': None,
'optimizer.init_args.lr': 0.001,
'optimizer.init_args.maximize': False,
'optimizer.init_args.momentum': 0.0,
'optimizer.init_args.nesterov': False,
'optimizer.init_args.weight_decay': 0.0
}
That's why I added the parsing into class_path and init_args:
optimizer_class_path = None
optimizer_init_args = {}
for key, value in list(kwargs.items()):
if key.startswith("optimizer"):
if "class_path" in key:
optimizer_class_path = value
elif "init_args" in key:
init_arg_key = key.split(".")[-1]
optimizer_init_args[init_arg_key] = value
del kwargs[key]
Everything seems to work, but I wonder if there isn't a better way to do it?
@rasbt I pushed a commit with what I would suggest. The str
code path could be improved if we want to expose arguments like the learning rate outside of the CLI, but that should be straightforward to implement.
Also fyi, you don't need to specify the .init_args
substring through command line
The only caveat now is that the class path still needs to be specified. I.e., only specifying the learning rate doesn't work
litgpt finetune full --optimizer.lr 200 --checkpoint_dir checkpoints/EleutherAI/pythia-160m
error: Parser key "optimizer":
Not a valid subclass of Optimizer. Got value: NestedArg(key='lr', val='200')
Subclass types expect one of:
- a class path (str)
- a dict with class_path entry
- a dict without class_path but with init_args entry (class path given previously)
And the optimizer always needs to be specified explicitely
litgpt finetune full --optimizer AdamW --optimizer.lr 200 --checkpoint_dir checkpoints/EleutherAI/pythia-160m
Do you know if that's a jsonargparse thing, @carmocca ? Because we already set a default value in the setup
method I was thinking that this is a bit weird.
I hope this is ready now @carmocca
The azure failure does look real:
> fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval, optimizer)
E TypeError: fit() takes 9 positional arguments but 10 were given
/__w/6/s/extensions/thunder/pretrain.py:229: TypeError
----------------------------- Captured stderr call -----------------------------
Missing logger folder: /tmp/pytest-of-root/pytest-0/test_pretrain0/out/logs/tensorboard
Seed set to 42
=========================== short test summary info ============================
FAILED tests/test_thunder_pretrain.py::test_pretrain - TypeError: fit() takes 9 positional arguments but 10 were given
It does. Let me investigate ...
Should be fixed for good now @carmocca . I can switch the link to the original tinystories now that you have seen the green checks haha 😆