KernelFunctions.jl
KernelFunctions.jl copied to clipboard
Resolve type stability of evaluation of `KernelSum` (WIP)
Summary This is an attempt to patch #458. Since this is probably some deeper Julia type inference issue, this solution is likely to be temporary.
Proposed changes
For now, I just added tests that expose the type instability. For those kernels, and only or those, the issue can be fixed by including an evaluation of sum
in the module initialization function.
What alternatives have you considered?
Re-writing
https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/blob/ce7923f8252d183768f247256079a2c8279f7b1b/src/kernels/kernelsum.jl#L46
either by using do
or let
blocks, or passing an init
argument to sum
.
None of those made the tests pass, but they work if the function is redefined in the REPL, after the module has been loaded.
Breaking changes None.
Codecov Report
Base: 68.82% // Head: 68.87% // Increases project coverage by +0.04%
:tada:
Coverage data is based on head (
db9cfd9
) compared to base (4ce3e87
). Patch coverage: 100.00% of modified lines in pull request are covered.
Additional details and impacted files
@@ Coverage Diff @@
## master #459 +/- ##
==========================================
+ Coverage 68.82% 68.87% +0.04%
==========================================
Files 52 52
Lines 1344 1346 +2
==========================================
+ Hits 925 927 +2
Misses 419 419
Impacted Files | Coverage Δ | |
---|---|---|
src/kernels/kernelsum.jl | 100.00% <100.00%> (ø) |
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.
:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.
There seem to be some unrelated failures.
This is weird. I'm finding that type-stability issues are resolved with the following implementations:
function kernelmatrix(κ::KernelSum, x::AbstractVector)
return sum(map(Base.Fix2(kernelmatrix, x), κ.kernels))
end
function kernelmatrix(κ::KernelSum, x::AbstractVector, y::AbstractVector)
return sum(map(k -> kernelmatrix(k, x, y), κ.kernels))
end
function kernelmatrix_diag(κ::KernelSum, x::AbstractVector)
return sum(map(Base.Fix2(kernelmatrix_diag, x), κ.kernels))
end
function kernelmatrix_diag(κ::KernelSum, x::AbstractVector, y::AbstractVector)
return sum(map(k -> kernelmatrix_diag(k, x, y), κ.kernels))
end
For some reason the implementation that uses the generator doesn't infer properly, but sum
ing over the output of map
seems to be fine. This is annoying because it doesn't work properly.
mapreduce
also doesn't infer
I guess this is related to map
being heavily optimised for Tuple
s, but perhaps mapreduce
and whatever generator is produced in the current implementation aren't?
Hmm, that'd then be similar to the hack I used in https://github.com/JuliaGaussianProcesses/GPLikelihoods.jl/pull/90, right?
Indeed!
I'd rather not use that here because the allocations might increase too much. A loopy version is probably also not so great for AD.
I'd rather not use that here because the allocations might increase too much. A loopy version is probably also not so great for AD.
Agreed -- I mean, map
ought to be completely fine for AD. It would only increase the allocations when outside of the context of reverse-mode AD when the total amount of memory allocated is less of an issue anyway. Just seems weird that there isn't a decent type-stable version of mapreduce
for tuples anywhere in the ecosystem.
Anyway, this seems to work on nightly, presumably due to https://github.com/JuliaLang/julia/pull/45789, but unfortunately that hasn't been backported to 1.8.
Ahh I see. Well I'm happy to wait for 1.9 if it will just fix this -- @simsurace is this causing you unacceptable performance issues, or is it something that you could live with for a few months?
Hmm I just did another minimal example on nightly and it didn't pass the test. I guess things might be more complicated after all. I don't really know how big of a performance issue this is. But hard-coding my kernel in terms of a custom type, where the sum is also hard-coded gave a moderate improvement and made the loss function fully type-stable.
Okay. I'll have a think about this. It might be that there's a straightforward, albeit slightly more verbose, way to implement this that we're not seeing..
The following fails on a nightly build on GitHub:
using Test
struct FunctionSum{Tf}
functions::Tf
end
(F::FunctionSum)(x) = sum(f -> f(x), F.functions)
F = FunctionSum((x -> sqrt(x), FunctionSum((x -> x^2, x -> x^3))))
@inferred F(0.1)
To the extent that this is similar (JET.jl shows type inference problems that look about the same to the above kernel examples), this is still a Julia issue. Could someone here verify this error?
Okay. I'll have a think about this. It might be that there's a straightforward, albeit slightly more verbose, way to implement this that we're not seeing..
On a related note, I was thinking about a totally different implementation that would optimize away e.g. multiple uses of the same distance matrix and similar things, using some kind of computational graph. Or is this supposed to be something that the Julia compiler can do on its own? I don't know whether it would help type stability though as long as those kinds of lazy constructions have type inference issues.
Aha! This seems to work nicely:
_sum(f::Tf, x::Tuple) where {Tf} = f(x[1]) + _sum(f, Base.tail(x))
_sum(f::Tf, x::Tuple{Tx}) where {Tf, Tx} = f(x[1])
function kernelmatrix(κ::KernelSum, x::AbstractVector)
return _sum(Base.Fix2(kernelmatrix, x), κ.kernels)
end
function kernelmatrix(κ::KernelSum, x::AbstractVector, y::AbstractVector)
return _sum(k -> kernelmatrix(k, x, y), κ.kernels)
end
function kernelmatrix_diag(κ::KernelSum, x::AbstractVector)
return _sum(Base.Fix2(kernelmatrix_diag, x), κ.kernels)
end
function kernelmatrix_diag(κ::KernelSum, x::AbstractVector, y::AbstractVector)
return _sum(k -> kernelmatrix_diag(k, x, y), κ.kernels)
end
It's just a recursive implementation of sum(f, x)
for the case in which x
is a Tuple
.
Hmm, how's that going to be with AD?
Seems to be fine (according to the unit tests). Generally speaking small tuples + recursion are going to be fine with Zygote I believe.
I just did some benchmarks, looks fine to me.
using BenchmarkTools
using KernelFunctions
using Test
using Zygote
k = RBFKernel() + RBFKernel() * ExponentialKernel()
@inferred k(0., 1.)
x = rand(100)
@btime kernelmatrix($k, $x)
# before: 207.625 μs (33 allocations: 631.05 KiB)
# after: 212.833 μs (28 allocations: 630.91 KiB)
@btime Zygote.pullback(kernelmatrix, $k, $x)
# before: 331.417 μs (402 allocations: 2.54 MiB)
# after: 335.875 μs (235 allocations: 2.54 MiB)
out, pb = Zygote.pullback(kernelmatrix, k, x)
@btime $pb($out)
# before: 270.625 μs (545 allocations: 1.09 MiB)
# after: 199.208 μs (325 allocations: 1.09 MiB)
k = sum(rand() * RBFKernel() ∘ ScaleTransform(rand()) for _ in 1:20)
@inferred k(0., 1.)
x = rand(100)
@btime kernelmatrix($k, $x)
# before: 1.670 ms (258 allocations: 6.08 MiB)
# after: 1.646 ms (258 allocations: 6.08 MiB)
@btime Zygote.pullback(kernelmatrix, $k, $x)
# before: 2.979 ms (3144 allocations: 18.44 MiB)
# after: 2.861 ms (2901 allocations: 18.41 MiB)
out, pb = Zygote.pullback(kernelmatrix, k, x)
@btime $pb($out)
# before: 1.510 ms (3519 allocations: 6.38 MiB)
# after: 1.483 ms (3223 allocations: 6.35 MiB)
EDIT: hold on, I now can't reproduce it. I still get type inference errors. I think the check_type_stability
function is being compiled away.
Ok, seems to be fine after all. There are now 15 additional tests being counted. Closes #458.
This PR looks like it's basically ready to go, other than syncing it up with master. @simsurace is there anything else that you think needs doing?
I don't think so.
Cool. Do you have a few minutes to merge in the master branch to this branch so that we can merge this back into master, or are you happy for me to handle it from here?
@simsurace have merged in changes from master and re-bumped the patch. Will merge when CI passes.
Eurgh. Somehow the build has been broken. Going to need to figure out how before merging.
Okay, I'm pretty sure that the failures aren't related to this PR, so I'm going to merge when CI (other than BaseKernels) passes