Zygote.jl
Zygote.jl copied to clipboard
Some missing derivative with ranges?
I ran a course the other week and a student found this:
using Flux
using Statistics
# Create NN
NN = Chain(Dense(2,16,tanh),
Dense(16,16,tanh),
Dense(16,1),
first)
NN([0.5,0.5])
# Grid data for sampling
#xg = collect(0:0.1:1) # Will start the training although it shows NaN's when going through it (program unfinished)
#yg = collect(0:0.1:1)
xg = 0:0.1:1 # DOES NOT WORK WITH RANGES??
yg = 0:0.1:1
# Finite differences, 2nd order
ϵ = sqrt(eps(Float32))
NNxx(x,y) = ( NN([x+ϵ,y]) - 2*NN([x,y]) + NN([x-ϵ,y]) ) / ϵ^2
NNyy(x,y) = ( NN([x,y+ϵ]) - 2*NN([x,y]) + NN([x,y-ϵ]) ) / ϵ^2
# Create loss function for our physical relation
function loss_ode()
summ = 0.
for xp in xg
for yp in yg
summ += abs2( NNxx(xp,yp) + NNyy(xp,yp) + sin(π*xp)*sin(π*yp) )
end
end
return summ
end
loss_ode()
# Define loss function for boundary conditions
function loss_BC()
summ = 0.
summ += sum(abs2,NN([0,yp]) for yp in yg) # u(0,y) = 0
summ += sum(abs2,NN([1,yp]) for yp in yg) # u(1,y) = -sin(π*1)*sin(π*y)
summ += sum(abs2,NN([xp,0]) for xp in xg) # u(x,0) = 0
summ += sum(abs2,NN([xp,1]) for xp in xg) # u(x,1) = -sin(π*x)*sin(π*1)
return summ
end
loss_BC()
# Combine loss functions
loss() = loss_ode() + loss_BC()
loss()
# Train the NN
opt = Flux.Descent(0.01)
data = Iterators.repeated((),1000)
iter = 0
cb = function()
global iter += 1
if iter % 100 == 0
display(loss())
end
end
display(loss())
Flux.train!(loss, Flux.params(NN), data, opt; cb=cb)
When you run this you get:
ERROR: MethodError: no method matching zero(::Tuple{Float64, Float64})
Closest candidates are:
zero(::Union{Type{P}, P}) where P<:Dates.Period at C:\Users\accou\.julia\juliaup\julia-1.8.0-rc1+0~x64\share\julia\stdlib\v1.8\Dates\src\periods.jl:53
zero(::SA) where SA<:StaticArrays.StaticArray at C:\Users\accou\.julia\packages\StaticArrays\0T5rI\src\linalg.jl:101
zero(::StatsBase.Histogram{T, N, E}) where {T, N, E} at C:\Users\accou\.julia\packages\StatsBase\n494Y\src\hist.jl:562
...
Stacktrace:
[1] (::Zygote.var"#134#135"{Bool})(Δ::Tuple{Float64, Float64})
@ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\lib\lib.jl:49
[2] (::Zygote.var"#1571#back#136"{Zygote.var"#134#135"{Bool}})(Δ::Tuple{Float64, Float64})
@ Zygote C:\Users\accou\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
[3] Pullback
@ .\twiceprecision.jl:84 [inlined]
[4] (::typeof(∂(add12)))(Δ::Tuple{Float64, Float64})
@ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
[5] Pullback
@ .\twiceprecision.jl:501 [inlined]
[6] (::typeof(∂(unsafe_getindex)))(Δ::Float64)
@ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
[7] Pullback
@ .\range.jl:876 [inlined]
[8] (::typeof(∂(iterate)))(Δ::Tuple{Float64, Nothing})
@ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
[9] Pullback
@ .\reduce.jl:60 [inlined]
[10] (::typeof(∂(_foldl_impl)))(Δ::Float64)
@ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
[11] Pullback
@ .\reduce.jl:48 [inlined]
[12] (::typeof(∂(foldl_impl)))(Δ::Float64)
@ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
[13] Pullback
@ .\reduce.jl:44 [inlined]
[14] (::typeof(∂(mapfoldl_impl)))(Δ::Float64)
@ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
[15] Pullback (repeats 2 times)
@ .\reduce.jl:162 [inlined]
[16] (::typeof(∂(mapfoldl)))(Δ::Float64)
@ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
[17] Pullback
@ .\reduce.jl:294 [inlined]
[18] (::typeof(∂(#mapreduce#263)))(Δ::Float64)
@ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
[19] Pullback
@ .\reduce.jl:294 [inlined]
[20] (::typeof(∂(mapreduce)))(Δ::Float64)
@ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
[21] Pullback
@ .\reduce.jl:520 [inlined]
[22] (::typeof(∂(#sum#266)))(Δ::Float64)
@ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
[23] Pullback
@ .\reduce.jl:520 [inlined]
[24] (::typeof(∂(sum)))(Δ::Float64)
@ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
[25] Pullback
@ c:\Users\accou\OneDrive\Computer\Desktop\test.jl:40 [inlined]
[26] (::typeof(∂(loss_BC)))(Δ::Float64)
@ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
[27] Pullback
@ c:\Users\accou\OneDrive\Computer\Desktop\test.jl:46 [inlined]
[28] #208
@ C:\Users\accou\.julia\packages\Zygote\DkIUK\src\lib\lib.jl:207 [inlined]
[29] #1750#back
@ C:\Users\accou\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
[30] Pullback
@ C:\Users\accou\.julia\packages\Flux\js6mP\src\optimise\train.jl:120 [inlined]
[31] (::Zygote.var"#89#90"{Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, typeof(∂(#37)), Zygote.Context})(Δ::Float64)
@ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface.jl:357
[32] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
@ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface.jl:76
[33] macro expansion
@ C:\Users\accou\.julia\packages\Flux\js6mP\src\optimise\train.jl:119 [inlined]
[34] macro expansion
@ C:\Users\accou\.julia\packages\ProgressLogging\6KXlp\src\ProgressLogging.jl:328 [inlined]
[35] train!(loss::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, data::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}},
opt::Descent; cb::var"#17#18")
@ Flux.Optimise C:\Users\accou\.julia\packages\Flux\js6mP\src\optimise\train.jl:117
[36] top-level scope
@ c:\Users\accou\OneDrive\Computer\Desktop\test.jl:60
But if you collect the ranges you're fine.
Is loss_BC
required to repro this?
It is not. Could you please try to simplify before asking for help?
This appears to give the same error:
julia> using Zygote
julia> function loss2()
summ = 0.
for xp in 0:0.1:1
summ += abs2(sin(π*xp))
end
summ
end;
julia> gradient(loss2)
ERROR: MethodError: no method matching zero(::Tuple{Float64, Float64})
Edit: even simpler, anything which calls iterate
on a range seems to go wrong:
julia> gradient(r -> sum(x for x in r), 0:1.0)
ERROR: MethodError: no method matching zero(::Tuple{Float64, Float64})
julia> gradient(first, 0:1.0)
ERROR: MethodError: no method matching zero(::Tuple{Float64, Float64})
but with LinRange it makes a structural tangent, and UnitRange is sometimes non-differentiable?
julia> Zygote.gradient(r -> sum(x for x in r), LinRange(0, 1, 2))
((start = 1.0, stop = 1.0, len = nothing, lendiv = -1.0),)
julia> Zygote.gradient(first, LinRange(0, 1, 2))
((start = 1.0, stop = nothing, len = nothing, lendiv = nothing),)
julia> gradient(r -> sum(x for x in r), 0:1) # UnitRange
(nothing,)
julia> gradient(first, 0:1)
((start = 1, stop = nothing),)
Yeah sorry, just pasted what I had from lecture notes because it was either that or no report. I'll get back around to cleaning this up in July.
One slightly crude way to solve this is to ensure that iterate
on the range calls getindex
which has a rule:
import ChainRulesCore: rrule, HasReverseMode, RuleConfig
function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractRange, i)
@show i
rrule_via_ad(config, getindex, x, i)
end
It might be possible to do something smarter, and directly handle iterate
. Or at least to avoid accumulating a dense vector to represent the gradient with respect to the range, via something like https://github.com/JuliaDiff/ChainRulesCore.jl/issues/437 .