Integrals.jl
Integrals.jl copied to clipboard
Zygote.gradient failed with CubaCuhre and batch !=0
using Quadrature, ForwardDiff, FiniteDiff, Zygote, Cuba
f(x,p) = sum(sin.(x .* p))
lb = ones(3)
ub = 3ones(3)
p = [1.5,2.0,3.0]
prob = QuadratureProblem(f,lb,ub,p; batch=0)
sol = solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=100)[1]
function testf(p)
prob = QuadratureProblem(f,lb,ub,p, batch=0)
solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=100)[1]
end
dp1 = Zygote.gradient(testf,p)
dp2 = FiniteDiff.finite_difference_gradient(testf,p)
dp3 = ForwardDiff.gradient(testf,p)
# work fine
prob = QuadratureProblem(f,lb,ub,p; batch=10)
sol = solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=100)[1]
function testf(p)
prob = QuadratureProblem(f,lb,ub,p, batch=10)
solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=100)[1]
end
dp2 = FiniteDiff.finite_difference_gradient(testf,p)
dp3 = ForwardDiff.gradient(testf,p)
# work fine
dp1 = Zygote.gradient(testf,p)
ERROR: MethodError: Cannot `convert` an object of type Array{Float64,1} to an object of type Float64
Closest candidates are:
convert(::Type{T}, ::ArrayInterface.StaticInt{N}) where {T<:Number, N} at /Users/kirill/.julia/packages/ArrayInterface/rw2kK/src/static.jl:18
convert(::Type{R}, ::T) where {R<:Real, T<:ReverseDiff.TrackedReal} at /Users/kirill/.julia/packages/ReverseDiff/jFRo1/src/tracked.jl:255
convert(::Type{T}, ::Unitful.Quantity) where T<:Real at /Users/kirill/.julia/packages/Unitful/1t88N/src/conversion.jl:145
...
Stacktrace:
[1] setindex! at ./array.jl:849 [inlined]
[2] macro expansion at ./multidimensional.jl:802 [inlined]
[3] macro expansion at ./cartesian.jl:64 [inlined]
[4] macro expansion at ./multidimensional.jl:797 [inlined]
[5] _unsafe_setindex!(::IndexLinear, ::Array{Float64,2}, ::Array{Array{Float64,1},1}, ::Base.Slice{Base.OneTo{Int64}}, ::Int64) at ./multidimensional.jl:789
[6] _setindex! at ./multidimensional.jl:785 [inlined]
[7] setindex!(::Array{Float64,2}, ::Array{Array{Float64,1},1}, ::Function, ::Int64) at ./abstractarray.jl:1153
[8] (::Quadrature.var"#46#57"{QuadratureProblem{false,Array{Float64,1},typeof(f),Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}}})(::Array{Float64,2}, ::Array{Float64,1}) at /Users/kirill/.julia/packages/Quadrature/ZmWGy/src/Quadrature.jl:524
[9] __solvebp_call(::QuadratureProblem{false,Array{Float64,1},Quadrature.var"#46#57"{QuadratureProblem{false,Array{Float64,1},typeof(f),Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}}},Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Symbol,Real,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:reltol, :abstol, :maxiters),Tuple{Float64,Float64,Int64}}}}, ::CubaCuhre, ::Quadrature.ReCallVJP{Quadrature.ZygoteVJP}, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,1}; reltol::Float64, abstol::Float64, maxiters::Int64, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /Users/kirill/.julia/packages/Quadrature/ZmWGy/src/Quadrature.jl:437
[10] (::Quadrature.var"#quadrature_adjoint#52"{Base.Iterators.Pairs{Symbol,Real,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:reltol, :abstol, :maxiters),Tuple{Float64,Float64,Int64}}},QuadratureProblem{false,Array{Float64,1},typeof(f),Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}},CubaCuhre,Quadrature.ReCallVJP{Quadrature.ZygoteVJP},Array{Float64,1},Array{Float64,1},Array{Float64,1},Tuple{}})(::Array{Float64,1}) at /Users/kirill/.julia/packages/Quadrature/ZmWGy/src/Quadrature.jl:546
[11] #65#back at /Users/kirill/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:55 [inlined]
[12] #150 at /Users/kirill/.julia/packages/Zygote/c0awc/src/lib/lib.jl:191 [inlined]
[13] (::Zygote.var"#1693#back#152"{Zygote.var"#150#151"{Quadrature.var"#65#back#64"{Quadrature.var"#quadrature_adjoint#52"{Base.Iterators.Pairs{Symbol,Real,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:reltol, :abstol, :maxiters),Tuple{Float64,Float64,Int64}}},QuadratureProblem{false,Array{Float64,1},typeof(f),Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}},CubaCuhre,Quadrature.ReCallVJP{Quadrature.ZygoteVJP},Array{Float64,1},Array{Float64,1},Array{Float64,1},Tuple{}}},Tuple{NTuple{8,Nothing},Tuple{}}}})(::Array{Float64,1}) at /Users/kirill/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[14] #solve#10 at /Users/kirill/.julia/packages/Quadrature/ZmWGy/src/Quadrature.jl:149 [inlined]
[15] (::typeof(∂(#solve#10)))(::Array{Float64,1}) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[16] #150 at /Users/kirill/.julia/packages/Zygote/c0awc/src/lib/lib.jl:191 [inlined]
[17] (::Zygote.var"#1693#back#152"{Zygote.var"#150#151"{typeof(∂(#solve#10)),Tuple{NTuple{5,Nothing},Tuple{}}}})(::Array{Float64,1}) at /Users/kirill/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[18] (::typeof(∂(solve##kw)))(::Array{Float64,1}) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[19] testf at ./none:3 [inlined]
[20] (::typeof(∂(testf)))(::Float64) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[21] (::Zygote.var"#41#42"{typeof(∂(testf))})(::Float64) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:45
[22] gradient(::Function, ::Array{Float64,1}) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:54
[23] top-level scope at none:1
I think there is a bug in the MWE since a scalar-valued f
is incompatible with batching. Namely, the batch integrand should return a vector whose length matches the last axis of the input points (see the FAQ for more details).
I've adapted the MWE to the current version of Integrals, modified the integrand to do what I think was intended, and can confirm it works on the master branch
using Integrals, ForwardDiff, FiniteDiff, Zygote, Cuba
f(x,p) = sum(sin.(x .* p); dims=1)
lb = ones(3)
ub = 3ones(3)
p = [1.5,2.0,3.0]
prob = IntegralProblem(f,lb,ub,p)
sol = solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=1000)[1]
function testf(p)
prob = IntegralProblem(f,lb,ub,p)
solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=1000)[1]
end
dp1 = Zygote.gradient(testf,p)
dp2 = FiniteDiff.finite_difference_gradient(testf,p)
dp3 = ForwardDiff.gradient(testf,p)
# work fine
prob = IntegralProblem(f,lb,ub,p; batch=10)
sol = solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=1000)[1]
function testf(p)
prob = IntegralProblem(f,lb,ub,p, batch=10)
solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=1000)[1]
end
dp2 = FiniteDiff.finite_difference_gradient(testf,p)
dp3 = ForwardDiff.gradient(testf,p)
# work fine
dp1 = Zygote.gradient(testf,p)
Since there are some bugs in the current release that affect CubaCuhre
and they are fixed on the master branch, I'll wait to close the issue until the next release.