Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

`rrule` for `Recur`

Open mcabbott opened this issue 3 years ago • 0 comments
trafficstars

Before, Diffractor does not like mutation, thus fails on RNNs:

julia> using Flux, Zygote, Diffractor

julia> Zygote.gradient(m -> sum(abs2, m([1 2; 3 4f0])), RNN(2 => 3; init=Flux.ones32))
((cell = (σ = nothing, Wi = Float32[0.0027779222 0.008236016; 0.0027779222 0.008236016; 0.0027779222 0.008236016], Wh = Float32[0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], b = Float32[0.0027290469, 0.0027290469, 0.0027290469], state0 = nothing), state = Float32[0.00818714; 0.00818714; 0.00818714;;]),)

julia> Diffractor.gradient(m -> sum(abs2, m([1 2; 3 4f0])), RNN(2 => 3; init=Flux.ones32))
ERROR: MethodError: no method matching copy(::Nothing)
...
Stacktrace:
  [1] perform_optic_transform(ff::Type{Diffractor.∂⃖recurse{1}}, args::Any)
    @ Diffractor ~/.julia/packages/Diffractor/XDXfC/src/stage1/generated.jl:22
...
  [4] setproperty!
    @ ./Base.jl:38 [inlined]

After:

julia> Diffractor.gradient(m -> sum(abs2, m([1 2; 3 4f0])), RNN(2 => 3; init=Flux.ones32))
(Tangent{Flux.Recur}(cell = Tangent{Flux.RNNCell{typeof(tanh), Matrix{Float32}, Vector{Float32}, Matrix{Float32}}}(σ = ChainRulesCore.NoTangent(), b = Float32[0.0027290469, 0.0027290469, 0.0027290469], Wh = Float32[0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], Wi = Float32[0.0027779222 0.008236016; 0.0027779222 0.008236016; 0.0027779222 0.008236016]), state = Float32[0.00818714; 0.00818714; 0.00818714;;]),)

# And with Array{T,3}

julia> Zygote.gradient(m -> sum(abs2, m(reshape(1:24, 2,3,4).+0f0)), RNN(2 => 3; init=Flux.ones32))
((cell = (σ = nothing, Wi = Float32[0.019653967 0.03929506; 0.019653967 0.03929506; 0.019653967 0.03929506], Wh = Float32[0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], b = Float32[0.019641092, 0.019641092, 0.019641092], state0 = nothing), state = Float32[0.058923274 0.058923274 0.058923274; 0.058923274 0.058923274 0.058923274; 0.058923274 0.058923274 0.058923274]),)

The @opt_out is needed to keep the Array{T,3} case working on Zygote; it does not yet work on Diffractor for the lack of a rule.

This rrule may mean Zygote is storing all the state in repeated applications of the rule, instead of in its configuration IdDict, which may matter for performance. On one very crude test it seems to be an improvement, perhaps others have more serious tests?

julia> @btime Zygote.gradient(m -> sum(abs2, m($(randn(Float32, 2, 100)))), $(RNN(2 => 3)));
  min 23.833 μs, mean 94.316 μs (97 allocations, 23.12 KiB)  # before
  min 23.458 μs, mean 27.116 μs (97 allocations, 23.12 KiB)  # after
  
julia> @btime Zygote.gradient(m -> sum(abs2, m($(randn(Float32, 20, 100)))), $(LSTM(20 => 30)));
  min 138.375 μs, mean 1.508 ms (112 allocations, 505.38 KiB)  # before
  min 127.416 μs, mean 190.538 μs (62 allocations, 496.33 KiB)  # after

More serious correctness tests might also be a good idea. I haven't looked at the test file for this.

Edit: Marked draft as many tests fail, I presume this is giving wrong gradients.

mcabbott avatar Oct 05 '22 03:10 mcabbott