KernelFunctions.jl icon indicating copy to clipboard operation
KernelFunctions.jl copied to clipboard

SimpleKernel performance with Zygote

Open willtebbutt opened this issue 4 years ago • 3 comments

It looks like we've got a performance bug in the above. Specifically, the primal looks to be fine:

julia> @benchmark kernelmatrix(SEKernel(), $(randn(100)))
BenchmarkTools.Trial:
  memory estimate:  158.25 KiB
  allocs estimate:  8
  --------------
  minimum time:     92.684 μs (0.00% GC)
  median time:      155.135 μs (0.00% GC)
  mean time:        181.207 μs (12.32% GC)
  maximum time:     15.995 ms (98.92% GC)
  --------------
  samples:          10000
  evals/sample:     1

while Zygotes forwards-pass is about 100x worse:

julia> @benchmark Zygote.pullback(kernelmatrix, SEKernel(), $(randn(100)))
BenchmarkTools.Trial:
  memory estimate:  7.25 MiB
  allocs estimate:  290068
  --------------
  minimum time:     34.106 ms (0.00% GC)
  median time:      35.180 ms (0.00% GC)
  mean time:        36.456 ms (2.37% GC)
  maximum time:     60.521 ms (25.00% GC)
  --------------
  samples:          138
  evals/sample:     1

Given that the bar for a reasonably performant forwards-pass is 1-2x the primal, this is definitely sub-optimal.

I'm reasonably sure that the culprit is the use of Base.Fix1 here: https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/blob/13985cc6cb8903004f33e6fe9d2540571594a3c8/src/matrix/kernelmatrix.jl#L95

maping interesting types in is generally a bad idea when used in conjunction with Zygote unfortunately.

@theogf @devmotion I'm not sure what our best options are here. I think there are two questions here:

  1. could we tweak the current implementation a bit into something that Zygote likes, and get sane performance?
  2. what escape hatches do we have / could we have to make it straightforward to hand-improve performance where necessary?

For example, the use of Base.Fix1 here might actually be a real win, because we could hand-optimise the implementation of map(::Base.Fix1{typeof(kappa), SEKernel}, ::Array{<:Real}) quite straightforwardly. Unfortunately, this would essentially mean that any kernel that we care about having good performance with would require us to hand-implement stuff. If we did this using ForwardDiff, it might be fine though 🤷

n.b. it's definitely not the pairwise computations:

julia> @benchmark Zygote.pullback(KernelFunctions.pairwise, KernelFunctions.SqEuclidean(), randn(100))
BenchmarkTools.Trial:
  memory estimate:  82.25 KiB
  allocs estimate:  45
  --------------
  minimum time:     28.473 μs (0.00% GC)
  median time:      59.303 μs (0.00% GC)
  mean time:        74.604 μs (15.96% GC)
  maximum time:     15.654 ms (99.13% GC)
  --------------
  samples:          10000
  evals/sample:     1

they've been hand-optimised for ages, so it would have been really surprising if there were a problem there.

willtebbutt avatar Mar 28 '21 23:03 willtebbutt

I can reproduce the benchmarks. And replacing Fix1 with a regular closure seems to fix the problem:

julia> function kernelmatrix2(κ::KernelFunctions.SimpleKernel, x::AbstractVector)
           return map(d -> KernelFunctions.kappa(κ, d), KernelFunctions.pairwise(KernelFunctions.metric(κ), x))
       end
kernelmatrix2 (generic function with 1 method)

julia> @benchmark kernelmatrix2($(SEKernel()), $(randn(100)))
BenchmarkTools.Trial: 
  memory estimate:  158.25 KiB
  allocs estimate:  8
  --------------
  minimum time:     80.734 μs (0.00% GC)
  median time:      84.674 μs (0.00% GC)
  mean time:        97.187 μs (6.26% GC)
  maximum time:     2.333 ms (95.08% GC)
  --------------
  samples:          10000
  evals/sample:     1

julia> @benchmark Zygote.pullback($kernelmatrix2, $(SEKernel()), $(randn(100)))
BenchmarkTools.Trial: 
  memory estimate:  1019.70 KiB
  allocs estimate:  65
  --------------
  minimum time:     160.745 μs (0.00% GC)
  median time:      196.239 μs (0.00% GC)
  mean time:        235.936 μs (15.31% GC)
  maximum time:     2.178 ms (88.50% GC)
  --------------
  samples:          10000
  evals/sample:     1

