model-zoo
model-zoo copied to clipboard
Continue tutorials on Zygote branch
Hello Devs,
Im dying to use FluxML properly but the tutorials section seems to be in a bit of a mess. I was talking to @DhairyaLGandhi and checked out the dg/Zygote branch. The codes there are much cleaner and I feel they are more consistent. Since this zoo is our gateway to the world, can we clean up a bit? ( I really wish more people used Flux. Its so nice and fun to use)
Also as a side. Can we please change the folder structure? CIFAR10 is not really a class of things. How about just VGG/ etc etc. That would be so much easier to work with and look up.
If it is a yes, I can fix the codes (as much as I can manage) and drop a PR. What say?
Have an awesome day and hope you are not Tensor :smiley:
As an example, here is an architecture of VGG16 I think would just be nice to have (because it is general and can be extended to fit other architectures like resnet and so on) and so much easier to follow. It is essentially a refactoring of the current code which is more readable and extensible. So just this could be extended to fit so many types of architectures instead of a huge list of repeating operations I feel.
# Conv_block is just Conv-> relu -> batchnorm taking into account in/out channels
conv_block(in_channels, out_channels) = Chain(
Conv((3,3), in_channels => out_channels, relu, pad = (1,1), stride = (1,1)),
BatchNorm(out_channels))
# DoubleConv follows VGG and returns conv_block twice with the second one having a repeated out channel followed by a max pool
double_conv(in_channels, out_channels) = Chain(
conv_block(in_channels, out_channels),
conv_block(out_channels, out_channels),
MaxPool((2,2)))
# General model def with start channels and number of classes
vgg16(initial_channels, num_classes) = Chain(
double_conv(initial_channels, 64),
double_conv(64,128),
conv_block(128, 256),
double_conv(256, 256),
conv_block(256, 512),
double_conv(512, 512),
conv_block(512, 512),
double_conv(512, 512),
x -> reshape(x, :, size(x, 4)),
Dense(512, 4096, relu),
Dropout(0.5),
Dense(4096, 4096, relu),
Dropout(0.5),
Dense(4096, num_classes),
softmax
) |> gpu
Waiting to look at your comments :))
@SubhadityaMukherjee maintainers welcome PRs in general, especially the ones that update the outdated stuff here. So you can go ahead and prepare one.