Metalhead.jl icon indicating copy to clipboard operation
Metalhead.jl copied to clipboard

Implementation of EfficientNetv2

Open theabhirath opened this issue 3 years ago • 1 comments
trafficstars

This is an implementation of EfficientNetv2. There's clearly some code duplication between the EffNets and the MobileNets, but figuring out how to unify that API and its design is perhaps best left for after the models themselves are in place.

TODO

  • [x] Figure out a way to unify the EfficientNet and EfficientNetv2 lower level API into a single efficientnet function
  • [ ] More elaborate docs for mbconv and fused_mbconv

P.S. memory issues mean that the tests are more fragmented than they ought to be, not sure how we can go about addressing those in the short/medium term

theabhirath avatar Aug 11 '22 03:08 theabhirath

Side note - using basic_conv_bn for the Inception models seems to have fixed their gradient times somehow, which are now much more manageable. Maybe the extra biases were causing a slowdown

theabhirath avatar Aug 16 '22 08:08 theabhirath

Intermittent failures for the ConvNeXt, ConvMixer testset on Julia 1 seem memory related (nightly tests pass. Very funny thing to suddenly start occurring). The pretrained ResNets will be broken until https://github.com/FluxML/Flux.jl/pull/2041 lands. Everything else should pass though

theabhirath avatar Aug 23 '22 19:08 theabhirath

CI is now green :) (modulo ResNet note above)

ToucheSir avatar Aug 24 '22 02:08 ToucheSir

Has anyone here succeeded in training EfficientNetv2, or aware of example code that does so? I'm finding with an Adam and logitcrossentropy approach that it learns to always return one of the classes

IanButterworth avatar Dec 07 '23 20:12 IanButterworth

On ImageNet? I used the MobileNets in this package a while back and trained them on CIFAR-100. Looking at that code, I used Adam and logit CE loss. Also a bunch of data augmentation. MobileNets share some code overlap with EfficientNets in the package.

darsnack avatar Dec 07 '23 22:12 darsnack

I'm trying on a custom dataset but perhaps it makes sense to start with something tried and tested.

IanButterworth avatar Dec 07 '23 23:12 IanButterworth

Yeah if something like CIFAR-10 doesn't work, then at least someone here can try and reproduce.

darsnack avatar Dec 08 '23 12:12 darsnack

I've tried training EfficientNetV2(:small; ... on CIFAR-10 here and get

Screenshot 2023-12-08 at 2 05 00 PM

Would you mind sanity-checking my approach?

IanButterworth avatar Dec 08 '23 18:12 IanButterworth

