KernelFunctions.jl icon indicating copy to clipboard operation
KernelFunctions.jl copied to clipboard

Remove hard-coded `Vector` fields and add vectors constructors

Open theogf opened this issue 3 years ago • 29 comments

Right now kernels like LinearKernel have two problems: they can only take Real arguments (no way to pass kernel.c) as an argument and also they do not allow for different AbstractVector types to be stored. The interested use case is when using GPUs, one cannot do kernelmatrix on CuArrays without this.

struct LinearKernel{Tc<:Real} <: SimpleKernel
    c::Vector{Tc}

    function LinearKernel(c::Real)
        @check_args(LinearKernel, c, c >= zero(c), "c ≥ 0")
        return new{typeof(c)}([c])
    end
end

should be

struct LinearKernel{Tc<:Real,Vc<:AbstractVector{<:Tc}} <: SimpleKernel
    c::Vc
    function LinearKernel(c::V) where {T<:Real, V<:AbstractVector{T}}
        @check_args(LinearKernel, first(c), first(c) >= zero(c), "c ≥ 0")
        return new{T,V}(c)
    end
    function LinearKernel(c::Real)
        @check_args(LinearKernel, c, c >= zero(c), "c ≥ 0")
        C = [c]
        return new{eltype(C),typeof(C)}(C)
    end
end

theogf avatar Jun 10 '21 16:06 theogf

Or we could just not make it a vector at all but a scalar :stuck_out_tongue:

devmotion avatar Jun 10 '21 16:06 devmotion

If we stick with vectors (e.g. because of Flux/Zygote), in the example I guess LinearKernel(c::Real) could be a simple outer constructor LinearKernel(c::Real) = LinearKernel([c]).

devmotion avatar Jun 10 '21 16:06 devmotion

But then how do I do my inplace updates! :laughing:. Also nobody like scalars! Especially GPUs!

theogf avatar Jun 10 '21 16:06 theogf

One could eg use ParameterHandling and perform in-place updates of the resulting vector 😉 But then you would still have to go back and reconstruct the kernel.

devmotion avatar Jun 10 '21 17:06 devmotion

Did we decide that Refs were a bad idea? i.e. we shouldn't be using Ref(5.0) doesn't play nicely with Functors.jl?

willtebbutt avatar Jun 14 '21 12:06 willtebbutt

Ref is not captured by Flux.params AFAICT

theogf avatar Jun 14 '21 12:06 theogf

I still think one could consider dropping support for Flux.params and switching to scalar parameters instead. Both Ref and Vector seem a bit unnatural.

devmotion avatar Jun 14 '21 12:06 devmotion

Ref is not captured by Flux.params AFAICT

:(

I still think one could consider dropping support for Flux.params and switching to scalar parameters instead. Both Ref and Vector seem a bit unnatural.

I can definitely see the appeal of this, but I would really like this ecosystem to be compatible with the Flux one e.g. so that you can map the inputs to a GP through a DNN defined using Flux. I'm not saying that this is necessarily a good idea in general, see [1], but I'd at least like to avoid ruling it out.

If we didn't use the Functors API, is there another way to interop nicely with Flux?

[1] - https://arxiv.org/abs/2102.12108

willtebbutt avatar Jun 14 '21 13:06 willtebbutt

One could eg use ParameterHandling and perform in-place updates of the resulting vector But then you would still have to go back and reconstruct the kernel.

You mentionned ParameterHandling and it looks very nice and it seems the only requirement on our side would be to implement the flatten approach. However that means any additional operations on the kernels would need to follow the ParameterHandling approach no?

theogf avatar Jun 14 '21 13:06 theogf

Not necessarily. You can implemented flatten for all of the things in your model, but my preferred approach is to use ParameterHandling.jl to make it easy to write a function which builds your model.

For example, see here.

This is a different style from the one that you prefer though @theogf .

willtebbutt avatar Jun 14 '21 13:06 willtebbutt

I haven't thought this through but maybe it is possible to write a custom Flux layer that contains only the parameter vector and a function that builds the kernel:

struct KernelLayer{P,F}
    params::P
    model::F
end

Functors.@functor KernelLayer (params,)

(f::KernelLayer)(x, y) = f.model(f.params)(x, y)
...

devmotion avatar Jun 14 '21 13:06 devmotion

Not necessarily. You can implemented flatten for all of the things in your model, but my preferred approach is to use ParameterHandling.jl to make it easy to write a function which builds your model.

I think I got this part. However I wanted to say that I am also happy with the destructure approach of Flux/Functors as well, which seems "similar" to ParameterHandling.

My point is that if we can have params, reconstruct = flatten(my_very_complicated_kernel) I am also happy with that :)

