KernelFunctions.jl icon indicating copy to clipboard operation
KernelFunctions.jl copied to clipboard

Second derivative of Matern in zero is wrong

Open FelixBenning opened this issue 2 years ago • 23 comments

julia> using KernelFunctions: MaternKernel

julia> k = MaternKernel(ν=5)
Matern Kernel (ν = 5, metric = Distances.Euclidean(0.0))

julia> import ForwardDiff as FD

julia> kx(x,y) = FD.derivative(t -> k(x+t, y), 0)
kx (generic function with 1 method)

julia> dk(x,y) = FD.derivative(t -> kx(x, y+t), 0)
dk (generic function with 1 method)

julia> dk(0,0)
0.0

This is wrong, because for a centered GP $Z$ with covariance function $k$

dk(x,y) = \partial_x \partial_y k(x,y) = \partial_x \partial_y \mathbb{E}[Z(x),Z(y)] = \mathbb{E}[\partial_x Z(x) \partial_y Z(y)] = \text{Cov}(Z'(x), Z'(y))

And $\text{Cov}(Z'(0), Z'(0)) >0$.

$\nu=5$ should be plenty of space for numerical errors since this implies the GP is 5 times differentiable.

FelixBenning avatar Jun 05 '23 14:06 FelixBenning

This is most likely due to this

function _matern(ν::Real, d::Real)
    if iszero(d)
        return one(d)
    else
        y = sqrt(2ν) * d
        b = log(besselk(ν, y))
        return exp((one(d) - ν) * oftype(y, logtwo) - loggamma(ν) + ν * log(y) + b)
    end
end

https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/blob/master/src/basekernels/matern.jl#L43-L45

The if block should not be constant but rather a taylor polynomial so that autodiff in this branch works. I am looking for a reference...

FelixBenning avatar Jun 05 '23 14:06 FelixBenning

Might be due to the hardcoded (constant) value for x = y. Hence possibly the problem doesn't exist with ForwardDiff 0.11.

devmotion avatar Jun 05 '23 14:06 devmotion

The if block should not be constant but rather a taylor polynomial so that autodiff in this branch works.

Shouldn't be necessary in ForwardDiff 0.11: It skips measure zero branches (which fixes some problems but broke existing code).

devmotion avatar Jun 05 '23 14:06 devmotion

https://github.com/cgeoga/BesselK.jl

https://arxiv.org/pdf/2201.00090.pdf

would be a possibility

FelixBenning avatar Jun 05 '23 14:06 FelixBenning

Shouldn't be necessary in ForwardDiff 0.11: It skips measure zero branches (which fixes some problems but broke existing code).

wow neat - when is that going to be released?

FelixBenning avatar Jun 05 '23 14:06 FelixBenning

The package and paper is mainly concerned with derivatives wrt the order, which does not seem to be the issue in the OP.

when is that going to be released?

I don't know. Initially the change was released in a non-breaking 0.10.X version but it broke a lot of downstream packages that rely on the current behaviour. So it was reverted and re-applied to the master branch but releases are only made in a separate 0.10.X branch without this change recently. I don't think anyone plans to release a 0.11 any time soon because the same thing will happen and nobody wants to invest the time to fix all the broken downstream code.

devmotion avatar Jun 05 '23 15:06 devmotion

@devmotion you are probably right about the references - damn I thought I have seen that somehwere. Can't find it at the moment. Given that the variance of the derivatives is really important for everything, this is not really a corner case and will break working with derivatives of GPs...

FelixBenning avatar Jun 05 '23 15:06 FelixBenning

Are you interested in a taylor expansion for the if x==y branch? Given the AD promise

Automatic Differentiation compatibility: all kernel functions which ought to be differentiable using AD packages like ForwardDiff.jl or Zygote.jl should be.

Same issue with the rationalQuadratic btw (although that does not appear to branch in the same way)

julia> using KernelFunctions: RationalQuadraticKernel

julia> k = RationalQuadraticKernel()
Rational Quadratic Kernel (α = 2.0, metric = Distances.Euclidean(0.0))

julia> kx(x,y) = FD.derivative(t->k(x+t, y), 0)
kx (generic function with 1 method)

julia> dk(x,y) = FD.derivative(t->k(x,y+t), 0)
dx (generic function with 1 method)

julia> dk(0,0)
0.0

EDIT: Wikipedia https://en.wikipedia.org/wiki/Mat%C3%A9rn_covariance_function#Taylor_series_at_zero_and_spectral_moments

