Flux.jl
Flux.jl copied to clipboard
`rrule` for `Recur`
trafficstars
Before, Diffractor does not like mutation, thus fails on RNNs:
julia> using Flux, Zygote, Diffractor
julia> Zygote.gradient(m -> sum(abs2, m([1 2; 3 4f0])), RNN(2 => 3; init=Flux.ones32))
((cell = (σ = nothing, Wi = Float32[0.0027779222 0.008236016; 0.0027779222 0.008236016; 0.0027779222 0.008236016], Wh = Float32[0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], b = Float32[0.0027290469, 0.0027290469, 0.0027290469], state0 = nothing), state = Float32[0.00818714; 0.00818714; 0.00818714;;]),)
julia> Diffractor.gradient(m -> sum(abs2, m([1 2; 3 4f0])), RNN(2 => 3; init=Flux.ones32))
ERROR: MethodError: no method matching copy(::Nothing)
...
Stacktrace:
[1] perform_optic_transform(ff::Type{Diffractor.∂⃖recurse{1}}, args::Any)
@ Diffractor ~/.julia/packages/Diffractor/XDXfC/src/stage1/generated.jl:22
...
[4] setproperty!
@ ./Base.jl:38 [inlined]
After:
julia> Diffractor.gradient(m -> sum(abs2, m([1 2; 3 4f0])), RNN(2 => 3; init=Flux.ones32))
(Tangent{Flux.Recur}(cell = Tangent{Flux.RNNCell{typeof(tanh), Matrix{Float32}, Vector{Float32}, Matrix{Float32}}}(σ = ChainRulesCore.NoTangent(), b = Float32[0.0027290469, 0.0027290469, 0.0027290469], Wh = Float32[0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], Wi = Float32[0.0027779222 0.008236016; 0.0027779222 0.008236016; 0.0027779222 0.008236016]), state = Float32[0.00818714; 0.00818714; 0.00818714;;]),)
# And with Array{T,3}
julia> Zygote.gradient(m -> sum(abs2, m(reshape(1:24, 2,3,4).+0f0)), RNN(2 => 3; init=Flux.ones32))
((cell = (σ = nothing, Wi = Float32[0.019653967 0.03929506; 0.019653967 0.03929506; 0.019653967 0.03929506], Wh = Float32[0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], b = Float32[0.019641092, 0.019641092, 0.019641092], state0 = nothing), state = Float32[0.058923274 0.058923274 0.058923274; 0.058923274 0.058923274 0.058923274; 0.058923274 0.058923274 0.058923274]),)
The @opt_out is needed to keep the Array{T,3} case working on Zygote; it does not yet work on Diffractor for the lack of a rule.
This rrule may mean Zygote is storing all the state in repeated applications of the rule, instead of in its configuration IdDict, which may matter for performance. On one very crude test it seems to be an improvement, perhaps others have more serious tests?
julia> @btime Zygote.gradient(m -> sum(abs2, m($(randn(Float32, 2, 100)))), $(RNN(2 => 3)));
min 23.833 μs, mean 94.316 μs (97 allocations, 23.12 KiB) # before
min 23.458 μs, mean 27.116 μs (97 allocations, 23.12 KiB) # after
julia> @btime Zygote.gradient(m -> sum(abs2, m($(randn(Float32, 20, 100)))), $(LSTM(20 => 30)));
min 138.375 μs, mean 1.508 ms (112 allocations, 505.38 KiB) # before
min 127.416 μs, mean 190.538 μs (62 allocations, 496.33 KiB) # after
More serious correctness tests might also be a good idea. I haven't looked at the test file for this.
Edit: Marked draft as many tests fail, I presume this is giving wrong gradients.