theogf avatar Jun 14 '21 13:06 theogf

I haven't thought this through but maybe it is possible to write a custom Flux layer that contains only the parameter vector and a function that builds the kernel:

Interesting -- how are you imagining this interacting with Flux @devmotion ? I don't think I'm seeing where you're going with this yet.

willtebbutt avatar Jun 14 '21 13:06 willtebbutt

As I said, I didn't spend much thought on it. But the main idea would be that instead of forcing all kernels to be mutable we just have one kernel that just wraps the parameters of a possibly complicated kernel in a Flux-compatible way (e.g. as a Vector) and includes a function that constructs the kernel from the parameters. Due to the functor definition, Flux would only work with and update the parameters.

devmotion avatar Jun 14 '21 13:06 devmotion

As I said, I didn't spend much thought on it. But the main idea would be that instead of forcing all kernels to be mutable we just have one kernel that just wraps the parameters of a possibly complicated kernel in a Flux-compatible way (e.g. as a Vector) and includes a function that constructs the kernel from the parameters. Due to the functor definition, Flux would only work with and update the parameters.

Ok so that's like a complementary idea right? We would need a conversion from any kernel to a KernelLayer to be then used with any optimization tool? The conversion could be handled in back-end by ParameterHandling or something handcrafted.

theogf avatar Jun 14 '21 13:06 theogf

We would need a conversion from any kernel to a KernelLayer to be then used with any optimization tool?

The main idea is that it only requires the output of ParameterHandling.flatten or Functors.functor (if one prefers @theogf's approach) or the custom function to build the kernel (apparently @willtebbutt's preferred approach). So there is not any additional support needed.

devmotion avatar Jun 14 '21 13:06 devmotion

I came back to this issue some days ago and thought a bit more about the problems and the suggestion above and started to work on a draft.

Currently my suggestion would be the following breaking changes:

  • remove support for Functors (almost) completely
  • remove all Vector workarounds and make scalar parameters scalar fields
  • implement ParameterHandling.flatten for kernels and transforms (to make @theogf happy 🙂) that handles eg positivity constraints but emphasize the advantages of specifying a parameter vector and a function that constructs the kernel from a parameter vector manually
  • implement a KernelModel/KernelLayer, basically as suggested above
  • add a helper function that constructs such a parameterized kernel from a given possibly nested kernel by calling KernelModel(flatten(kernel)...)

My main reasoning is:

  • It seems Functors is too limited and not the right tool for handling parameters: eg., it does not provide a way to specify the domain of parameters, so in general it seems challenging/impossible to just optimize Flux.Params(kernel) (the example in the docs works since it does only train an input transform but not kernel parameters)
  • ParameterHandling became much more lightweight in 0.4.0 since the dependency on Bijectors was removed
  • Vector fields for scalar parameters are not a satisfying design (IMO), cause allocations, and eg. cause problems with GPUs
  • a default implementation flatten with eg exp transformation for positive parameters seems convenient
  • however, it does not allow to use non-standard transforms (eg log1pexp), to fix specific parameters without an additional wrapper, or to introduce dependencies between different parameters (eg. identical parameters in some kernels in a sum of kernels) which can be done more easily if one defines a function that builds the kernel
  • if the parameter vector and the module function (obtained from flatten or defined manually) are coupled in a kernel it is easier to integrate a kernel with an array of parameters in a larger modeling pipeline
  • the @functor KernelModel (params,) definition would allow to extract the kernel parameters automically with Flux.Params and to optimize them, in particular if the kernel is part of a larger @functor based model.

devmotion avatar Oct 14 '21 21:10 devmotion

the @functor KernelModel (params,) definition would allow to extract the kernel parameters automically with Flux.Params and to optimize them, in particular if the kernel is part of a larger @functor based model.

