NNlib.jl
NNlib.jl copied to clipboard
BatchNorm and Dropout
Does it make sense to put functional forms of BatchNorm and Dropout into NNlib so that other packages could simply import them from here?
Yeah, that's a great idea.
BatchNorm is going to be tricky because of its "statistics update step". The current thinking behind this with Zygote is to do something along the lines of https://gist.github.com/staticfloat/a509b1e1cb1fb556028779722c2531e6
Now that Flux's normalization interface has been re-worked and GPU batchnorm moved from CUDA.jl -> NNlibCUDA, perhaps we should revisit this. The only reason https://github.com/FluxML/Flux.jl/tree/master/src/cuda exists at all now is to accommodate a non-standard implementation of batchnorm, so getting rid of that would be great.
I was looking into porting the functional form of normalization layers here, but I'm not sure how to handle the Zygote.ignore
block without having NNlib depend on Zygote
The concern has been raised earlier and is fixed by https://github.com/FluxML/Flux.jl/pull/1509
I don't think there's any reason these dropout
functions need to live in Flux. Shall we move them over? Happy to volunteer for a copy-paste PR if we're all in agreement. This would also unblock https://github.com/FluxML/Flux.jl/pull/1572.