KernelFunctions.jl icon indicating copy to clipboard operation
KernelFunctions.jl copied to clipboard

Resolve type stability of evaluation of `KernelSum` (WIP)

Open simsurace opened this issue 2 years ago • 2 comments

Summary This is an attempt to patch #458. Since this is probably some deeper Julia type inference issue, this solution is likely to be temporary.

Proposed changes For now, I just added tests that expose the type instability. For those kernels, and only or those, the issue can be fixed by including an evaluation of sum in the module initialization function.

What alternatives have you considered? Re-writing https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/blob/ce7923f8252d183768f247256079a2c8279f7b1b/src/kernels/kernelsum.jl#L46 either by using do or let blocks, or passing an init argument to sum. None of those made the tests pass, but they work if the function is redefined in the REPL, after the module has been loaded.

Breaking changes None.

simsurace avatar Jun 23 '22 18:06 simsurace

Codecov Report

Base: 68.82% // Head: 68.87% // Increases project coverage by +0.04% :tada:

Coverage data is based on head (db9cfd9) compared to base (4ce3e87). Patch coverage: 100.00% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #459      +/-   ##
==========================================
+ Coverage   68.82%   68.87%   +0.04%     
==========================================
  Files          52       52              
  Lines        1344     1346       +2     
==========================================
+ Hits          925      927       +2     
  Misses        419      419              
Impacted Files Coverage Δ
src/kernels/kernelsum.jl 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

codecov[bot] avatar Jun 23 '22 18:06 codecov[bot]

There seem to be some unrelated failures.

simsurace avatar Jun 23 '22 18:06 simsurace

This is weird. I'm finding that type-stability issues are resolved with the following implementations:

function kernelmatrix(κ::KernelSum, x::AbstractVector)
    return sum(map(Base.Fix2(kernelmatrix, x), κ.kernels))
end

function kernelmatrix(κ::KernelSum, x::AbstractVector, y::AbstractVector)
    return sum(map(k -> kernelmatrix(k, x, y), κ.kernels))
end

function kernelmatrix_diag(κ::KernelSum, x::AbstractVector)
    return sum(map(Base.Fix2(kernelmatrix_diag, x), κ.kernels))
end

function kernelmatrix_diag(κ::KernelSum, x::AbstractVector, y::AbstractVector)
    return sum(map(k -> kernelmatrix_diag(k, x, y), κ.kernels))
end

For some reason the implementation that uses the generator doesn't infer properly, but suming over the output of map seems to be fine. This is annoying because it doesn't work properly.

mapreduce also doesn't infer

I guess this is related to map being heavily optimised for Tuples, but perhaps mapreduce and whatever generator is produced in the current implementation aren't?

willtebbutt avatar Aug 30 '22 13:08 willtebbutt

Hmm, that'd then be similar to the hack I used in https://github.com/JuliaGaussianProcesses/GPLikelihoods.jl/pull/90, right?

simsurace avatar Aug 30 '22 14:08 simsurace

Indeed!

willtebbutt avatar Aug 30 '22 14:08 willtebbutt

I'd rather not use that here because the allocations might increase too much. A loopy version is probably also not so great for AD.

simsurace avatar Aug 30 '22 14:08 simsurace

I'd rather not use that here because the allocations might increase too much. A loopy version is probably also not so great for AD.

Agreed -- I mean, map ought to be completely fine for AD. It would only increase the allocations when outside of the context of reverse-mode AD when the total amount of memory allocated is less of an issue anyway. Just seems weird that there isn't a decent type-stable version of mapreduce for tuples anywhere in the ecosystem.

willtebbutt avatar Aug 30 '22 14:08 willtebbutt

Anyway, this seems to work on nightly, presumably due to https://github.com/JuliaLang/julia/pull/45789, but unfortunately that hasn't been backported to 1.8.

simsurace avatar Aug 30 '22 15:08 simsurace

Ahh I see. Well I'm happy to wait for 1.9 if it will just fix this -- @simsurace is this causing you unacceptable performance issues, or is it something that you could live with for a few months?

willtebbutt avatar Aug 30 '22 15:08 willtebbutt

Hmm I just did another minimal example on nightly and it didn't pass the test. I guess things might be more complicated after all. I don't really know how big of a performance issue this is. But hard-coding my kernel in terms of a custom type, where the sum is also hard-coded gave a moderate improvement and made the loss function fully type-stable.

simsurace avatar Aug 30 '22 18:08 simsurace

Okay. I'll have a think about this. It might be that there's a straightforward, albeit slightly more verbose, way to implement this that we're not seeing..

willtebbutt avatar Aug 30 '22 18:08 willtebbutt

The following fails on a nightly build on GitHub:

using Test

struct FunctionSum{Tf}
    functions::Tf
end

(F::FunctionSum)(x) = sum(f -> f(x), F.functions)

F = FunctionSum((x -> sqrt(x), FunctionSum((x -> x^2, x -> x^3))))
@inferred F(0.1)