Could you elaborate a little on what this might look like?

willtebbutt avatar Oct 18 '21 11:10 willtebbutt

E.g., in https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/examples/deep-kernel-learning/ one approach would be to remove the @functor definitions (I don't think they are needed currently either, they are defined in KernelFunctions already and not useful for kernels without parameters such as SqExponentialKernel) and replace

neuralnet = Chain(Dense(1, 3), Dense(3, 2))
k = SqExponentialKernel() ∘ FunctionTransform(neuralnet)

with

neuralnet = Chain(Dense(1, 3), Dense(3, 2))
k = KernelModel(flatten(SqExponentialKernel())...) ∘ FunctionTransform(neuralnet)

Then the rest of the example would work in the same way as now and it would be as efficient as now. Clearly, this is a stupid example since SqExponentialKernel does not have any parameters and it would be "cooler" if SqExponentialKernel() is replaced by a combination of kernels with different parameters. I would suggest to use @functor only for such high-level kernels such as TransformedKernel (which k is an example of) and transforms such as FunctionTransform where we want to propagate always to nested elements. Proper kernels, with e.g. real-valued or positive parameters, would have to be casted to a KernelModel by users (as mentioned above, a more convenient syntax than KernelModel(flatten(...)...) would be nice) or already be designed as KernelModel, and are ignored otherwise by functor and hence e.g. Flux.Params.

devmotion avatar Oct 18 '21 12:10 devmotion

Ahhh I think I understand now! You're saying we provide a KernelModel interface containing the parameters and the reconstruction function right?. It is compatible with Functors, because we just tell it that the flattened array representation is the trainable parameters. In this case, I would even argue that we don't have to implement @functor only for specific objects (like FunctionTransform) as we can just use flatten.

We could default flatten to try to use the functor API for unknown structs/objects. Such that flatten(neuralnet) would just call the Functors API And then have

k = compose(SqExponentialKernel(), FunctionTransform(neuralnet))
model = KernelModel(flatten(k)...)

theogf avatar Oct 18 '21 12:10 theogf

I was worried that such generic fallbacks introduce a) type piracy and b) surprising behaviour. I became also more and more convinced that it would be better to error if someone calls flatten for a kernel for which no explicit flatten implementation exists instead of assuming that the kernel contains no parameters (or implements functor). Otherwise it would be easy to miss some parameters without noticing it.

