AbstractGPs.jl
AbstractGPs.jl copied to clipboard
Zygote errors with parameterized mean functions and multidimensional input
When trying to differentiate logpdf
or other scalar functions with a parameterized mean function and multidimensional input, there are errors:
using AbstractGPs
using Zygote
pars = [1., 0.]
function build_model(pars)
a, b = pars
return GP(x -> a * first(x) + b, SEKernel())
end
rand_data(n::Integer) = rand(n), randn(n)
rand_data_2d(n::Integer) = RowVecs(rand(n, 2)), randn(n)
function test_logpdf(pars)
f = build_model(pars)
x, y = rand_data(10)
return logpdf(f(x, 1e-3), y)
end
test_logpdf(pars)
Zygote.gradient(test_logpdf, pars) # works
function test_logpdf2(pars)
f = build_model(pars)
x, y = rand_data_2d(10)
return logpdf(f(x, 1e-3), y)
end
test_logpdf2(pars)
Zygote.gradient(test_logpdf2, pars)
# ERROR: MethodError: no method matching +(::NamedTuple{(:X,), Tuple{LinearAlgebra.Transpose{Float64, Matrix{Float64}}}}, ::Vector{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}})
function test_mean(pars)
f = build_model(pars)
x, _ = rand_data_2d(10)
return sum(mean(f(x, 1e-3)))
end
test_mean(pars)
Zygote.gradient(test_mean, pars) # ERROR: Pullback on AbstractVector{<:AbstractVector}.
function test_post_mean(pars)
f = build_model(pars)
x, y = rand_data_2d(10)
fp = posterior(f(x, 1e-3), y)
return sum(mean(fp(x, 1e-3)))
end
test_post_mean(pars)
Zygote.gradient(test_post_mean, pars)
# ERROR: MethodError: no method matching +(::NamedTuple{(:X,), Tuple{LinearAlgebra.Transpose{Float64, Matrix{Float64}}}}, ::Vector{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}})
Is there a simple fix? The error for test_mean
gives a suggestion to overload a kernelmatrix
method, but that does not seem to be the issue since we are talking about the mean here. Why does the existing rrule
for RowVecs
not suffice?
Hmmm this issue has come up repeatedly recently. All a problem in this Stheno.jl issue, I suspect for the same reasons.
To be honest, the simplest solution is going to be to implement AbstractGPs._map_meanfunction
for your custom mean function and a RowVecs
/ ColVecs
input, so that you can be sure that it's differentiable. So, something like
function AbstractGPs._map_meanfunction(f::CustomMean{typeof(your_mean_function)}, x::RowVecs)
end
etc. @simsurace could you let me know if this solves the problem?
I can't see this issue surrounding the mean function getting fixed any time soon (because it's AD-related), so I'm wondering whether we should change our approach to documenting it. e.g. making it clear that if you might well need to implement _map_meanfunction
if you're using a custom mean function.
Thanks for the suggestion. Maybe I misunderstood, but this did not solve the issue:
struct LinearMean{T}
a::T
b::T
end
(f::LinearMean)(x) = f.a * first(x) + f.b
using AbstractGPs: CustomMean
function AbstractGPs._map_meanfunction(f::CustomMean{<:LinearMean}, x::RowVecs)
@info "Calling specialized function"
return [f.f.a * first(xi) + f.f.b for xi in x]
end
function build_model(pars)
a, b = pars
return GP(CustomMean(LinearMean(a, b)), SEKernel())
end
Ah, sorry, I more mean something like
function AbstractGPs._map_meanfunction(f::CustomMean{<:LinearMean}, x::RowVecs)
@info "Calling specialized function"
return f.f.a * x.X[1, :] .+ f.f.b
end
so that you're interacting with the underlying matrix.
Oh, I get it. Thanks, this seems to do the trick. Actually, wrapping LinearMean
in CustomMean
seems overly complicated. I could make LinearMean <: MeanFunction
and then define _map_meanfunction
accordingly, right?
P.S. CustomMean
currently does not seem to be exported or documented, for that matter.
Oh, I get it. Thanks, this seems to do the trick. Actually, wrapping LinearMean in CustomMean seems overly complicated. I could make LinearMean <: MeanFunction and then define _map_meanfunction accordingly, right?
That should indeed work!
I ended up with a general struct FunctionOfTime{Tf} <: MeanFunction
with overloads that map its field over slices. This works. Thanks for the tips!
I also ran into this issue recently, and because you end up hitting the def in KernelFunctions, debugging is somewhat confusing :confused:
Probably worth an entry in the docs + maybe changing the error in KernelFunctions?
I agree that the docs should probably be improved here