Flux.jl
Flux.jl copied to clipboard
Improve type stability of LayerNorm and Dropout
These two layers made use of explicit or implicit control flow (e.g. default keyword argument values) which Zygote does not like. This PR is essentially a set of small hacks to work around that.
Any ideas on how to avoid return_type
in _dropout
would be much appreciated, but for now it seems to work.
TODO benchmarks.
PR Checklist
- [ ] Entry in NEWS.md
TTFG timings using the following snippet:
Test code
using Metalhead, Flux, Zygote
using Metalhead: ChannelLayerNorm
model = ConvNeXt(:tiny; inchannels=1, nclasses=1).layers
# ChannelLayerNorm isn't type stable yet (for the same reason as LayerNorm wasn't),
# So remove it for this demo
model = fmap(Returns(identity), model; exclude=Base.Fix2(isa, ChannelLayerNorm))
# display(model); println()
loss(m, x) = sum(m(x))
inputs = randn(Float32, 32, 32, 1, 1)
# @time loss(model, inputs)
# @time loss(model, inputs)
loss_grad(m, x) = gradient((m, x) -> loss(m, x), m, x)
@time loss_grad(model, inputs)
# @time loss_grad(model, inputs)
julia> @time loss_grad(model, inputs)
34.835647 seconds (87.12 M allocations: 4.701 GiB, 3.14% gc time, 99.38% compilation time) # 0.13.3
30.679322 seconds (78.88 M allocations: 4.300 GiB, 3.46% gc time, 98.96% compilation time) # this PR
Replacing the Chain{Vector}
with a Chain{Tuple}
creates a larger gap:
julia> @time loss_grad(model, inputs)
79.846248 seconds (98.87 M allocations: 5.243 GiB, 1.68% gc time, 99.67% compilation time) # 0.13.3
63.024710 seconds (79.23 M allocations: 4.245 GiB, 1.92% gc time, 99.45% compilation time) # this PR
52.838056 seconds (70.81 M allocations: 3.745 GiB, 1.98% gc time, 99.60% compilation time) # this PR + Zygote#1248
For kicks, here is Diffractor with https://github.com/JuliaDiff/ChainRules.jl/pull/644:
julia> @time loss_grad(model, inputs)
30.442982 seconds (92.61 M allocations: 4.148 GiB, 3.18% gc time, 89.07% compilation time) # tuple chain
23.051121 seconds (88.06 M allocations: 3.920 GiB, 3.81% gc time, 85.11% compilation time) # vector chain, requires https://github.com/JuliaDiff/Diffractor.jl/pull/82
Re-enabling ChannelLayerNorm
adds but ~1s to the total. Note that even the tuple Chain here is faster than any tested Zygote configuration.
Edit: added times for vector chains using a patched Diffractor.
Does Diffractor already work with most Flux models (or at least those with built-in layers)? I was under the impression that it wasn't there yet 😅
Not OOTB, which is why that ChainRules PR is required.
@ToucheSir Could you try running the layer norm gradient with gpu? I have try that manual broadcast fusion before but CUDA.time
said it actually allocated more gpu memory
You're right, it allocates one more time for over 2x the memory overhead. I also found this out the hard way recently while trying to fuse the RNN cell kernels for https://github.com/FluxML/Flux.jl/pull/2023, but forgot about the change here.
Codecov Report
Merging #2005 (29ef2ff) into master (d66d2c4) will increase coverage by
0.27%
. The diff coverage is100.00%
.
@@ Coverage Diff @@
## master #2005 +/- ##
==========================================
+ Coverage 87.10% 87.37% +0.27%
==========================================
Files 20 20
Lines 1528 1553 +25
==========================================
+ Hits 1331 1357 +26
+ Misses 197 196 -1
Impacted Files | Coverage Δ | |
---|---|---|
src/Flux.jl | 0.00% <ø> (ø) |
|
src/layers/normalise.jl | 90.28% <100.00%> (+1.46%) |
:arrow_up: |
src/layers/stateless.jl | 100.00% <100.00%> (ø) |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact)
,ø = not affected
,? = missing data
Powered by Codecov. Last update d66d2c4...29ef2ff. Read the comment docs.
Any updates on this (like benchmarks after unfusing)?