FelixBenning avatar Jun 05 '23 15:06 FelixBenning

~~I think the Matern kernel is in fact correct:~~ Made a mistake here, Matern is incorrect,but RationalQuadratic is fine

using KernelFunctions, Enzyme, Plots
k = Matern52Kernel()
mk(d) = k(1.0, 1.0 + d)

mk(0.1)

dr = range(0.0, 0.1; length=200)

p = plot(layout=(1, 2))
plot!(p[1], dr, mk.(dr), label="matern", legend=:topright)

dmkd(x, y) = only(autodiff(
    Forward,
    yt -> only(autodiff_deferred(Forward, k, DuplicatedNoNeed, Duplicated(x, 1.0), yt)),
    DuplicatedNoNeed,
    Duplicated(y, 1.0)))
dmkd(1.0, 1.0 + 0.1)

plot!(p[2], dr, dmkd.(1.0, 1.0 .+ dr), label="matern", legend=:topright, title="d^2/(dx1 dx2) ker")

image

but for RationalQuadratic

using KernelFunctions, Enzyme, Plots
k = RationalQuadraticKernel()
mk(d) = k(1.0, 1.0 + d)

mk(0.1)

dr = range(0.0, 0.1; length=200)

p = plot(layout=(1, 2))
plot!(p[1], dr, mk.(dr), label="RQ", legend=:topright)

dmkd(x, y) = only(autodiff(
    Forward,
    yt -> only(autodiff_deferred(Forward, k, DuplicatedNoNeed, Duplicated(x, 1.0), yt)),
    DuplicatedNoNeed,
    Duplicated(y, 1.0)))
dmkd(1.0, 1.0 + 0.1)

plot!(p[2], dr, dmkd.(1.0, 1.0 .+ dr), label="RQ", legend=:right, title="d^2/(dx1 dx2) ker")

image

Crown421 avatar Jun 06 '23 16:06 Crown421

~I think the Matern kernel is in fact correct:~ Made a mistake here, Matern is incorrect,but RationalQuadratic is fine

neat so switching to enzyme would at least fix the rational quadratic.

FelixBenning avatar Jun 06 '23 18:06 FelixBenning

I have been working on this for a bit, and it seems the issue is not the matern kernel, or that the Taylor Expansion is needed, but instead it seems to related to the Euclidean distance. In the following code I did it all by hand, and get

using KernelFunctions
using Enzyme
using Plots

k = Matern52Kernel()
mk(d) = k(1.0, 1.0 + d)

kappa(d) = (1 + sqrt(5) * d + 5 * d^2 / 3) * exp(-sqrt(5) * d)

r = range(0, 0.2, length=30)

begin
    dist1(x, y) = sqrt((x - y)^2)
    dist2(x, y) = abs(x - y)
    ck1(x, y) = kappa(dist1(x, y))
    ck2(x, y) = kappa(dist2(x, y))

    p = plot(layout=(2, 1))
    plot!(p[1], r, ck1.(1.0, 1.0 .+ r), label="dist1")
    plot!(p[1], r, ck2.(1.0, 1.0 .+ r), label="dist2")
    plot!(p[1], r, mk.(r), label="ref")

    dmkd1(x, y) = only(autodiff(
        Forward,
        yt -> only(autodiff_deferred(Forward, ck1, DuplicatedNoNeed, Duplicated(x, 1.0), yt)),
        DuplicatedNoNeed,
        Duplicated(y, 1.0)))
    dmkd2(x, y) = only(autodiff(
        Forward,
        yt -> only(autodiff_deferred(Forward, ck2, DuplicatedNoNeed, Duplicated(x, 1.0), yt)),
        DuplicatedNoNeed,
        Duplicated(y, 1.0)))

    plot!(p[2], r, dmkd1.(1.0, 1.0 .+ r), label="dist1")
    plot!(p[2], r, dmkd2.(1.0, 1.0 .+ r), label="dist2")
end

image

Repeating the same with the distance works correctly, so there seems to be some weird issue with the chain rule. image

Crown421 avatar Aug 23 '23 17:08 Crown421

If I understand your code correctly, you reimplemented kappa (in the first plot you make sure that it results in the same function by comparing it to the reference implementation) and then you take the derivative of this kappa.

But your kappa does not have an if iszero(d) like the general matern implementation does cf.

https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/blob/master/src/basekernels/matern.jl#L43-L45

As autodiff simply takes the derivative of the branch it finds itself in, it will take the derivative of this if case (if d is zero). And the derivative of a constant is zero.

