Metalhead.jl
Metalhead.jl copied to clipboard
Implementation of EfficientNetv2
This is an implementation of EfficientNetv2. There's clearly some code duplication between the EffNets and the MobileNets, but figuring out how to unify that API and its design is perhaps best left for after the models themselves are in place.
TODO
- [x] Figure out a way to unify the EfficientNet and EfficientNetv2 lower level API into a single
efficientnetfunction - [ ] More elaborate docs for
mbconvandfused_mbconv
P.S. memory issues mean that the tests are more fragmented than they ought to be, not sure how we can go about addressing those in the short/medium term
Side note - using basic_conv_bn for the Inception models seems to have fixed their gradient times somehow, which are now much more manageable. Maybe the extra biases were causing a slowdown
Intermittent failures for the ConvNeXt, ConvMixer testset on Julia 1 seem memory related (nightly tests pass. Very funny thing to suddenly start occurring). The pretrained ResNets will be broken until https://github.com/FluxML/Flux.jl/pull/2041 lands. Everything else should pass though
CI is now green :) (modulo ResNet note above)
Has anyone here succeeded in training EfficientNetv2, or aware of example code that does so? I'm finding with an Adam and logitcrossentropy approach that it learns to always return one of the classes
On ImageNet? I used the MobileNets in this package a while back and trained them on CIFAR-100. Looking at that code, I used Adam and logit CE loss. Also a bunch of data augmentation. MobileNets share some code overlap with EfficientNets in the package.
I'm trying on a custom dataset but perhaps it makes sense to start with something tried and tested.
Yeah if something like CIFAR-10 doesn't work, then at least someone here can try and reproduce.
I've tried training EfficientNetV2(:small; ... on CIFAR-10 here and get
Would you mind sanity-checking my approach?
Switching to EfficientNet(:b0; ... I get
Okay I'll make my own script over the weekend and sanity check.
Thanks!
Comparing to this example https://www.kaggle.com/code/paultimothymooney/efficientnetv2-with-tensorflow?scriptVersionId=120655976&cellId=27
- It uses SGD, not Adam.
- Learning rate is 0.005 vs the default Adam 0.001
- Primary momentum is the same
- It uses label smoothing
And performance maxes out within a few epochs
@darsnack i was just wondering whether you have had a chance to test this or check my code for any obvious issues. Thanks
Using https://github.com/IanButterworth/EfficientNet-Training/commit/ec8008408b49a38042a0c40ed58d27bd460234d2 and switching just the model to ResNet gets a much more reasonable result
model = ResNet(18; pretrain=false, inchannels=3, nclasses=length(labels)) |> device
Other than there being some issue with the model, could it be that EfficientNet requires some subtle difference in the loss function or optimizer setups?
EfficientNet (not V2)
model = EfficientNet(:b0; pretrain=false, inchannels=3, nclasses=length(labels)) |> device
I was trying a lot of these on CIFAR-10 during my GSoC and was facing issues, including with ResNet – the accuracies were nowhere near what I could get with PyTorch, for example. I remember trying to debug this but then had to give up since I got occupied with other commitments. IIRC at the time one theory was that there might be something wrong with the gradients, but we didn't nearly manage to get enough training examples on the GPU to confirm. I could try running these again if I got a GPU somehow 😅
The difference between EfficientNet and ResNet is weird, but not unexpected because they do use different kinds of layers. Maybe MobileNet vs EfficientNet tells us more? The underlying code over there is the exact same because of the way the models are constructed. Even for ResNet, though, the loss curve looks kinda bumpy...
Interesting.
I was considering setting up a job to do the same for all appropriate models in Metalhead on my machine, and tracking any useful metrics while doing so for a comparison table. I have a relatively decent setup so it might not take too long.
I can understand that full training isn't a part of CI because of resources, but I think it needs to be demonstrated that these models are trainable somewhere public.
I've started putting together a benchmarking script here. https://github.com/FluxML/Metalhead.jl/pull/264 It'd be great to get feedback on it.
As @theabhirath mentioned, this has come up before but we never got to the bottom of why because of time constraints. If you have some bandwidth to do some diagnostic runs looking at whether gradients, activations etc are misbehaving and where they're misbehaving, we could narrow the problem down to something actionable. Whether that be a bug in Metalhead itself or something further upstream like Flux, Zygote or CUDA.
Ok. I think my minimum goal then is to characterise the issue (is it all models) and make it clear in the docs which ones cannot be expected to be trainable.
As someone coming to this package to use it to train on a custom dataset it's been quite a time sync just to get to that understanding.
If it's a either-or deal between that and finding out why a specific model (e.g. EfficientNet) has problems, I'd lean towards the latter. I suspect any issues may lie at least at the level of some of the shared layer helpers, so addressing those would help other models too. This kind of deep debugging is also something myself and probably @darsnack are less likely to have time to do, whereas running a convergence test for all models would be more straightforward. But if the goal is to do both, that sounds good to me.
I lean more towards what I'm currently trying because I'm not sure I have the knowledge/skill set to dive in and debug.
Tbh if I find a model that trains and is performant I may declare victory, but share findings.
Or maybe I strike lucky in a dive.
Sorry for the late reply. I did start writing a script, but I never gotten around to starting my testing. I kept meaning to reply as soon as I did that, but it's been a busy week.
Looks like we might have some duplicate work. Either way, I uploaded my code to a repo here. I've add all of you as collaborators. The biggest help right now would be is someone can add in all the remaining models + configs to the main.jl script. Output from the script is being logged here. It should be a public wandb workspace that we can all view.
Based on the current runs that finished, it looks like all the EfficientNetv2 models have some bug. Only the :small variant trains to completion. The rest all drop to NaN loss at some point, with the larger variants dropping to NaN in the first epoch.
ResNets have the quirk Abhirath noticed during the summer where the models starts off fine then start to overfit to the training data.
I also have AlexNet and EfficientNet (v1) queued up. Let's see how those do.
I will say though that the ResNet loss curves don't look as bad as I remember them. Perhaps in this case, a different learning rate would fix things.
My recollection was that the PyTorch resnets converged faster and overfit less easily even with less help from data augmentation. Is it straightforward to do a head-to-head comparison?
I'm able to replicate the poor training behaviour of EfficientNet-B0 on a local machine, which happens to have an AMD GPU. this suggests the problem may not lie with anything CUDA-specific.
Great, thanks. How does one go about debugging this kind of thing? Are there generic tools for visualizing the model that could help?
I'll modify the script to log gradient norms by layer and also do some local debugging just to sanity check the output isn't obviously wrong.
I'll also add MobileNet to the script. I think that might be a good reference model to compare against assuming it does converge. If it works, that would narrow the issue down to the specific "inverted residual block" that EfficientNet uses.
If either of you have a machine with a decently powerful CPU, I think a CPU-only run would be interesting to see if we can isolate GPU libraries as a possibility altogether.
I'm giving EfficientNetv2 a go with cpu and -t16. Looks like each epoch will take 3 hours
julia> versioninfo()
Julia Version 1.9.4
Commit 8e5136fa297 (2023-11-14 08:46 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 32 × AMD Ryzen 9 5950X 16-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-14.0.6 (ORCJIT, znver3)
Threads: 16 on 32 virtual cores