gpytorch icon indicating copy to clipboard operation
gpytorch copied to clipboard

[RFC] "Unrolled" representation of multi-task MVN

Open Balandat opened this issue 5 years ago • 7 comments

Currently, the interleaved/non-interleaved representation of MTMVN creates a bunch of headaches (in particular regarding indexing, scalarization, and performance). It would be nice to have a higher-level API that abstracts away from these details (see #1055).

In this RFC, the covariance matrix is represented as an "unrolled" tensor (e.g. n x t x n x t), instead of a nt x nt covariance matrix. The constructor accepts a covariance matrix of shape batch_shape x shape-dim tensor, where shape has either three or four elements, and can be any combination of "n" and "t" with at least one and at most two occurrences of both "n" and "t" (currently only n==2 is supported). Internally, the covariance is represented consistently as n x t n x t or t x n x n (case of cross-task independence) in order to simplify indexing / scalarization operations.

The main benefits here are:

  1. It makes indexing much easier (e.g. if we want the cross-point covariance at some task j, we can just index the internal covariance as _covar[..., :, j, :, j] and we are done. With the current representation (especially considering the interleaving option) this is a huge headache. Note that this RFC does not include any indexing functions)
  2. We can represent MTMVNs with independent outputs much more efficiently, without necessarily having to do a ton of LazyTensor acrobatics. Importantly, this should mean that in this case sampling and computing log probs should be much faster.
  3. We can easily scalarize multi-output posteriors (in fact, this is related to indexing)

Currently, this is very early work, and so there are major limitations with this:

  1. This PR assumes proper tensors throughout (no support for LazyTensor yet)
  2. No caching of matrix decompositions (yet)
  3. The class subclasses MultivariateNormal (and hence the torch MultivariateNormal), but it doesn't make sure that it implements the full interface, probably a lot of stuff will just break.
  4. No support for covariance_matrix and lazy_covariance_matrix interface. I am actually not sure that we'll need this though.

The notebook included in this PR demos the basic usage of this.

Balandat avatar Mar 23 '20 00:03 Balandat

Now as a next step - would it make sense to get this new MMVN implemented for non-lazy tensors, and focus on adding LTs later down the road?

I think for non-lazy tensors the functionality is essentially there, the rest is just some grunt work (especially caching the cholesky's / allowing for linearCG computations). So including support at least for KroneckerProductLazyTensor should be part of the next step.

Re the covariance_matrix and lazy_covariance_matrix APIs: Should we still allow this to be interleaved and not interleaved? It seems that this doesn't really matter much in this new representation and we can just agree on a convention. Any preferences?

Balandat avatar Mar 27 '20 00:03 Balandat

@Balandat regarding order= naming, this might be a good opportunity to follow conventions set in PyTorch for named tensors? If we start building something like this in to LazyTensor, then we could just name the dimensions of the covariance matrix appropriately.

This would also give us the opportunity to name the dimensions of *MultivariateNormal's outputs, (e.g., the dimensions of a sample).

jacobrgardner avatar Mar 27 '20 15:03 jacobrgardner

Good call @jacobrgardner. I was an early adopter of NamedTensors, but upsterma development on them has been somewhat de-prioritized, so I didn't do much with them. But I think it makes a ton of sense here (and more generally for other stuff as well).

I'll make sure to comply with this convention. In fact, we can even just accept NamedTensors as input args as well without requiring to specify the order.

Balandat avatar Mar 27 '20 20:03 Balandat

Also a lot of stuff is currently not supported with NamedTensors (e.g. permute), so while we can align the naming interface, we can't really do much with them until there is better support for general operations.

Balandat avatar Mar 28 '20 01:03 Balandat

Added a lazy version for the task-independent case. Not sure what level of generality we want / need to support for the lazies, at some point things become really clunky. One thing I will definitely add is a Kronecker version.

Balandat avatar Apr 19 '20 18:04 Balandat

@gpleiss, @jacobrgardner, @dme65 I finally got back to this.

To recall, I drafted a MultioutputMultivariateNormal distribution (subclasses MultivariateNormal, but for now completely separate from MultitaskMultivariateNormal) . This is an abstract API to a multi-output MVN that has is of shape batch_shape x n x m, where n is the number of data points and m is the number of outputs. Internally, the representation can be whatever is most suitable to make computations / indexing more efficient / simpler.

The main limitation right now is that I haven't really hooked this up to the lazies. But at least for the existing functionality I came up with something that allows for quite a nice indexing API (this all works, see the nb in the PR). Any thoughts on this?

Multi-output MVN subsetting API:

$ mtmvn
    
MultioutputMultivariateNormal(batch_shape=(2,), n=3, m=2)
# select batch <- this returns an non-batched MOMVN
$ mtmvn[0]

MultioutputMultivariateNormal(n=3, m=2)
# select range of batches <- this is a null-op
$ mtmvn[:]

MultioutputMultivariateNormal(batch_shape=(2,), n=3, m=2)
# select data point <- returns a batched MVN (single-output)
$ mtmvn[..., 0, :]

MultivariateNormal(batch_shape=(2,), n=2)
# select range of data points <- just subsets the data (n)
$mtmvn[..., :2, :]

MultioutputMultivariateNormal(batch_shape=(2,), n=2, m=2)
# select single output <- returns a standard MVN
$ mtmvn[..., 0]

MultivariateNormal(batch_shape=(2,), n=3)
# select range of outputs <- just subsets the tasks (m)
$ mtmvn[..., :2]

MultioutputMultivariateNormal(batch_shape=(2,), n=3, m=2)
# mixed indexing select specific data point and output form specific batch
# This is a degenerate MVN with one data point
$ mtmvn[0, 1, 0]

MultivariateNormal(n=1)
# mixed indexing but respect batch dims
$ mtmvn[..., 1, 0]

MultivariateNormal(batch_shape=(2,), n=1)

Balandat avatar Dec 20 '20 00:12 Balandat

@neerajprad you might be interested in this, too

Balandat avatar Dec 20 '20 01:12 Balandat