It would be interesting to figure out why exactly Fix1 leads to much worse performance with Zygote, since then maybe one could fix the underlying problem here. E.g., both implementations seem to lead to the same type instabilities:

julia> @code_warntype Zygote.pullback(kernelmatrix, SEKernel(), randn(100))
Variables
  #self#::Core.Const(ZygoteRules.pullback)
  f::Core.Const(KernelFunctions.kernelmatrix)
  args::Core.PartialStruct(Tuple{SqExponentialKernel, Vector{Float64}}, Any[Core.Const(Squared Exponential Kernel), Vector{Float64}])
  #41::Zygote.var"#41#42"{_A} where _A
  @_5::Int64
  back::typeof(∂(kernelmatrix))
  y::Any

Body::Tuple{Any, Zygote.var"#41#42"{_A} where _A}
1 ─ %1  = Core.tuple(f)::Core.Const((KernelFunctions.kernelmatrix,))
│   %2  = Core._apply_iterate(Base.iterate, Zygote._pullback, %1, args)::Tuple{Any, typeof(∂(kernelmatrix))}
│   %3  = Base.indexed_iterate(%2, 1)::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(2)])
│         (y = Core.getfield(%3, 1))
│         (@_5 = Core.getfield(%3, 2))
│   %6  = Base.indexed_iterate(%2, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{typeof(∂(kernelmatrix)), Int64}, Any[typeof(∂(kernelmatrix)), Core.Const(3)])
│         (back = Core.getfield(%6, 1))
│   %8  = y::Any
│   %9  = Zygote.:(var"#41#42")::Core.Const(Zygote.var"#41#42")
│   %10 = Core.typeof(back)::Type{typeof(∂(kernelmatrix))} where _A
│   %11 = Core.apply_type(%9, %10)::Type{Zygote.var"#41#42"{_A}} where _A
│         (#41 = %new(%11, back))
│   %13 = #41::Zygote.var"#41#42"{_A} where _A
│   %14 = Core.tuple(%8, %13)::Tuple{Any, Zygote.var"#41#42"{_A} where _A}
└──       return %14

julia> @code_warntype Zygote.pullback(kernelmatrix2, SEKernel(), randn(100))
Variables
  #self#::Core.Const(ZygoteRules.pullback)
  f::Core.Const(kernelmatrix2)
  args::Core.PartialStruct(Tuple{SqExponentialKernel, Vector{Float64}}, Any[Core.Const(Squared Exponential Kernel), Vector{Float64}])
  #41::Zygote.var"#41#42"{_A} where _A
  @_5::Int64
  back::typeof(∂(kernelmatrix2))
  y::Any

Body::Tuple{Any, Zygote.var"#41#42"{_A} where _A}
1 ─ %1  = Core.tuple(f)::Core.Const((kernelmatrix2,))
│   %2  = Core._apply_iterate(Base.iterate, Zygote._pullback, %1, args)::Tuple{Any, typeof(∂(kernelmatrix2))}
│   %3  = Base.indexed_iterate(%2, 1)::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(2)])
│         (y = Core.getfield(%3, 1))
│         (@_5 = Core.getfield(%3, 2))
│   %6  = Base.indexed_iterate(%2, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{typeof(∂(kernelmatrix2)), Int64}, Any[typeof(∂(kernelmatrix2)), Core.Const(3)])
│         (back = Core.getfield(%6, 1))
│   %8  = y::Any
│   %9  = Zygote.:(var"#41#42")::Core.Const(Zygote.var"#41#42")
│   %10 = Core.typeof(back)::Type{typeof(∂(kernelmatrix2))} where _A
│   %11 = Core.apply_type(%9, %10)::Type{Zygote.var"#41#42"{_A}} where _A
│         (#41 = %new(%11, back))
│   %13 = #41::Zygote.var"#41#42"{_A} where _A
│   %14 = Core.tuple(%8, %13)::Tuple{Any, Zygote.var"#41#42"{_A} where _A}
└──       return %14

devmotion avatar Mar 29 '21 06:03 devmotion

I'm also finding that your fix fixes the problem, and agree that it would be good to get to the bottom of this.

In the mean time, I'll open a PR with the fix.

willtebbutt avatar Mar 29 '21 09:03 willtebbutt

@DhairyaLGandhi any thoughts on the general topic of Base.Fix1 vs a regular closure? I don't have a strong intuition about what's going on here.

willtebbutt avatar Mar 29 '21 09:03 willtebbutt