Stheno.jl icon indicating copy to clipboard operation
Stheno.jl copied to clipboard

Flux compatible types that transform inputs

Open HamletWantToCode opened this issue 5 years ago • 4 comments

Based on our discussion here, we agree to design some types to transform input variables, we wish these types work seamlessly with Flux's API, transformations in KernelFunctions.jl are good examples.

Could we achieve it by just letting Stheno.jl work with KernelFunctions.jl ?

HamletWantToCode avatar Apr 10 '20 16:04 HamletWantToCode

Thanks for opening this!

Definitely, the plan is to get KernelFunctions.jl to play nicely with Stheno.jl, which will involve getting some interface improvements into KernelFunctions.jl eg. moving away from requiring inputs to be Matrixs and getting rid of the obsdim kwarg. See the discussion here.

I'm not sure that the transformation objects in KernelFunctions is really what I was getting at. They're more akin to Stheno's stretch

Comment from here:

While Stheno and KernelFunctions have nice APIs for input transformation ( e.g. Stretched, Scaled ), I think it's also important for us to pay attention to kernel hyperparameters, such as scaling factor and length scale, their values need more or less be constrained during the optimization. Therefore, besides types performing input transformation, we may also need types that set constraints to hyperparameters.

This isn't quite right. The Stretched and Scaled kernels are the things that let everything have a length scale and variance. In general, you can always think of changing the length scale of a kernel as an input transformation, so there's really no need to provide length-scales as such, because the appropriate input transforms do everything that you need. Similarly, the Scaled kernel handles the variance of a kernel. The stretch and * functions are just shorthands for these.

The thing I'm thinking about is precisely what you say about constraining kernel length scales and variances -- but here that corresponds to constraining the parameters of Stretched and Scaled (and the corresponding linear transformations in the CompositeGP case.

To do this, I was imagining having types along the following lines:

struct Positive{Tx, Tf} <: AbstractParameter
   x::Ref{Tx}
   f::Tf
   Positive(y, f) = new{Tx, Tf}(Ref(inv(f)(y)), f)
end

Positive(y) = Positive(y, Exp())

value(y::Positive) = y.f(y.x[])

where a Positive represents a positively-constrained value (or vector, or matrix?), encoded by eg. a Bijector from Bijectors.jl.

To construct it one would pass in the value you want, and the bijector that you want to use to map from unconstrained to constrained. eg. for a positive-valued quantity, you could use Exp.

The constructor would map the value from the positive reals to the whole of the real line using inv(f), then whenever you ask for the value, it applies the bijector to the unconstrained value to produce the constrained value. Note the use of a Ref for Flux compatibility.

You could also have a default constructor that provides a reasonable fallback ie. the Exp function.

We could implement a few of these to handle various types of constraints, probably just using the things in Bijectors.jl. The nice thing about using Bijectors.jl is that you know that inv gives us the inverse bijection, so there would be no need to worry about that.

I suspect that the reduction in total code won't be incredible, but it will definitely abstract away some of the detail from the kernel implementations in Stheno.jl, and make it nicer for users to understand how to parametrise things in a constrained way. We would obviously need to get these into KernelFunctions.jl in the near-ish future, but we could focus on a Stheno.jl-specific implementation for now and then move it over once we've worked out any awkward design issues.

Maybe we don't want to constrain the input transformation to be bijectors, but we can figure that detail out later.

willtebbutt avatar Apr 11 '20 21:04 willtebbutt

Hi @willtebbutt , sorry for this late response. I was busy on my project last week and just get time to carefully think about your comment.

The thing I'm thinking about is precisely what you say about constraining kernel length scales and variances. To do this, I was imagining having types along the following lines.

Bijector is much better than what I have thought about, after reading documentation of Bijector.jl, I think it's worth a try :)

So, just summarize what I think we are going to have ( in case where hyperparameters are scalar )

# define abstract parameter type
abstract type AbstractParameter{T} end
const AP{T} = AbstractParameter{T}

# define scalar parameter type
struct ScalarParameter{T, Tf<:Bijector} <: AbstractParameter{T}
    x::Ref{T}
    f::Tf
end

value(y::ScalarParameter) = y.f(y.x[])

# define vector parameter type ....

# modify transformation function
struct Stretched{Ta<:AP{<:Real}, Tk<:Kernel} <: Kernel
    a::Ta
    k::Tk
end

function ew(k::Stretched{<:AP{<:Real}}, x::AV{<:Real}, x′::AV{<:Real})
    return ew(k.k, value(k.a) .* x, value(k.a) .* x′)
end

parameters passed to transformation functions such as Stretched are classified into two types, constrained or unconstrained. For Unconstrained one, :f in ScalarParameter is taken to be identity function, while for constrained one, :f can be set to e.g. Exp.

Should we depend on Bijector.jl or implement our own Bijectors ?

Also notice you guys are working on improving KernelFunctions.jl here 👍

HamletWantToCode avatar Apr 20 '20 15:04 HamletWantToCode

Yes, this is exactly what I had in mind!

Should we depend on Bijector.jl or implement our own Bijectors ?

Definitely we should depend on Bijectors. Rolling our own would be bad imho.

Also notice you guys are working on improving KernelFunctions.jl here 👍

Yup, I'm hoping it will be ready for use with Stheno.jl in the not-too-distant future!

willtebbutt avatar Apr 21 '20 21:04 willtebbutt

Cool!, I will collect them into a PR after I finish my assignment this weeks :)

HamletWantToCode avatar Apr 22 '20 12:04 HamletWantToCode