To the extent that this is similar (JET.jl shows type inference problems that look about the same to the above kernel examples), this is still a Julia issue. Could someone here verify this error?

simsurace avatar Aug 30 '22 18:08 simsurace

Okay. I'll have a think about this. It might be that there's a straightforward, albeit slightly more verbose, way to implement this that we're not seeing..

On a related note, I was thinking about a totally different implementation that would optimize away e.g. multiple uses of the same distance matrix and similar things, using some kind of computational graph. Or is this supposed to be something that the Julia compiler can do on its own? I don't know whether it would help type stability though as long as those kinds of lazy constructions have type inference issues.

simsurace avatar Aug 30 '22 18:08 simsurace

Aha! This seems to work nicely:

_sum(f::Tf, x::Tuple) where {Tf} = f(x[1]) + _sum(f, Base.tail(x))

_sum(f::Tf, x::Tuple{Tx}) where {Tf, Tx} = f(x[1])

function kernelmatrix(κ::KernelSum, x::AbstractVector)
    return _sum(Base.Fix2(kernelmatrix, x), κ.kernels)
end

function kernelmatrix(κ::KernelSum, x::AbstractVector, y::AbstractVector)
    return _sum(k -> kernelmatrix(k, x, y), κ.kernels)
end

function kernelmatrix_diag(κ::KernelSum, x::AbstractVector)
    return _sum(Base.Fix2(kernelmatrix_diag, x), κ.kernels)
end

function kernelmatrix_diag(κ::KernelSum, x::AbstractVector, y::AbstractVector)
    return _sum(k -> kernelmatrix_diag(k, x, y), κ.kernels)
end

It's just a recursive implementation of sum(f, x) for the case in which x is a Tuple.

willtebbutt avatar Aug 30 '22 18:08 willtebbutt

Hmm, how's that going to be with AD?

simsurace avatar Aug 30 '22 18:08 simsurace

Seems to be fine (according to the unit tests). Generally speaking small tuples + recursion are going to be fine with Zygote I believe.

willtebbutt avatar Aug 30 '22 18:08 willtebbutt

I just did some benchmarks, looks fine to me.

using BenchmarkTools
using KernelFunctions
using Test
using Zygote

k = RBFKernel() + RBFKernel() * ExponentialKernel()

@inferred k(0., 1.)

x = rand(100)
@btime kernelmatrix($k, $x)
# before: 207.625 μs (33 allocations: 631.05 KiB)
# after:  212.833 μs (28 allocations: 630.91 KiB)
@btime Zygote.pullback(kernelmatrix, $k, $x)
# before: 331.417 μs (402 allocations: 2.54 MiB)
# after:  335.875 μs (235 allocations: 2.54 MiB)
out, pb = Zygote.pullback(kernelmatrix, k, x)
@btime $pb($out)
# before: 270.625 μs (545 allocations: 1.09 MiB)
# after:  199.208 μs (325 allocations: 1.09 MiB)




k = sum(rand() * RBFKernel() ∘ ScaleTransform(rand()) for _ in 1:20)

@inferred k(0., 1.)

x = rand(100)
@btime kernelmatrix($k, $x)
# before: 1.670 ms (258 allocations: 6.08 MiB)
# after:  1.646 ms (258 allocations: 6.08 MiB)
@btime Zygote.pullback(kernelmatrix, $k, $x)
# before: 2.979 ms (3144 allocations: 18.44 MiB)
# after:  2.861 ms (2901 allocations: 18.41 MiB)
out, pb = Zygote.pullback(kernelmatrix, k, x)
@btime $pb($out)
# before: 1.510 ms (3519 allocations: 6.38 MiB)
# after:  1.483 ms (3223 allocations: 6.35 MiB)

simsurace avatar Aug 30 '22 20:08 simsurace

EDIT: hold on, I now can't reproduce it. I still get type inference errors. I think the check_type_stability function is being compiled away.

simsurace avatar Aug 30 '22 20:08 simsurace

Ok, seems to be fine after all. There are now 15 additional tests being counted. Closes #458.

simsurace avatar Aug 31 '22 08:08 simsurace

This PR looks like it's basically ready to go, other than syncing it up with master. @simsurace is there anything else that you think needs doing?

willtebbutt avatar Sep 23 '22 12:09 willtebbutt

I don't think so.

simsurace avatar Sep 23 '22 14:09 simsurace

Cool. Do you have a few minutes to merge in the master branch to this branch so that we can merge this back into master, or are you happy for me to handle it from here?

willtebbutt avatar Sep 23 '22 14:09 willtebbutt

@simsurace have merged in changes from master and re-bumped the patch. Will merge when CI passes.

willtebbutt avatar Sep 26 '22 08:09 willtebbutt

Eurgh. Somehow the build has been broken. Going to need to figure out how before merging.

willtebbutt avatar Sep 26 '22 09:09 willtebbutt

Okay, I'm pretty sure that the failures aren't related to this PR, so I'm going to merge when CI (other than BaseKernels) passes

willtebbutt avatar Sep 26 '22 12:09 willtebbutt