FelixBenning avatar Aug 24 '23 17:08 FelixBenning

As autodiff simply takes the derivative of the branch it finds itself in

That's not true in general. ForwardDiff#master is supposed to ignore branches of measure zero.

devmotion avatar Aug 24 '23 17:08 devmotion

That's not true in general. ForwardDiff#master is supposed to ignore branches of measure zero.

But it doesn't at the moment and not in the forseeable future as I understood from the reaction when I asked around about it. And the point of my comment was to explain what the issue was at the moment (and this is probably it).

FelixBenning avatar Aug 24 '23 17:08 FelixBenning

But your kappa does not have an if iszero(d) like the general matern implementation does

From my understanding and experiments, the general Matern implementation doesn't matter for Matern52. It is a specialized implementation, and _matern never enters into it.

Further down in the file you linked, you can see the code for Matern52.

Crown421 avatar Aug 25 '23 09:08 Crown421

I didn't know that Matern52 also breaks. I thought only the general version was a problem. I guess in this case we have two problems (the branch in the general case) AND the distance function.

FelixBenning avatar Aug 25 '23 12:08 FelixBenning

Continuing to look into it, it seems the issues with sqrt. The Euclidean distance is (effectively) defined as sqrt((x-y)^2), which then causes issues in x==y. See also this long discussion in Enzyme.

This also explains why ForwardDiff.jl failed. KernelFunctions defines a custom rule for this case (via ChainRules.jl), but ForwardDiff does not use ChainRules as far as I can tell.

Crown421 avatar Sep 06 '23 15:09 Crown421

Apologies as I haven't read further up and got linked here from slack, I mean the linked Enzyme issue there is that the derivative of sqrt is undefined at 0. We chose to have it as 0 there instead of nan, which the linked research says is good for a variety of reasons, though added an extra instruction.

Cc @martinjm97

wsmoses avatar Sep 06 '23 15:09 wsmoses

No, ForwardDiff uses its own definitions (functions with Dual arguments) and DiffRules. The sqrt example reminds me of https://github.com/JuliaDiff/DiffRules.jl/pull/100 which reverted a ForwardDiff bug introduced by defining the derivative of abs(x) as sign(x) instead of signbit(x) ? -one(x) : one(x). Clearly the only difference is how the AD derivative at 0 is defined (obviously the function is not differentiable at 0) - but setting it to 0 (which is a valid subderivative) breaks the hessian example in the linked PR (even on ForwardDiff#master). Hence generally I think it's better to define "derivatives" at non-differentiable points by evaluating them as the left and hand side derivatives than setting them to some constant (!) subderivative (as argued above in this issue here as well IIRC).

devmotion avatar Sep 06 '23 15:09 devmotion

@devmotion Do you think adding something like this special rule would be the solution in this context?

Crown421 avatar Sep 06 '23 15:09 Crown421

@devmotion fwiw that's how Enzyme defines its abs derivative (see below). Sqrt, however, cannot have this though.

julia> using Enzyme
julia> Enzyme.API.printall!(true)

julia> Enzyme.autodiff(Forward, abs, Duplicated(0.0, 1.0))
after simplification :
; Function Attrs: mustprogress nofree nosync readnone willreturn
define double @preprocess_julia_abs_615_inner.1(double %0) local_unnamed_addr #3 !dbg !10 {
entry:
  %1 = call {}*** @julia.get_pgcstack() #4
  %2 = call double @llvm.fabs.f64(double %0) #4, !dbg !11
  ret double %2, !dbg !13
}

; Function Attrs: mustprogress nofree nosync readnone willreturn
define internal double @fwddiffejulia_abs_615_inner.1(double %0, double %"'") local_unnamed_addr #3 !dbg !14 {
entry:
  %1 = call {}*** @julia.get_pgcstack() #4
  %2 = fcmp fast olt double %0, 0.000000e+00, !dbg !15
  %3 = select fast i1 %2, double -1.000000e+00, double 1.000000e+00, !dbg !15
  %4 = fmul fast double %"'", %3, !dbg !15
  ret double %4
}

(1.0,)

wsmoses avatar Sep 06 '23 15:09 wsmoses

@wsmoses A bit of a side issue, but would it make sense to have that API.printall! function in the Enzyme documentation? I was initially looking for something like this and could not find it.

Crown421 avatar Sep 06 '23 15:09 Crown421

Oh yeah go for it, contributions welcome!

wsmoses avatar Sep 06 '23 15:09 wsmoses