KernelFunctions.jl
KernelFunctions.jl copied to clipboard
SimpleKernel performance with Zygote
It looks like we've got a performance bug in the above. Specifically, the primal looks to be fine:
julia> @benchmark kernelmatrix(SEKernel(), $(randn(100)))
BenchmarkTools.Trial:
memory estimate: 158.25 KiB
allocs estimate: 8
--------------
minimum time: 92.684 μs (0.00% GC)
median time: 155.135 μs (0.00% GC)
mean time: 181.207 μs (12.32% GC)
maximum time: 15.995 ms (98.92% GC)
--------------
samples: 10000
evals/sample: 1
while Zygotes forwards-pass is about 100x worse:
julia> @benchmark Zygote.pullback(kernelmatrix, SEKernel(), $(randn(100)))
BenchmarkTools.Trial:
memory estimate: 7.25 MiB
allocs estimate: 290068
--------------
minimum time: 34.106 ms (0.00% GC)
median time: 35.180 ms (0.00% GC)
mean time: 36.456 ms (2.37% GC)
maximum time: 60.521 ms (25.00% GC)
--------------
samples: 138
evals/sample: 1
Given that the bar for a reasonably performant forwards-pass is 1-2x the primal, this is definitely sub-optimal.
I'm reasonably sure that the culprit is the use of Base.Fix1 here: https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/blob/13985cc6cb8903004f33e6fe9d2540571594a3c8/src/matrix/kernelmatrix.jl#L95
maping interesting types in is generally a bad idea when used in conjunction with Zygote unfortunately.
@theogf @devmotion I'm not sure what our best options are here. I think there are two questions here:
- could we tweak the current implementation a bit into something that Zygote likes, and get sane performance?
- what escape hatches do we have / could we have to make it straightforward to hand-improve performance where necessary?
For example, the use of Base.Fix1 here might actually be a real win, because we could hand-optimise the implementation of map(::Base.Fix1{typeof(kappa), SEKernel}, ::Array{<:Real}) quite straightforwardly. Unfortunately, this would essentially mean that any kernel that we care about having good performance with would require us to hand-implement stuff. If we did this using ForwardDiff, it might be fine though 🤷
n.b. it's definitely not the pairwise computations:
julia> @benchmark Zygote.pullback(KernelFunctions.pairwise, KernelFunctions.SqEuclidean(), randn(100))
BenchmarkTools.Trial:
memory estimate: 82.25 KiB
allocs estimate: 45
--------------
minimum time: 28.473 μs (0.00% GC)
median time: 59.303 μs (0.00% GC)
mean time: 74.604 μs (15.96% GC)
maximum time: 15.654 ms (99.13% GC)
--------------
samples: 10000
evals/sample: 1
they've been hand-optimised for ages, so it would have been really surprising if there were a problem there.
I can reproduce the benchmarks. And replacing Fix1 with a regular closure seems to fix the problem:
julia> function kernelmatrix2(κ::KernelFunctions.SimpleKernel, x::AbstractVector)
return map(d -> KernelFunctions.kappa(κ, d), KernelFunctions.pairwise(KernelFunctions.metric(κ), x))
end
kernelmatrix2 (generic function with 1 method)
julia> @benchmark kernelmatrix2($(SEKernel()), $(randn(100)))
BenchmarkTools.Trial:
memory estimate: 158.25 KiB
allocs estimate: 8
--------------
minimum time: 80.734 μs (0.00% GC)
median time: 84.674 μs (0.00% GC)
mean time: 97.187 μs (6.26% GC)
maximum time: 2.333 ms (95.08% GC)
--------------
samples: 10000
evals/sample: 1
julia> @benchmark Zygote.pullback($kernelmatrix2, $(SEKernel()), $(randn(100)))
BenchmarkTools.Trial:
memory estimate: 1019.70 KiB
allocs estimate: 65
--------------
minimum time: 160.745 μs (0.00% GC)
median time: 196.239 μs (0.00% GC)
mean time: 235.936 μs (15.31% GC)
maximum time: 2.178 ms (88.50% GC)
--------------
samples: 10000
evals/sample: 1
It would be interesting to figure out why exactly Fix1 leads to much worse performance with Zygote, since then maybe one could fix the underlying problem here. E.g., both implementations seem to lead to the same type instabilities:
julia> @code_warntype Zygote.pullback(kernelmatrix, SEKernel(), randn(100))
Variables
#self#::Core.Const(ZygoteRules.pullback)
f::Core.Const(KernelFunctions.kernelmatrix)
args::Core.PartialStruct(Tuple{SqExponentialKernel, Vector{Float64}}, Any[Core.Const(Squared Exponential Kernel), Vector{Float64}])
#41::Zygote.var"#41#42"{_A} where _A
@_5::Int64
back::typeof(∂(kernelmatrix))
y::Any
Body::Tuple{Any, Zygote.var"#41#42"{_A} where _A}
1 ─ %1 = Core.tuple(f)::Core.Const((KernelFunctions.kernelmatrix,))
│ %2 = Core._apply_iterate(Base.iterate, Zygote._pullback, %1, args)::Tuple{Any, typeof(∂(kernelmatrix))}
│ %3 = Base.indexed_iterate(%2, 1)::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(2)])
│ (y = Core.getfield(%3, 1))
│ (@_5 = Core.getfield(%3, 2))
│ %6 = Base.indexed_iterate(%2, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{typeof(∂(kernelmatrix)), Int64}, Any[typeof(∂(kernelmatrix)), Core.Const(3)])
│ (back = Core.getfield(%6, 1))
│ %8 = y::Any
│ %9 = Zygote.:(var"#41#42")::Core.Const(Zygote.var"#41#42")
│ %10 = Core.typeof(back)::Type{typeof(∂(kernelmatrix))} where _A
│ %11 = Core.apply_type(%9, %10)::Type{Zygote.var"#41#42"{_A}} where _A
│ (#41 = %new(%11, back))
│ %13 = #41::Zygote.var"#41#42"{_A} where _A
│ %14 = Core.tuple(%8, %13)::Tuple{Any, Zygote.var"#41#42"{_A} where _A}
└── return %14
julia> @code_warntype Zygote.pullback(kernelmatrix2, SEKernel(), randn(100))
Variables
#self#::Core.Const(ZygoteRules.pullback)
f::Core.Const(kernelmatrix2)
args::Core.PartialStruct(Tuple{SqExponentialKernel, Vector{Float64}}, Any[Core.Const(Squared Exponential Kernel), Vector{Float64}])
#41::Zygote.var"#41#42"{_A} where _A
@_5::Int64
back::typeof(∂(kernelmatrix2))
y::Any
Body::Tuple{Any, Zygote.var"#41#42"{_A} where _A}
1 ─ %1 = Core.tuple(f)::Core.Const((kernelmatrix2,))
│ %2 = Core._apply_iterate(Base.iterate, Zygote._pullback, %1, args)::Tuple{Any, typeof(∂(kernelmatrix2))}
│ %3 = Base.indexed_iterate(%2, 1)::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(2)])
│ (y = Core.getfield(%3, 1))
│ (@_5 = Core.getfield(%3, 2))
│ %6 = Base.indexed_iterate(%2, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{typeof(∂(kernelmatrix2)), Int64}, Any[typeof(∂(kernelmatrix2)), Core.Const(3)])
│ (back = Core.getfield(%6, 1))
│ %8 = y::Any
│ %9 = Zygote.:(var"#41#42")::Core.Const(Zygote.var"#41#42")
│ %10 = Core.typeof(back)::Type{typeof(∂(kernelmatrix2))} where _A
│ %11 = Core.apply_type(%9, %10)::Type{Zygote.var"#41#42"{_A}} where _A
│ (#41 = %new(%11, back))
│ %13 = #41::Zygote.var"#41#42"{_A} where _A
│ %14 = Core.tuple(%8, %13)::Tuple{Any, Zygote.var"#41#42"{_A} where _A}
└── return %14
I'm also finding that your fix fixes the problem, and agree that it would be good to get to the bottom of this.
In the mean time, I'll open a PR with the fix.
@DhairyaLGandhi any thoughts on the general topic of Base.Fix1 vs a regular closure? I don't have a strong intuition about what's going on here.