Zygote.jl
Zygote.jl copied to clipboard
Error with `Zygote.gradient` for `foldl, sum`
MWE
using Zygote, LinearAlgebra
N = 4
u0 = rand(N)
ps = rand(N)
mats = (rand(N,N), rand(N,N),) # (A, B,)
nums = (rand(), rand(),) # (α, β,)
loss_m = function(p)
v = Diagonal(p) * u0
v = Zygote.hook(Δ -> (println("Δv: ", typeof(Δ)); Δ), v)
w = foldl((acc, op) -> op * acc, mats; init=v) # w = B * A * v
w = Zygote.hook(Δ -> (println("Δw: ", Δ); Δ), w)
l = sum(w)
l = Zygote.hook(Δ -> (println("Δl: ", Δ); Δ), l)
end
println("fwd"); @time loss_m(ps) |> display
println("bwd"); @time Zygote.gradient(loss_m, ps) |> display # INCORRECT - should not vanish
loss_n = function(p)
v = Diagonal(p) * u0
v = Zygote.hook(Δ -> (println("Δv: ", typeof(Δ)); Δ), v)
w = sum(a -> convert(Number, a), nums; init=zero(eltype(nums))) * v # w = αβ * v
w = Zygote.hook(Δ -> (println("Δw: ", Δ); Δ), w)
l = sum(w)
l = Zygote.hook(Δ -> (println("Δl: ", Δ); Δ), l)
end
println("fwd"); @time loss_n(ps) |> display
println("bwd"); @time Zygote.gradient(loss_n, ps) |> display # ERRORS
julia> include("examples/ad/zy.jl")
fwd
4.339451806053281
0.021413 seconds (44.18 k allocations: 2.637 MiB, 99.38% compilation time)
bwd
Δl: 1.0
Δw: Fill(1.0, 4)
Δv: Nothing
(nothing,)
0.139943 seconds (444.37 k allocations: 23.545 MiB, 99.45% compilation time)
fwd
1.5660193401267022
0.355185 seconds (1.11 M allocations: 65.174 MiB, 99.67% compilation time)
bwd
ERROR: LoadError: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
iterate(::Union{LinRange, StepRangeLen}) at range.jl:872
iterate(::Union{LinRange, StepRangeLen}, ::Integer) at range.jl:872
iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}} at dict.jl:712
...
Stacktrace:
[1] indexed_iterate(I::Nothing, i::Int64)
@ Base ./tuple.jl:91
[2] chain_rrule_kw
@ ~/.julia/packages/Zygote/D7j8v/src/compiler/chainrules.jl:229 [inlined]
[3] macro expansion
@ ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0 [inlined]
[4] _pullback(::Zygote.Context, ::Base.var"#sum##kw", ::NamedTuple{(:init,), Tuple{Float64}}, ::typeof(sum), ::var"#49#54", ::Tuple{Float64, Float64})
@ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:9
[5] _pullback
@ ~/.julia/dev/PDEInterfaces/examples/ad/zy.jl:29 [inlined]
[6] _pullback(ctx::Zygote.Context, f::var"#47#52", args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
[7] _pullback(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:34
[8] pullback(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:40
[9] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:75
[10] top-level scope
@ ./timing.jl:242
[11] include(fname::String)
@ Base.MainInclude ./client.jl:476
[12] top-level scope
@ REPL[2]:1
[13] top-level scope
@ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52
in expression starting at /home/vedantpu/.julia/dev/PDEInterfaces/examples/ad/zy.jl:38
ref - https://github.com/SciML/SciMLOperators.jl/pull/94
the case with sum works when I remove the kwarg init
. but still curious why it wouldn't work otherwise
foldl
not tracking init
keyword is https://github.com/JuliaDiff/ChainRules.jl/issues/567, you could try with https://github.com/JuliaDiff/ChainRules.jl/pull/569
sum
not supporting init
is also bad, could you make an issue on ChainRules.jl?
julia> ChainRules.rrule(sum, [1,2,3]; init=4)
ERROR: MethodError: no method matching rrule(::typeof(sum), ::Vector{Int64}; init::Int64)
Closest candidates are:
rrule(::typeof(sum), ::AbstractArray; dims) got unsupported keyword argument "init"
@ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/mapreduce.jl:28
rrule(::typeof(sum), ::Any, ::AbstractArray{Bool}; sum_pullback) got unsupported keyword argument "init"
@ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/nondiff.jl:82