I also think it can be more efficient for larger models to not only use flatten here (i.e., move everything into a KernelModel): It will concatenate all parameters of the NN together with the parameter of the kernel whereas functor would just return them separately (Flux might still combine them at some point but we don't enforce it). In general, often flatten will be quite inefficient in nested models due to multiple break ups of vectors and intermediate allocations and so one should really emphasize this in the documentation and recommend to instead explicitly write a function that builds the model if one wants to make it more efficient.

devmotion avatar Oct 18 '21 13:10 devmotion

Thanks for the explanation, not having a general fallback seems safer and faster indeed.

theogf avatar Oct 18 '21 13:10 theogf

a default implementation flatten with eg exp transformation for positive parameters seems convenient

How do you envision this? Should we still allow to pass Reals as is or should all (trainable) fields be defined as ParameterHandling types?

Implementation-wise just using PH types sounds easier, but it also sounds like an unnecessary extra cost if one does not do any optimization. I am particularly thinking of how the constructors should look like if we allow both... Would using a keyword argument like fixed=true lead to a ~~type unstable~~ constructor where the output cannot be inferred at compile time?

Here is a quick try to see how it would go

julia> struct Foo{T}
         a::T
         function Foo(x; fixed=true)
           if fixed
             return new{typeof(float(x)}(float(x)) # I am just putting this as a dummy example
           else
              return new{typeof(x)}(x)
           end
         end
       end

julia> Foo(2)
Foo{Float64}(2.0)

julia> Foo(2; fixed=false)
Foo{Int64}(2)

julia> @code_warntype Foo(2)
Variables
  #self#::Type{Foo}
  x::Int64

Body::Foo{Float64}
1 ─ %1 = $(QuoteNode(var"#Foo#3#4"()))
│   %2 = (%1)(true, #self#, x)::Foo{Float64}
└──      return %2

julia> @code_warntype Foo(2; fixed=false)
Variables
  #unused#::Core.Const(Core.var"#Type##kw"())
  @_2::NamedTuple{(:fixed,), Tuple{Bool}}
  @_3::Type{Foo}
  x::Int64
  fixed::Bool
  @_6::Bool

Body::Union{Foo{Float64}, Foo{Int64}}
1 ─ %1  = Base.haskey(@_2, :fixed)::Core.Const(true)
│         Core.typeassert(%1, Core.Bool)
│         (@_6 = Base.getindex(@_2, :fixed))
└──       goto #3
2 ─       Core.Const(:(@_6 = true))
3 ┄ %6  = @_6::Bool
│         (fixed = %6)
│   %8  = (:fixed,)::Core.Const((:fixed,))
│   %9  = Core.apply_type(Core.NamedTuple, %8)::Core.Const(NamedTuple{(:fixed,), T} where T<:Tuple)
│   %10 = Base.structdiff(@_2, %9)::Core.Const(NamedTuple())
│   %11 = Base.pairs(%10)::Core.Const(Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}())
│   %12 = Base.isempty(%11)::Core.Const(true)
│         Core.typeassert(%12, Core.Bool)
└──       goto #5
4 ─       Core.Const(:(Base.kwerr(@_2, @_3, x)))
5 ┄ %16 = $(QuoteNode(var"#Foo#3#4"()))
│   %17 = fixed::Bool
│   %18 = (%16)(%17, @_3, x)::Union{Foo{Float64}, Foo{Int64}}
└──       return %18

theogf avatar Oct 18 '21 13:10 theogf

No, this was not what I had in mind and wanted to suggest. I do not want to allow ParameterHandling types in kernels, they are not even Reals and cause undesired complexity IMO.

I just wanted to suggest using them in the flatten implementation. I view flatten mainly as a convenient (but as mentioned not necessarily efficient) mechanism to retrieve a vector of parameters and a function to reconstruct the kernel. The vectorized form (also the one obtained with @functor) is a bit useless in general though if it is not clear to which domain the individual elements are (or should be) restricted and hence I would suggest to return only unconstrained parameters. Here I think it would be convenient to make use of e.g. positive and bounded. A more concrete example:

struct GaussianKernelWithLengthscale{T<:Real} <: SimpleKernel
    l::T
end

function ParameterHandling.flatten(::Type{T}, k::GaussianKernelWithLengthscale)
    # uses `exp` by default - if one wants to use e.g. `log1pexp` one has to implement the model explicitly
    lvec, unflatten_to_l = value_flatten(T, positive(k.l))
    function unflatten_to_gaussiankernelwithlengthscale(x::Vector{T})
        return GaussianKernelWithLengthscale(unflatten_to_l(x))
    end
    return lvec, unflatten_to_gaussiankernelwithlengthscale
end

# implementation of the kernel
...

devmotion avatar Oct 18 '21 13:10 devmotion

Ah ok! But then if we start to define a default flatten for ScaleTransform using exp for example, how do one overrides it?

theogf avatar Oct 18 '21 13:10 theogf

One doesn't. That's the whole point of flatten IMO - it is just one opinionated way of flattening the parameters and usually not the most efficient one. It doesn't allow you to fix parameters before flattening the kernel. It isn't able to handle parameter dependencies between kernels (e.g., some parameters of kernels in a sum should always be the same).

If you want something more custom or complex, you just have to write the function that builds the model.

devmotion avatar Oct 18 '21 13:10 devmotion

If that's okay with you I would be happy to build the PR! I think I got all the elements, but since I am obviously not the most qualified person to do so I would understand if you (@devmotion) or @willtebbutt would like to do it!

theogf avatar Oct 18 '21 14:10 theogf

I already have something in a local branch, I had to play around a bit before I ended up with the suggestion above. But before finishing and polishing it I wanted to discuss the ideas with you :slightly_smiling_face: If you think we should do this, I will clean up the local branch and open a PR.

devmotion avatar Oct 18 '21 14:10 devmotion

I am definitely in favor, we just need @st-- and @willtebbutt inputs.

theogf avatar Oct 18 '21 14:10 theogf