Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

Improve type stability of LayerNorm and Dropout

Open ToucheSir opened this issue 2 years ago • 8 comments

These two layers made use of explicit or implicit control flow (e.g. default keyword argument values) which Zygote does not like. This PR is essentially a set of small hacks to work around that.

Any ideas on how to avoid return_type in _dropout would be much appreciated, but for now it seems to work.

TODO benchmarks.

PR Checklist

  • [ ] Entry in NEWS.md

ToucheSir avatar Jun 23 '22 05:06 ToucheSir

TTFG timings using the following snippet:

Test code
using Metalhead, Flux, Zygote
using Metalhead: ChannelLayerNorm

model = ConvNeXt(:tiny; inchannels=1, nclasses=1).layers
# ChannelLayerNorm isn't type stable yet (for the same reason as LayerNorm wasn't),
# So remove it for this demo
model = fmap(Returns(identity), model; exclude=Base.Fix2(isa, ChannelLayerNorm))

# display(model); println()

loss(m, x) = sum(m(x))

inputs = randn(Float32, 32, 32, 1, 1)
# @time loss(model, inputs)
# @time loss(model, inputs)

loss_grad(m, x) = gradient((m, x) -> loss(m, x), m, x)

@time loss_grad(model, inputs)
# @time loss_grad(model, inputs)
julia> @time loss_grad(model, inputs)
 34.835647 seconds (87.12 M allocations: 4.701 GiB, 3.14% gc time, 99.38% compilation time) # 0.13.3
 30.679322 seconds (78.88 M allocations: 4.300 GiB, 3.46% gc time, 98.96% compilation time) # this PR

Replacing the Chain{Vector} with a Chain{Tuple} creates a larger gap:

julia> @time loss_grad(model, inputs)
 79.846248 seconds (98.87 M allocations: 5.243 GiB, 1.68% gc time, 99.67% compilation time) # 0.13.3
 63.024710 seconds (79.23 M allocations: 4.245 GiB, 1.92% gc time, 99.45% compilation time) # this PR
 52.838056 seconds (70.81 M allocations: 3.745 GiB, 1.98% gc time, 99.60% compilation time) # this PR + Zygote#1248

ToucheSir avatar Jun 27 '22 23:06 ToucheSir

For kicks, here is Diffractor with https://github.com/JuliaDiff/ChainRules.jl/pull/644:

julia> @time loss_grad(model, inputs)
 30.442982 seconds (92.61 M allocations: 4.148 GiB, 3.18% gc time, 89.07% compilation time) # tuple chain
 23.051121 seconds (88.06 M allocations: 3.920 GiB, 3.81% gc time, 85.11% compilation time) # vector chain, requires https://github.com/JuliaDiff/Diffractor.jl/pull/82

Re-enabling ChannelLayerNorm adds but ~1s to the total. Note that even the tuple Chain here is faster than any tested Zygote configuration.

Edit: added times for vector chains using a patched Diffractor.

ToucheSir avatar Aug 01 '22 00:08 ToucheSir

Does Diffractor already work with most Flux models (or at least those with built-in layers)? I was under the impression that it wasn't there yet 😅

theabhirath avatar Aug 01 '22 01:08 theabhirath

Not OOTB, which is why that ChainRules PR is required.

ToucheSir avatar Aug 01 '22 01:08 ToucheSir

@ToucheSir Could you try running the layer norm gradient with gpu? I have try that manual broadcast fusion before but CUDA.time said it actually allocated more gpu memory

chengchingwen avatar Aug 01 '22 04:08 chengchingwen

You're right, it allocates one more time for over 2x the memory overhead. I also found this out the hard way recently while trying to fuse the RNN cell kernels for https://github.com/FluxML/Flux.jl/pull/2023, but forgot about the change here.

ToucheSir avatar Aug 01 '22 05:08 ToucheSir

Codecov Report

Merging #2005 (29ef2ff) into master (d66d2c4) will increase coverage by 0.27%. The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master    #2005      +/-   ##
==========================================
+ Coverage   87.10%   87.37%   +0.27%     
==========================================
  Files          20       20              
  Lines        1528     1553      +25     
==========================================
+ Hits         1331     1357      +26     
+ Misses        197      196       -1     
Impacted Files Coverage Δ
src/Flux.jl 0.00% <ø> (ø)
src/layers/normalise.jl 90.28% <100.00%> (+1.46%) :arrow_up:
src/layers/stateless.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update d66d2c4...29ef2ff. Read the comment docs.

codecov-commenter avatar Aug 01 '22 05:08 codecov-commenter

Any updates on this (like benchmarks after unfusing)?

darsnack avatar Aug 10 '22 15:08 darsnack