Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Some missing derivative with ranges?

Open ChrisRackauckas opened this issue 2 years ago • 4 comments

I ran a course the other week and a student found this:

using Flux
using Statistics

# Create NN
NN = Chain(Dense(2,16,tanh),
           Dense(16,16,tanh),
           Dense(16,1),
           first)
NN([0.5,0.5])

# Grid data for sampling
#xg = collect(0:0.1:1)  # Will start the training although it shows NaN's when going through it (program unfinished)
#yg = collect(0:0.1:1)
xg = 0:0.1:1    # DOES NOT WORK WITH RANGES??
yg = 0:0.1:1

# Finite differences, 2nd order
ϵ = sqrt(eps(Float32))
NNxx(x,y) = ( NN([x+ϵ,y]) - 2*NN([x,y]) + NN([x-ϵ,y]) ) / ϵ^2
NNyy(x,y) = ( NN([x,y+ϵ]) - 2*NN([x,y]) + NN([x,y-ϵ]) ) / ϵ^2

# Create loss function for our physical relation
function loss_ode()
    summ = 0.
    for xp in xg
        for yp in yg
            summ += abs2( NNxx(xp,yp) + NNyy(xp,yp) + sin(π*xp)*sin(π*yp) )
        end
    end
    return summ
end
loss_ode()

# Define loss function for boundary conditions
function loss_BC()
    summ = 0.
    summ += sum(abs2,NN([0,yp]) for yp in yg)                  # u(0,y) = 0
    summ += sum(abs2,NN([1,yp]) for yp in yg)   # u(1,y) = -sin(π*1)*sin(π*y)
    summ += sum(abs2,NN([xp,0]) for xp in xg)                  # u(x,0) = 0
    summ += sum(abs2,NN([xp,1]) for xp in xg)   # u(x,1) = -sin(π*x)*sin(π*1)
    return summ
end
loss_BC()

# Combine loss functions
loss() = loss_ode() + loss_BC()
loss()

# Train the NN
opt = Flux.Descent(0.01)
data = Iterators.repeated((),1000)
iter = 0
cb = function()
    global iter += 1
    if iter % 100 == 0
        display(loss())
    end
end
display(loss())
Flux.train!(loss, Flux.params(NN), data, opt; cb=cb)

When you run this you get:

ERROR: MethodError: no method matching zero(::Tuple{Float64, Float64})
Closest candidates are:
  zero(::Union{Type{P}, P}) where P<:Dates.Period at C:\Users\accou\.julia\juliaup\julia-1.8.0-rc1+0~x64\share\julia\stdlib\v1.8\Dates\src\periods.jl:53
  zero(::SA) where SA<:StaticArrays.StaticArray at C:\Users\accou\.julia\packages\StaticArrays\0T5rI\src\linalg.jl:101
  zero(::StatsBase.Histogram{T, N, E}) where {T, N, E} at C:\Users\accou\.julia\packages\StatsBase\n494Y\src\hist.jl:562
  ...
