Implementation of lookahead
From https://arxiv.org/abs/2006.14567 (taken from the Projects page for this repo)
Currently testing on my own, but so far has had very promising early results.
Currently untested:
- Multi GPU setup since I only have one
- Theoretically supports loading from checkpoints that didn't use lookahead originally
Tested Multi-GPU using Azure and verified that it at least ran for >100 iterations and produced expected outputs. That's about as much as I can validate for now.
Also tested resuming training after starting without lookahead to confirm that works as well.
Friendly ping to @lucidrains 😄. My own testing with lookahead resulted in excellent improvements of outputs when training without attention, I'm interested to see if others see similar improvements. My results weren't as good with attention applied, though I did do all of my testing before the recent attention fix, so it might perform better now.
@lucidrains any chance of merging this? :smiley:
Hm, I've tried this and it doesn't actually seem to work:
File "/mnt/storage3/training/stylegan2-lookahead/bin/stylegan2_pytorch", line 8, in <module>
sys.exit(main())
File "/mnt/storage3/training/stylegan2-lookahead/lib/python3.10/site-packages/stylegan2_pytorch/cli.py", line 215, in main
fire.Fire(train_from_folder)
File "/mnt/storage3/training/stylegan2-lookahead/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/mnt/storage3/training/stylegan2-lookahead/lib/python3.10/site-packages/fire/core.py", line 466, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/mnt/storage3/training/stylegan2-lookahead/lib/python3.10/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/mnt/storage3/training/stylegan2-lookahead/lib/python3.10/site-packages/stylegan2_pytorch/cli.py", line 206, in train_from_folder
run_training(0, 1, model_args, data, load_from, disc_load_from, new, num_train_steps, name, seed)
File "/mnt/storage3/training/stylegan2-lookahead/lib/python3.10/site-packages/stylegan2_pytorch/cli.py", line 66, in run_training
retry_call(model.train, tries=3, exceptions=NanException)
File "/mnt/storage3/training/stylegan2-lookahead/lib/python3.10/site-packages/retry/api.py", line 101, in retry_call
return __retry_internal(partial(f, *args, **kwargs), exceptions, tries, delay, max_delay, backoff, jitter, logger)
File "/mnt/storage3/training/stylegan2-lookahead/lib/python3.10/site-packages/retry/api.py", line 33, in __retry_internal
return f()
File "/mnt/storage3/training/stylegan2-lookahead/lib/python3.10/site-packages/stylegan2_pytorch/stylegan2_pytorch.py", line 1079, in train
self.GAN.D_opt.zero_grad()
File "/mnt/storage3/training/stylegan2-lookahead/lib/python3.10/site-packages/torch/optim/optimizer.py", line 231, in zero_grad
foreach = self.defaults.get('foreach', False)
AttributeError: 'Lookahead' object has no attribute 'defaults'
It looks like something in PyTorch changed in the past year that makes the code not work. I promise it did work when I made the PR 😄. Unfortunately, I don't have a setup available right now to debug and fix the issue.
@tannisroot If you're feeling adventurous, I think the issue might be worked around by calling super().__init__() in the __init__ function for Lookahead. I think that might fix the issue so you can see Lookahead possibly working. If you do try that, let me know if it works.
Unfortunately, I'm not very familiar with implementing a PyTorch optimizer. The code for Lookahead came directly from the paper. So if that doesn't work, I'm all out of ideas, haha.