Integrals.jl icon indicating copy to clipboard operation
Integrals.jl copied to clipboard

Zygote.gradient failed with CubaCuhre and batch !=0

Open KirillZubov opened this issue 3 years ago • 1 comments

using Quadrature, ForwardDiff, FiniteDiff, Zygote, Cuba
f(x,p) = sum(sin.(x .* p))
lb = ones(3)
ub = 3ones(3)
p = [1.5,2.0,3.0]

prob = QuadratureProblem(f,lb,ub,p; batch=0)
sol = solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=100)[1]

function testf(p)
    prob = QuadratureProblem(f,lb,ub,p, batch=0)
    solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=100)[1]
end
dp1 = Zygote.gradient(testf,p)
dp2 = FiniteDiff.finite_difference_gradient(testf,p)
dp3 = ForwardDiff.gradient(testf,p)
# work fine

prob = QuadratureProblem(f,lb,ub,p; batch=10)
sol = solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=100)[1]


function testf(p)
    prob = QuadratureProblem(f,lb,ub,p, batch=10)
    solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=100)[1]
end
dp2 = FiniteDiff.finite_difference_gradient(testf,p)
dp3 = ForwardDiff.gradient(testf,p)
# work fine

dp1 = Zygote.gradient(testf,p)

ERROR: MethodError: Cannot `convert` an object of type Array{Float64,1} to an object of type Float64
Closest candidates are:
  convert(::Type{T}, ::ArrayInterface.StaticInt{N}) where {T<:Number, N} at /Users/kirill/.julia/packages/ArrayInterface/rw2kK/src/static.jl:18
  convert(::Type{R}, ::T) where {R<:Real, T<:ReverseDiff.TrackedReal} at /Users/kirill/.julia/packages/ReverseDiff/jFRo1/src/tracked.jl:255
  convert(::Type{T}, ::Unitful.Quantity) where T<:Real at /Users/kirill/.julia/packages/Unitful/1t88N/src/conversion.jl:145
  ...
