AbstractGPs.jl icon indicating copy to clipboard operation
AbstractGPs.jl copied to clipboard

Zygote errors with parameterized mean functions and multidimensional input

Open simsurace opened this issue 2 years ago • 10 comments

When trying to differentiate logpdf or other scalar functions with a parameterized mean function and multidimensional input, there are errors:

using AbstractGPs
using Zygote

pars = [1., 0.]

function build_model(pars)
    a, b = pars
    return GP(x -> a * first(x) + b, SEKernel())
end

rand_data(n::Integer) = rand(n), randn(n)
rand_data_2d(n::Integer) = RowVecs(rand(n, 2)), randn(n)

function test_logpdf(pars)
    f = build_model(pars)
    x, y = rand_data(10)
    return logpdf(f(x, 1e-3), y)
end

test_logpdf(pars)
Zygote.gradient(test_logpdf, pars) # works

function test_logpdf2(pars)
    f = build_model(pars)
    x, y = rand_data_2d(10)
    return logpdf(f(x, 1e-3), y)
end

test_logpdf2(pars)
Zygote.gradient(test_logpdf2, pars)
# ERROR: MethodError: no method matching +(::NamedTuple{(:X,), Tuple{LinearAlgebra.Transpose{Float64, Matrix{Float64}}}}, ::Vector{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}})

function test_mean(pars)
    f = build_model(pars)
    x, _ = rand_data_2d(10)
    return sum(mean(f(x, 1e-3)))
end

test_mean(pars)
Zygote.gradient(test_mean, pars) # ERROR: Pullback on AbstractVector{<:AbstractVector}.

function test_post_mean(pars)
    f = build_model(pars)
    x, y = rand_data_2d(10)
    fp = posterior(f(x, 1e-3), y)
    return sum(mean(fp(x, 1e-3)))
end

test_post_mean(pars)
Zygote.gradient(test_post_mean, pars) 
# ERROR: MethodError: no method matching +(::NamedTuple{(:X,), Tuple{LinearAlgebra.Transpose{Float64, Matrix{Float64}}}}, ::Vector{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}})

Is there a simple fix? The error for test_mean gives a suggestion to overload a kernelmatrix method, but that does not seem to be the issue since we are talking about the mean here. Why does the existing rrule for RowVecs not suffice?

simsurace avatar Dec 06 '22 14:12 simsurace

Hmmm this issue has come up repeatedly recently. All a problem in this Stheno.jl issue, I suspect for the same reasons.

To be honest, the simplest solution is going to be to implement AbstractGPs._map_meanfunction for your custom mean function and a RowVecs / ColVecs input, so that you can be sure that it's differentiable. So, something like

function AbstractGPs._map_meanfunction(f::CustomMean{typeof(your_mean_function)}, x::RowVecs)

end

etc. @simsurace could you let me know if this solves the problem?

I can't see this issue surrounding the mean function getting fixed any time soon (because it's AD-related), so I'm wondering whether we should change our approach to documenting it. e.g. making it clear that if you might well need to implement _map_meanfunction if you're using a custom mean function.

willtebbutt avatar Dec 06 '22 14:12 willtebbutt

Thanks for the suggestion. Maybe I misunderstood, but this did not solve the issue:

struct LinearMean{T}
    a::T
    b::T
end
(f::LinearMean)(x) = f.a * first(x) + f.b

using AbstractGPs: CustomMean
function AbstractGPs._map_meanfunction(f::CustomMean{<:LinearMean}, x::RowVecs)
    @info "Calling specialized function"
    return [f.f.a * first(xi) + f.f.b for xi in x]
end

function build_model(pars)
    a, b = pars
    return GP(CustomMean(LinearMean(a, b)), SEKernel())
end

simsurace avatar Dec 06 '22 15:12 simsurace

Ah, sorry, I more mean something like

function AbstractGPs._map_meanfunction(f::CustomMean{<:LinearMean}, x::RowVecs)
    @info "Calling specialized function"
    return f.f.a * x.X[1, :] .+ f.f.b
end

so that you're interacting with the underlying matrix.

willtebbutt avatar Dec 06 '22 15:12 willtebbutt

Oh, I get it. Thanks, this seems to do the trick. Actually, wrapping LinearMean in CustomMean seems overly complicated. I could make LinearMean <: MeanFunction and then define _map_meanfunction accordingly, right?

simsurace avatar Dec 06 '22 15:12 simsurace

P.S. CustomMean currently does not seem to be exported or documented, for that matter.

simsurace avatar Dec 06 '22 15:12 simsurace

Oh, I get it. Thanks, this seems to do the trick. Actually, wrapping LinearMean in CustomMean seems overly complicated. I could make LinearMean <: MeanFunction and then define _map_meanfunction accordingly, right?

That should indeed work!

willtebbutt avatar Dec 06 '22 16:12 willtebbutt

I ended up with a general struct FunctionOfTime{Tf} <: MeanFunction with overloads that map its field over slices. This works. Thanks for the tips!

simsurace avatar Dec 06 '22 16:12 simsurace

I also ran into this issue recently, and because you end up hitting the def in KernelFunctions, debugging is somewhat confusing :confused:

Probably worth an entry in the docs + maybe changing the error in KernelFunctions?

torfjelde avatar Sep 12 '23 10:09 torfjelde

I agree that the docs should probably be improved here

willtebbutt avatar Sep 12 '23 13:09 willtebbutt

We should probably add a note about when you the need to implement mean_vector yourself for CustomMean here and here, and provide an example.

willtebbutt avatar Sep 12 '23 13:09 willtebbutt