Stacktrace:
  [1] (::Zygote.var"#134#135"{Bool})(Δ::Tuple{Float64, Float64})
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\lib\lib.jl:49
  [2] (::Zygote.var"#1571#back#136"{Zygote.var"#134#135"{Bool}})(Δ::Tuple{Float64, Float64})
    @ Zygote C:\Users\accou\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
  [3] Pullback
    @ .\twiceprecision.jl:84 [inlined]
  [4] (::typeof(∂(add12)))(Δ::Tuple{Float64, Float64})
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
  [5] Pullback
    @ .\twiceprecision.jl:501 [inlined]
  [6] (::typeof(∂(unsafe_getindex)))(Δ::Float64)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
  [7] Pullback
    @ .\range.jl:876 [inlined]
  [8] (::typeof(∂(iterate)))(Δ::Tuple{Float64, Nothing})
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
  [9] Pullback
    @ .\reduce.jl:60 [inlined]
 [10] (::typeof(∂(_foldl_impl)))(Δ::Float64)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [11] Pullback
    @ .\reduce.jl:48 [inlined]
 [12] (::typeof(∂(foldl_impl)))(Δ::Float64)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [13] Pullback
    @ .\reduce.jl:44 [inlined]
 [14] (::typeof(∂(mapfoldl_impl)))(Δ::Float64)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [15] Pullback (repeats 2 times)
    @ .\reduce.jl:162 [inlined]
 [16] (::typeof(∂(mapfoldl)))(Δ::Float64)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [17] Pullback
    @ .\reduce.jl:294 [inlined]
 [18] (::typeof(∂(#mapreduce#263)))(Δ::Float64)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [19] Pullback
    @ .\reduce.jl:294 [inlined]
 [20] (::typeof(∂(mapreduce)))(Δ::Float64)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [21] Pullback
    @ .\reduce.jl:520 [inlined]
 [22] (::typeof(∂(#sum#266)))(Δ::Float64)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [23] Pullback
    @ .\reduce.jl:520 [inlined]
 [24] (::typeof(∂(sum)))(Δ::Float64)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [25] Pullback
    @ c:\Users\accou\OneDrive\Computer\Desktop\test.jl:40 [inlined]
 [26] (::typeof(∂(loss_BC)))(Δ::Float64)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [27] Pullback
    @ c:\Users\accou\OneDrive\Computer\Desktop\test.jl:46 [inlined]
 [28] #208
    @ C:\Users\accou\.julia\packages\Zygote\DkIUK\src\lib\lib.jl:207 [inlined]
 [29] #1750#back
    @ C:\Users\accou\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [30] Pullback
    @ C:\Users\accou\.julia\packages\Flux\js6mP\src\optimise\train.jl:120 [inlined]
 [31] (::Zygote.var"#89#90"{Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, typeof(∂(#37)), Zygote.Context})(Δ::Float64)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface.jl:357
 [32] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface.jl:76
 [33] macro expansion
    @ C:\Users\accou\.julia\packages\Flux\js6mP\src\optimise\train.jl:119 [inlined]
 [34] macro expansion
    @ C:\Users\accou\.julia\packages\ProgressLogging\6KXlp\src\ProgressLogging.jl:328 [inlined]
 [35] train!(loss::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, data::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}},
 opt::Descent; cb::var"#17#18")
    @ Flux.Optimise C:\Users\accou\.julia\packages\Flux\js6mP\src\optimise\train.jl:117
 [36] top-level scope
    @ c:\Users\accou\OneDrive\Computer\Desktop\test.jl:60

But if you collect the ranges you're fine.

ChrisRackauckas avatar Jun 22 '22 22:06 ChrisRackauckas

Is loss_BC required to repro this?

ToucheSir avatar Jun 23 '22 02:06 ToucheSir

It is not. Could you please try to simplify before asking for help?

This appears to give the same error:

julia> using Zygote

julia> function loss2()
           summ = 0.
           for xp in 0:0.1:1
               summ += abs2(sin(π*xp))
           end
           summ
       end;

julia> gradient(loss2)
ERROR: MethodError: no method matching zero(::Tuple{Float64, Float64})

Edit: even simpler, anything which calls iterate on a range seems to go wrong:

julia> gradient(r -> sum(x for x in r), 0:1.0)
ERROR: MethodError: no method matching zero(::Tuple{Float64, Float64})

julia> gradient(first, 0:1.0)
ERROR: MethodError: no method matching zero(::Tuple{Float64, Float64})

but with LinRange it makes a structural tangent, and UnitRange is sometimes non-differentiable?

julia> Zygote.gradient(r -> sum(x for x in r), LinRange(0, 1, 2))
((start = 1.0, stop = 1.0, len = nothing, lendiv = -1.0),)

julia> Zygote.gradient(first, LinRange(0, 1, 2))
((start = 1.0, stop = nothing, len = nothing, lendiv = nothing),)

julia> gradient(r -> sum(x for x in r), 0:1)  # UnitRange
(nothing,)

julia> gradient(first, 0:1)
((start = 1, stop = nothing),)

mcabbott avatar Jun 23 '22 04:06 mcabbott

Yeah sorry, just pasted what I had from lecture notes because it was either that or no report. I'll get back around to cleaning this up in July.

ChrisRackauckas avatar Jun 23 '22 09:06 ChrisRackauckas

One slightly crude way to solve this is to ensure that iterate on the range calls getindex which has a rule:

import ChainRulesCore: rrule, HasReverseMode, RuleConfig

function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractRange, i)
    @show i
    rrule_via_ad(config, getindex, x, i)
end

It might be possible to do something smarter, and directly handle iterate. Or at least to avoid accumulating a dense vector to represent the gradient with respect to the range, via something like https://github.com/JuliaDiff/ChainRulesCore.jl/issues/437 .

mcabbott avatar Jul 05 '22 17:07 mcabbott