Metalhead.jl
Metalhead.jl copied to clipboard
Add model implementations
Below are models that we still need to implement
Will implement
These are models for which we are actively seeking implementations/weights.
Model implementations
- [x] VGG
- [x] ResNet
- [x] DenseNet
- [x] SqueezeNet
- [x] GoogLeNet
- [x] Inception v3
- [x] Inception v4 (see #170)
- [x] Inception ResNet v2 (see #170)
- [x] Xception (see #170)
- [x] ResNeXt
- [x] MobileNet v1
- [x] MobileNet v2
- [x] MLPMixer
- [x] Vision Transformer (see #105)
- [x] EfficientNet (see #113 #171)
- [x] EfficientNet v2
- [ ] CaiT
- [ ] LeViT
- [ ] Swin
- [ ] MobileViT
- [x] ConvNeXt (see #119)
- [x] UNet (see #210)
- [ ] DeiT
Pre-trained weights
- [x] VGG (see #164)
- [x] ResNet (see #164)
- [x] DenseNet
- [x] SqueezeNet
- [ ] GoogLeNet
- [ ] Inception v3
- [ ] Inception v4
- [ ] Inception ResNet v2
- [ ] Xception
- [x] ResNeXt
- [ ] MobileNet v1
- [ ] MobileNet v2
- [ ] MLPMixer
- [x] Vision Transformer
- [ ] EfficientNet
- [ ] EfficientNet v2
- [ ] CaiT
- [ ] LeViT
- [ ] Swin
- [ ] MobileViT
- [ ] ConvNeXt
- [ ] UNet
- [ ] DeiT
Will consider implementing
These are models that aren't necessarily must haves, but we are happy to accept contributions that add them.
- ESRGAN (see #118).
cc @theabhirath please add the models you are planning on implementing to this list
A starting point for EfficientNet can be found at https://github.com/pxl-th/EfficientNet.jl (see #106).
I'm planning on
at first. I also want to start on object detection and semantic segmentation models but there's a lot of helper functions related to the fields that I will have to write - if everything CV-related is planned to be in Metalhead, then I'll go ahead and start coding them up
Hello, Can I work on inception v4 and efficient net?(Also I have already coded ESRGANs)
There's #113 for efficient net. Would be good to port srgan etc into Metalhead as well
I could provide the pretrained weights for VGG and ResNet converted from PyTorch (once some minor changes to the Metalhead models are merged so that they are equivalent to the PyTorch model).
(I would be very interested in an ESRGAN implementation. 😃)
Also see #109 for pretrained models. It's hard to know before hand which weights would work well between pytorch and Flux, but if we have some pretrained weights we can validate, that would be welcome!
Hello, Can I work on inception v4 and efficient net?(Also I have already coded ESRGANs)
As mentioned, there is already a PR for EfficientNet, but InceptionNet v4 would be very welcome! ESRGAN would be welcome too.
I could provide the pretrained weights for VGG and ResNet converted from PyTorch (once some minor changes to the Metalhead models are merged so that they are equivalent to the PyTorch model).
Yes, I think your flow will work well for both those models. Please submit PRs to MetalheadWeights when you have them!
How do I contribute models here? I'm fairly new here
The Flux contribution guide has some info as well as links for how to get started with making your first contribution.
I'll briefly summarize the process for this repo (apologies if you already know this):
- Fork this repo (there is a button to "fork" on the front page). This creates your own copy of the repo.
- Clone your fork locally onto your machine.
- Create a new branch for the model or feature that you are trying to add.
- Edit the code to add the new feature or create a new file that implements the model. It would be useful to skim the existing model code so you get an idea of what helper functions already exist and the model building style that the repo follows.
- Make commits to your feature branch and push them to your fork.
- Go to "Pull Requests" on this repo and click "New Pull Request." Choose your feature branch from your fork as the source.
- We will review your pull request and provide feedback. You can incorporate the feedback by making more commits to your feature branch. The pull request will update automatically. Together, we'll iterate the code until the design is ready to merge.
- At this point, one of the maintainers will merge your pull request, and you're done!
I hope this is the right place for this discussion, but it looks like using PyTorch weights might not be too complicated. I've had success opening Torch's .pth
files using Pickle.jl:
using Downloads: download
using Pickle
using Flux, Metalhead
# Links from https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py
model_urls = Dict(
"vgg11" => "https://download.pytorch.org/models/vgg11-8a719046.pth",
"vgg13" => "https://download.pytorch.org/models/vgg13-19584684.pth",
"vgg16" => "https://download.pytorch.org/models/vgg16-397923af.pth",
"vgg19" => "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
"vgg11_bn" => "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
"vgg13_bn" => "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
"vgg16_bn" => "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
"vgg19_bn" => "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
)
model_name = "vgg11"
datadir = joinpath(@__DIR__, "data")
path = joinpath(datadir, "$model_name.pth")
!isdir(datadir) && mkdir(datadir)
!isfile(path) && download(model_urls[model_name], path)
# Torchvision uses NCHW ordering, Flux WHCN
function permute_weights(A::AbstractArray{T, N}) where {T, N}
if N == 4
return permutedims(A, (4, 3, 2, 1))
end
return A
end
torchweights = Pickle.Torch.THload(path)
weights = map(permute_weights, torchweights.vals)
model = VGG11()
Flux.loadparams!(model, weights)
where torchweights
is an OrderedDict
:
julia> torchweights
OrderedCollections.OrderedDict{Any, Any} with 22 entries:
"features.0.weight" => [0.288164 0.401512 0.216151; -0.3528 -0.574001 -0.028024;…
"features.0.bias" => Float32[0.193867, 0.304219, 0.18251, -1.11219, 0.0441538,…
"features.3.weight" => [0.0302545 0.0595999 … 0.00978228 0.0180896; -0.00835048 …
"features.3.bias" => Float32[-0.0372088, -0.115514, 0.148786, -0.106784, 0.153…
"features.6.weight" => [-0.0107434 0.00274947 … 0.0393466 0.0168702; -0.0110679 …
"features.6.bias" => Float32[0.0696629, -0.0745776, 0.0681913, -0.115447, 0.11…
"features.8.weight" => [-0.0158845 -0.0120116 … -0.0287082 0.00195862; -0.026421…
"features.8.bias" => Float32[-0.00635051, 0.031504, 0.0732542, 0.0478025, 0.32…
"features.11.weight" => [0.0145356 -0.0262731 … -0.0032392 0.0459495; -0.00929091…
"features.11.bias" => Float32[-0.0146084, 0.187013, -0.0683434, 0.0223707, -0.0…
"features.13.weight" => [0.0514621 -0.0490013 … -0.00740175 0.00124351; 0.0382785…
"features.13.bias" => Float32[0.324421, -0.00724723, 0.0839103, 0.180003, 0.075…
"features.16.weight" => [0.00122506 -0.0199895 … -0.0369922 -0.0188395; -0.023495…
"features.16.bias" => Float32[-0.0753394, 0.19634, 0.0544855, 0.0230991, 0.2478…
"features.18.weight" => [0.0104981 -0.0085396 … -0.00996796 0.00263586; -0.022834…
"features.18.bias" => Float32[-0.0132562, -0.128536, -0.021685, 0.0401009, 0.14…
"classifier.0.weight" => Float32[-0.00160107 -0.00533261 … 0.00406176 0.0014802; 0…
"classifier.0.bias" => Float32[0.0252175, -0.00486407, 0.0436967, 0.0097529, 0.0…
"classifier.3.weight" => Float32[0.00620286 -0.0210906 … -0.00731366 -0.0212599; -…
"classifier.3.bias" => Float32[0.0449707, 0.0848237, 0.0727477, 0.0816414, 0.074…
"classifier.6.weight" => Float32[-0.0118874 0.0186423 … 0.0170721 0.0105425; 0.033…
"classifier.6.bias" => Float32[0.0146753, -0.00972212, -0.0238722, -0.0290253, -…
If this looks promising to you in any way, I'd be glad to open a draft PR with some guidance. :)
Edit: transposed weights according to Alexander's feedback. Thanks!
Indeed!, In this PR I have taken a similar approach (but calling PyTorch via PyCall): https://github.com/FluxML/MetalheadWeights/pull/2
Sometimes the weight have to be transposed as you found by the error message. https://github.com/FluxML/MetalheadWeights/pull/2/files#diff-796965064d4fe9651f28a44c7b3dcd003b4690f37e619bb91782c283cee5a159R99
The ability to use PyTorch weights without needing PyTorch installed is indeed very nice!
This seems to be also relevant: https://github.com/FluxML/Flux.jl/issues/1907
I would like to give LeViT a try, if someone else is not already working on it.
@darsnack I would like to take a crack at MobileVIT ! Shall I go ahead and create a new issue for it ? I'll try to convert Pytorch weights from apple/cvnets.
Go for it! For both ViT-based models, make sure to look at the existing ViT implementation and the Layers submodules. Torchvision models are good reference, but ultimately we want something that is a Flux-idiomatic model not a simple clone.