keras-cv
keras-cv copied to clipboard
Training script improvement proposal
What does this PR do?
Sorry for not communicating it before making a pull request, so feel free to totally reject it if it doesn't fit.
- Fixed typo (Adam -> SGD)
- Fixed? wrong dtype for TPU (float16 -> bfloat16)
- Added warmup with cosine decay as an optional scheduler + flag for it
Thoughts? @ianstenbit @tanzhenyu @bhack
Another note - I don't know what the device utilization is on the GC environment this is running on. If there's space for more compute, we might want to try increasing the steps_per_execution
as well. I've seen ~10% speedups on some of my own models with it, which amounts to a lot when it comes to pre-training a bunch of models on ImageNet
Should I run format.sh
on this script as well?
Should I run
format.sh
on this script as well?
Yes, the formatter should cover these scripts.
I am doing a test run of this script against a KerasCV EfficientNet
Formatting in a second, thanks! Excited to hear how EffNet does with this script! Can you post the results here?
Formatted 👍
Formatting in a second, thanks! Excited to hear how EffNet does with this script! Can you post the results here?
Will do!
@ianstenbit Did the warmup help? :)
@ianstenbit Did the warmup help? :)
Sorry for the delay -- since I did the first test run (when I sent my last set of comments), I've been focusing on some other training runs. I am planning to kick off a training run of an EfficientNet using this script some time this week.
Starting a training run with this change now. In the meantime, could you merge from master to get #935 included? @DavidLandup0
Starting a training run with this change now. In the meantime, could you merge from master to get #935 included? @DavidLandup0
Will do! I'm getting to my PC in a couple of hours, so I'll get that and update it according to your comments then. Thanks! Excited to see whether the schedule helps :) No problem in taking the time to test it out!
Sounds great, thanks! I want to add some docstrings. Do I add the casts and push again or do you want to add the casts in and I push the docstrings?
Sounds great, thanks! I want to add some docstrings. Do I add the casts and push again or do you want to add the casts in and I push the docstrings?
Sorry I missed this before. If I push this it'll add me to the authors on this PR, so I think you ought to just add the casts and push :)
- Changed warmup percentages to be floats between 0 and 1 instead of integers to avoid further multiplication later
- Added docstrings
- Replaced NumPy with math because of a single constant call
- Added
NUM_IMAGES
used later instead of the dataset's length - Added
tf.cast
calls - Updated default target_lr to 1e-2 (we're using SGD)
- Removed gradient clipnorm from optimizer when using scheduler (shouldn't be a need for it)
It's odd that the casts were required. In my own runs, dtypes weren't an issue. Might be a TPU thing?
@ianstenbit How's the training going so far? :D
@ianstenbit How's the training going so far? :D
Good! Running with default settings and 300 epochs. After 44 epochs we're at 66% top-1 accuracy.
Awesome, thanks for the update! It'll start decaying the LR soon then - excited to see how far it goes!
If need be, we can search for a better set of parameters than 0.1
and 0.1
. They worked well for me before, so I use them as the defaults, but these can be tuned for different networks, training scripts, etc. At the time it was ran, was the target_lr 1e-3 or 1e-2? It was updated in the latest push because SGD's default should be 1e-2 🤔
How did the run end up doing? Tried different parameters?
How did the run end up doing? Tried different parameters?
After 250/300 epochs, EfficientNetV2B3 is at ~77% top-1 accuracy. This isn't good enough to offer the weights (unless it gets at least 1% more in the next 50 epochs, which is possible), but it's close enough that I think this LR schedule code should be merged. We can tweak runs and try different parameters for future runs.
I'm using 1e-1 as a target LR (really it's 0.0125 but multiplied by 8 as I'm using 8 accelerators)
/gcbrun
Looks like this needs to be formatted. Can you please run the linter using sh shell/format.sh
Thanks for the update! I'll try out a couple of variations on Imagenette to see if there's some parameter that obviously works better, because it might translate into a smaller increase on ImageNet. Keeping you updated with a table of params/results :)
The target lr sounds good. Thanks!
Hmm, I've ran the format script before pushing and it's currently also showing no files changed when run 🤔
This is for the last commit

Oh okay, the action shows what the issue was. There was a trailing whitespace in line 261 in the comment before the newline which the linter doesn't see apparently.
Since it looks like the action only mentioned that one whitespace, I've removed it manually and pushed again :)
/gcbrun
Thanks for the merge! I'll keep you updated on some of the experiments with the script on Imagenette :)
Okay, so I did some test runs with this script, substituted ImageNet with Imagenette, and here's a finding for LR warmup. If we only hold for ~10%, it starts decaying too fast and prevents the later epochs to be utilized well enough. I've upped the hold period to 35% in my test script, and the validation accuracy jumped from ~80% to ~86%. Now, that's Imagenette, with fewer training epochs and a faster reduction on plateau but here's the guiding principle:
In the tensorboard runs you've had so far, the first reduction happens at ~80% of the training in. If we try to match the hold period to get near that point (or at the start of the plateau, because reduction happens only after a plateau), we can slowly reduce the training instead of step-reduce it at that point. In the case of Imagenette, the first plateau/reduction hit at ~66% of the training, and a 35% holding period boosted the accuracy significantly.
Maybe we could get the Imagenet accuracy up by some amount by having the warmup_hold_percentage
increased to anywhere between 0.3
and 0.4
? @ianstenbit
Re-running the experiments with longer training for ReduceLROnPlateau, with a longer plateau period and longer early stopping
ReduceLROnPlateau:
Cosine Decay with 0.1 warmup and 0.45 hold:
Both had a policy where Mixup or CutMix are applied at batch-level with a 50% probability. This seems to have boosted the accuracy a bit as well. Training curve comparison:
P.S. I made the labels a bit weird - both use the same CutMix/MixUp policy, not just the other one. The only difference between these is the LR schedule.
Since the first plot resembles your training plots for ImageNet (although, a bit more ragged), I feel like this could very well translate to a small increase even though it was validated only on Imagenette. Do we have the resources to re-run V2B0 or do V2B3 to see if this helps? I can push a new update with the augmentation policy change and the new default for the holding period