ForwardDiff.jl
ForwardDiff.jl copied to clipboard
StackOverflow when differentiating QuadGK Integral
When trying to differentiate the numerical integral from QuadGK (here integral of sin(x) for example), I get a stackoverflow error. Differentiating with Zygote works.
using ForwardDiff
using QuadGK
f(x) = quadgk(u->sin(u),0,x,rtol=1e-8)[1]
ForwardDiff.derivative(f,1)
ERROR: StackOverflowError:
Stacktrace:...
Expected value was sin(1) = 0.841471, which is the result I get when using f'(1)
with Zygote.
Why did you remove the stacktrace?
Smaller MWE:
julia> using ForwardDiff, QuadGK
julia> f = x -> x
#5 (generic function with 1 method)
julia> QuadGK.cachedrule(ForwardDiff.Dual{ForwardDiff.Tag{typeof(f),Int64},Float16,1}, 1)
ERROR: StackOverflowError:
Stacktrace:
[1] cachedrule(::Type{ForwardDiff.Dual{ForwardDiff.Tag{var"#5#6",Int64},Float16,1}}, ::Int64) at C:\Users\kawcz\.julia\packages\QuadGK\czbUH\src\gausskronrod.jl:249 (repeats 79984 times)
Looks like the problem is that, for T = ForwardDiff.Dual{ForwardDiff.Tag{typeof(f),Int64},Float16,1}
, the call to typeof(float(real(one(T))))
(here) results in the same subtype:
julia> typeof(float(real(one(T)))) <: Number
true
And the <: AbstractFloat
method doesn’t catch this
Is there any progress in solving this issue?
In my case, I want to differentiate a function several times as
fn(x) = quadgk(u-> sin(u),0,x)[1]
dfn = x -> ForwardDiff.derivative(fn,x)
d2fn = x -> ForwardDiff.derivative(dfn,x)
d3fn = x -> ForwardDiff.derivative(d2fn,x)
d4fn ...
In case of simple functions like fn(x) = 10x^4 + 4x^3 + 5x^2 + 12x + 20
ForwardDiff works as desired. But not for quadgk :(
I tried to do the same with Zygote, but it does not work out - at least not fast enough. It seems to have some problems to compute the solution.
At least for the first derivative, adding this definition seems to fix the issue:
@generated function QuadGK.cachedrule(::Type{ForwardDiff.Dual{Tag, T, N}}, n::Integer) where {T<:AbstractFloat,Tag,N}
cache = haskey(QuadGK.rulecache, T) ? QuadGK.rulecache[T] : (QuadGK.rulecache[T] = Dict{Int,NTuple{3,Vector{T}}}())
:(haskey($cache, n) ? $cache[n] : ($cache[n] = kronrod($T, n)))
end
This way one can compute, e.g., dfn(2.0)
. It should not be too complicated to generalise this to higher order derivatives.