Stacktrace:
 [1] setindex! at ./array.jl:849 [inlined]
 [2] macro expansion at ./multidimensional.jl:802 [inlined]
 [3] macro expansion at ./cartesian.jl:64 [inlined]
 [4] macro expansion at ./multidimensional.jl:797 [inlined]
 [5] _unsafe_setindex!(::IndexLinear, ::Array{Float64,2}, ::Array{Array{Float64,1},1}, ::Base.Slice{Base.OneTo{Int64}}, ::Int64) at ./multidimensional.jl:789
 [6] _setindex! at ./multidimensional.jl:785 [inlined]
 [7] setindex!(::Array{Float64,2}, ::Array{Array{Float64,1},1}, ::Function, ::Int64) at ./abstractarray.jl:1153
 [8] (::Quadrature.var"#46#57"{QuadratureProblem{false,Array{Float64,1},typeof(f),Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}}})(::Array{Float64,2}, ::Array{Float64,1}) at /Users/kirill/.julia/packages/Quadrature/ZmWGy/src/Quadrature.jl:524
 [9] __solvebp_call(::QuadratureProblem{false,Array{Float64,1},Quadrature.var"#46#57"{QuadratureProblem{false,Array{Float64,1},typeof(f),Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}}},Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Symbol,Real,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:reltol, :abstol, :maxiters),Tuple{Float64,Float64,Int64}}}}, ::CubaCuhre, ::Quadrature.ReCallVJP{Quadrature.ZygoteVJP}, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,1}; reltol::Float64, abstol::Float64, maxiters::Int64, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /Users/kirill/.julia/packages/Quadrature/ZmWGy/src/Quadrature.jl:437
 [10] (::Quadrature.var"#quadrature_adjoint#52"{Base.Iterators.Pairs{Symbol,Real,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:reltol, :abstol, :maxiters),Tuple{Float64,Float64,Int64}}},QuadratureProblem{false,Array{Float64,1},typeof(f),Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}},CubaCuhre,Quadrature.ReCallVJP{Quadrature.ZygoteVJP},Array{Float64,1},Array{Float64,1},Array{Float64,1},Tuple{}})(::Array{Float64,1}) at /Users/kirill/.julia/packages/Quadrature/ZmWGy/src/Quadrature.jl:546
 [11] #65#back at /Users/kirill/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:55 [inlined]
 [12] #150 at /Users/kirill/.julia/packages/Zygote/c0awc/src/lib/lib.jl:191 [inlined]
 [13] (::Zygote.var"#1693#back#152"{Zygote.var"#150#151"{Quadrature.var"#65#back#64"{Quadrature.var"#quadrature_adjoint#52"{Base.Iterators.Pairs{Symbol,Real,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:reltol, :abstol, :maxiters),Tuple{Float64,Float64,Int64}}},QuadratureProblem{false,Array{Float64,1},typeof(f),Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}},CubaCuhre,Quadrature.ReCallVJP{Quadrature.ZygoteVJP},Array{Float64,1},Array{Float64,1},Array{Float64,1},Tuple{}}},Tuple{NTuple{8,Nothing},Tuple{}}}})(::Array{Float64,1}) at /Users/kirill/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [14] #solve#10 at /Users/kirill/.julia/packages/Quadrature/ZmWGy/src/Quadrature.jl:149 [inlined]
 [15] (::typeof(∂(#solve#10)))(::Array{Float64,1}) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [16] #150 at /Users/kirill/.julia/packages/Zygote/c0awc/src/lib/lib.jl:191 [inlined]
 [17] (::Zygote.var"#1693#back#152"{Zygote.var"#150#151"{typeof(∂(#solve#10)),Tuple{NTuple{5,Nothing},Tuple{}}}})(::Array{Float64,1}) at /Users/kirill/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [18] (::typeof(∂(solve##kw)))(::Array{Float64,1}) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [19] testf at ./none:3 [inlined]
 [20] (::typeof(∂(testf)))(::Float64) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [21] (::Zygote.var"#41#42"{typeof(∂(testf))})(::Float64) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:45
 [22] gradient(::Function, ::Array{Float64,1}) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:54
 [23] top-level scope at none:1

KirillZubov avatar Nov 21 '20 13:11 KirillZubov

I think there is a bug in the MWE since a scalar-valued f is incompatible with batching. Namely, the batch integrand should return a vector whose length matches the last axis of the input points (see the FAQ for more details).

I've adapted the MWE to the current version of Integrals, modified the integrand to do what I think was intended, and can confirm it works on the master branch

using Integrals, ForwardDiff, FiniteDiff, Zygote, Cuba
f(x,p) = sum(sin.(x .* p); dims=1)
lb = ones(3)
ub = 3ones(3)
p = [1.5,2.0,3.0]

prob = IntegralProblem(f,lb,ub,p)
sol = solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=1000)[1]

function testf(p)
    prob = IntegralProblem(f,lb,ub,p)
    solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=1000)[1]
end
dp1 = Zygote.gradient(testf,p)
dp2 = FiniteDiff.finite_difference_gradient(testf,p)
dp3 = ForwardDiff.gradient(testf,p)
# work fine

prob = IntegralProblem(f,lb,ub,p; batch=10)
sol = solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=1000)[1]


function testf(p)
    prob = IntegralProblem(f,lb,ub,p, batch=10)
    solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=1000)[1]
end
dp2 = FiniteDiff.finite_difference_gradient(testf,p)
dp3 = ForwardDiff.gradient(testf,p)
# work fine

dp1 = Zygote.gradient(testf,p)

Since there are some bugs in the current release that affect CubaCuhre and they are fixed on the master branch, I'll wait to close the issue until the next release.

lxvm avatar Mar 09 '24 23:03 lxvm