Switching to EfficientNet(:b0; ... I get

Screenshot 2023-12-08 at 3 02 42 PM

IanButterworth avatar Dec 08 '23 20:12 IanButterworth

Okay I'll make my own script over the weekend and sanity check.

darsnack avatar Dec 08 '23 20:12 darsnack

Thanks!

Comparing to this example https://www.kaggle.com/code/paultimothymooney/efficientnetv2-with-tensorflow?scriptVersionId=120655976&cellId=27

  • It uses SGD, not Adam.
  • Learning rate is 0.005 vs the default Adam 0.001
  • Primary momentum is the same
  • It uses label smoothing

And performance maxes out within a few epochs

IanButterworth avatar Dec 08 '23 20:12 IanButterworth

@darsnack i was just wondering whether you have had a chance to test this or check my code for any obvious issues. Thanks

IanButterworth avatar Dec 12 '23 07:12 IanButterworth

Using https://github.com/IanButterworth/EfficientNet-Training/commit/ec8008408b49a38042a0c40ed58d27bd460234d2 and switching just the model to ResNet gets a much more reasonable result

model = ResNet(18; pretrain=false, inchannels=3, nclasses=length(labels)) |> device
Screenshot 2023-12-13 at 12 46 27 PM

Other than there being some issue with the model, could it be that EfficientNet requires some subtle difference in the loss function or optimizer setups?

IanButterworth avatar Dec 13 '23 17:12 IanButterworth

EfficientNet (not V2)

model = EfficientNet(:b0; pretrain=false, inchannels=3, nclasses=length(labels)) |> device
Screenshot 2023-12-13 at 12 59 21 PM

IanButterworth avatar Dec 13 '23 18:12 IanButterworth

I was trying a lot of these on CIFAR-10 during my GSoC and was facing issues, including with ResNet – the accuracies were nowhere near what I could get with PyTorch, for example. I remember trying to debug this but then had to give up since I got occupied with other commitments. IIRC at the time one theory was that there might be something wrong with the gradients, but we didn't nearly manage to get enough training examples on the GPU to confirm. I could try running these again if I got a GPU somehow 😅

The difference between EfficientNet and ResNet is weird, but not unexpected because they do use different kinds of layers. Maybe MobileNet vs EfficientNet tells us more? The underlying code over there is the exact same because of the way the models are constructed. Even for ResNet, though, the loss curve looks kinda bumpy...

theabhirath avatar Dec 14 '23 03:12 theabhirath

Interesting.

I was considering setting up a job to do the same for all appropriate models in Metalhead on my machine, and tracking any useful metrics while doing so for a comparison table. I have a relatively decent setup so it might not take too long.

I can understand that full training isn't a part of CI because of resources, but I think it needs to be demonstrated that these models are trainable somewhere public.

IanButterworth avatar Dec 14 '23 03:12 IanButterworth

I've started putting together a benchmarking script here. https://github.com/FluxML/Metalhead.jl/pull/264 It'd be great to get feedback on it.

IanButterworth avatar Dec 14 '23 17:12 IanButterworth

As @theabhirath mentioned, this has come up before but we never got to the bottom of why because of time constraints. If you have some bandwidth to do some diagnostic runs looking at whether gradients, activations etc are misbehaving and where they're misbehaving, we could narrow the problem down to something actionable. Whether that be a bug in Metalhead itself or something further upstream like Flux, Zygote or CUDA.

ToucheSir avatar Dec 15 '23 00:12 ToucheSir

Ok. I think my minimum goal then is to characterise the issue (is it all models) and make it clear in the docs which ones cannot be expected to be trainable.

As someone coming to this package to use it to train on a custom dataset it's been quite a time sync just to get to that understanding.

IanButterworth avatar Dec 15 '23 01:12 IanButterworth

If it's a either-or deal between that and finding out why a specific model (e.g. EfficientNet) has problems, I'd lean towards the latter. I suspect any issues may lie at least at the level of some of the shared layer helpers, so addressing those would help other models too. This kind of deep debugging is also something myself and probably @darsnack are less likely to have time to do, whereas running a convergence test for all models would be more straightforward. But if the goal is to do both, that sounds good to me.

ToucheSir avatar Dec 15 '23 01:12 ToucheSir

I lean more towards what I'm currently trying because I'm not sure I have the knowledge/skill set to dive in and debug.

Tbh if I find a model that trains and is performant I may declare victory, but share findings.

Or maybe I strike lucky in a dive.

IanButterworth avatar Dec 15 '23 01:12 IanButterworth

Sorry for the late reply. I did start writing a script, but I never gotten around to starting my testing. I kept meaning to reply as soon as I did that, but it's been a busy week.

Looks like we might have some duplicate work. Either way, I uploaded my code to a repo here. I've add all of you as collaborators. The biggest help right now would be is someone can add in all the remaining models + configs to the main.jl script. Output from the script is being logged here. It should be a public wandb workspace that we can all view.

darsnack avatar Dec 15 '23 02:12 darsnack

Based on the current runs that finished, it looks like all the EfficientNetv2 models have some bug. Only the :small variant trains to completion. The rest all drop to NaN loss at some point, with the larger variants dropping to NaN in the first epoch.

ResNets have the quirk Abhirath noticed during the summer where the models starts off fine then start to overfit to the training data.

I also have AlexNet and EfficientNet (v1) queued up. Let's see how those do.

darsnack avatar Dec 15 '23 12:12 darsnack

I will say though that the ResNet loss curves don't look as bad as I remember them. Perhaps in this case, a different learning rate would fix things.

darsnack avatar Dec 15 '23 12:12 darsnack

My recollection was that the PyTorch resnets converged faster and overfit less easily even with less help from data augmentation. Is it straightforward to do a head-to-head comparison?

ToucheSir avatar Dec 15 '23 14:12 ToucheSir

I'm able to replicate the poor training behaviour of EfficientNet-B0 on a local machine, which happens to have an AMD GPU. this suggests the problem may not lie with anything CUDA-specific.

ToucheSir avatar Dec 16 '23 06:12 ToucheSir

Great, thanks. How does one go about debugging this kind of thing? Are there generic tools for visualizing the model that could help?

IanButterworth avatar Dec 16 '23 13:12 IanButterworth

I'll modify the script to log gradient norms by layer and also do some local debugging just to sanity check the output isn't obviously wrong.

I'll also add MobileNet to the script. I think that might be a good reference model to compare against assuming it does converge. If it works, that would narrow the issue down to the specific "inverted residual block" that EfficientNet uses.

darsnack avatar Dec 16 '23 17:12 darsnack

If either of you have a machine with a decently powerful CPU, I think a CPU-only run would be interesting to see if we can isolate GPU libraries as a possibility altogether.

ToucheSir avatar Dec 16 '23 20:12 ToucheSir

I'm giving EfficientNetv2 a go with cpu and -t16. Looks like each epoch will take 3 hours Screenshot 2023-12-16 at 8 58 09 PM

julia> versioninfo()
Julia Version 1.9.4
Commit 8e5136fa297 (2023-11-14 08:46 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 32 × AMD Ryzen 9 5950X 16-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, znver3)
  Threads: 16 on 32 virtual cores

IanButterworth avatar Dec 17 '23 01:12 IanButterworth