GPJax icon indicating copy to clipboard operation
GPJax copied to clipboard

feat: vector-valued GPs

Open emilemathieu opened this issue 2 years ago • 7 comments

Hi all,

First thanks for this nice library that I've recently starting using! ):)

I'd be interested in vector-valued GPs and from what I understand, this is not supported yet right? Or am I missing something? I've passed a kernel function which return matrices of shape [d x d].

gp = gpx.Prior(kernel=RBFCurlFree())
gp = gp(dict(kernel=kernel.params, mean_function={}))
y = gp(x)

I believe that the way to deal with this is to rearrange/reshape things as [N*d, N*d] but I don't really have much experience with how to easily deal with vector-valued GPs .

Best, Emile

emilemathieu avatar Jul 20 '22 16:07 emilemathieu

Hey Emile,

Thank you for your question and interest in GPJax. GPJax does not currently support this, but it is on our radar! We (me and @thomaspinder) are currently working on a multi-output GP package as a seamless extension of GPJax, to support functionality of this kind. We have decided to make this separate package (currently called "MOGPJax") to keep GPJax as a light readable codebase while offering users greater flexibility in defining a broad scope of vector-valued models and scalable inference procedures that, in general, differ from single-output models. We are currently in the early stages of development but expect our first release to be made public soon (once we are happy with the core structure).

Cheers, Dan

daniel-dodd avatar Jul 21 '22 10:07 daniel-dodd

Thanks @Daniel-Dodd for your answer! Would you by any chance be able to tell a bit about the core idea in how to get a seamless extension? By rearranging/reshaping the mean [b, n, d] -> [b, n*d] and covariance [b, n, n, d, d] -> [b, n*d, n*d]?

emilemathieu avatar Jul 21 '22 14:07 emilemathieu

@Daniel-Dodd would you have any update regarding this MOGPJax package by any chance?

emilemathieu avatar Sep 01 '22 14:09 emilemathieu

Hi @emilemathieu,

Apologies for my delay.

We are actively developing this package and working towards the first release. It will, however, take us more time. Some of this depends on us completing GPJax's v0.5 release. The first public release of MOGPJax will have (at the bare minimum): GPLVMs, isotopic conjugate GPs, and isotopic non-conjugate GPs for which the user can do map estimates or MCMC inference like in GPJax. Presently, we have some rough implementations, but these need further work.

Thanks, Dan

daniel-dodd avatar Sep 19 '22 10:09 daniel-dodd

Hi @Daniel-Dodd and @thomaspinder,

First thanks for this awesome library ! In my research, I am also very interested in using GP (more specifically SVGP) with multi outputs, and was wondering if, waiting for MOGPJax to be officially released, there was maybe some repo I could fork to begin to use multi output GP ? I could not find MOGPJax anywhere so I guess it is not public yet.

Also do you have a release date yet ?

Thanks a lot, Thomas

thomascerbelaud avatar Nov 29 '22 10:11 thomascerbelaud

Hi @thomascerbelaud, Thank you for your kind words and interest.

Progress on this has fallen behind, due to major refactoring work on GPJax versions v0.5 - v0.5.2, and the repository has not been updated in while!

To get the ball rolling on this, and now that we have the JaxGaussianProcesses organisation established, I will aim to make this repository public soon (hopefully over this weekend, but certainly by the end of next week) once the tests pass (some things broke since GPJax refactoring)! That way we can develop in public, so that anyone can fork and contribute.

Currently,MOGPJax has a GPLVM via a map estimate, and we have a rough notebook implementation of a multi-output prior and conjugate posterior for isotopic datasets.

We plan to move out the kernels into a separate library JaxKern, and are interested in thinking about a multi-output kernel abstraction. In addition, we plan in the near future, to implement Kronecker linear operators over in JaxLinOp, so we can, e.g. think about linear coregionalisation model abstractions.

We would certainly be super interested in supporting a multi-output SVGP framework!

daniel-dodd avatar Dec 08 '22 16:12 daniel-dodd

Hi @thomascerbelaud, this has been made public, JaxGaussianProcesses/MOGPJax.

Note this currently only has a GPLVM model. The JaxLinOp integration into GPJax has broken the multi-output prior and posterior stuff, particularly the gram/cross-covariance matrix construction, and I have not had time to fix this. I will open issues for all of these, shortly.

It would be great to work towards a clean multi-output kernel abstraction, with a neat way to compute cross-covariances, gram inverses, etc, that harmonises efficiency from abstractions in JaxLinOp. This would probably be a good starting point. Once we have this sorted, we'll have a good basis to start adding MOGP models to the library.

daniel-dodd avatar Dec 16 '22 15:12 daniel-dodd