numpyro
numpyro copied to clipboard
Support MatrixNormal distribution
Hi all, the MatrixNormal provides a clean way to parameterize the distribution of n x k matrices X ~ MN_{nk}(M, U, V) as a function of its n x k mean M, n x n row-wise covariance U and k x k column-wise covariance V.
This can be equivalently modeled as a classical multi-variate normal by applying the vec operation of X and kronecker product between U and V to get vec(X) ~ N(vec(M), V kron U). However directly evaluating this likelihood ignores the computational advantages offered by the MatrixNormal formulation (computing inverse of V kron U requires O((nk)^3) compared with the O(n^3 + k^3) formulation).
It would be great to see this implemented in numpyro to exploit its variance/covariance structure for efficient likelihood calculation. There are some corner cases that devolve into batched (or transposed batched) MVNs (ie when either U or V is identity), but in its full generality it could be useful (ie certain factor analysis methods, or multi-variate outcomes models).
Implementation should be straightforward (I've got a non-batched version worked out), but could see batching complicate implementation details.
Hi @quattro, this is a nice distribution to have. Please don't worry much about batching stuff. We can iterate using PR comments.
Any progress on this?
I was able to implement a basic example. (MN=MatrixNormal, MVN=MultivariateNormal)
Sampling:
I've used the fact that if X is a (n x p) matrix with x_(ij) ~ N(0,1) and Y=loc + U @ X @ V' then Y ~ MN(loc, A@A',B@B').
Assume we have sample_shape=(s,), batch_shape=() . Then we can sample X form a Normal distribution with sample_shape = (s,)+event_shape where event_shape=(n,p).
We then have to map to matrix multiplication along the (single) first axis which can be achieved using vmap, e.g.
X = dist.Normal(0,1).sample(rng, sample_shape=(s,)+event_shape)
Y = loc + jax.vmap(lambda x: U@x@V)(X)
I was able to come up with an extension for sample_shape=(s,) and batch_shape=(b,)by first using vmap to map the matrix multiplication along the (single) batch axis and then map this along the (single) sampling axis.
To my surprise this also worked for batch_shape=(b1,b1) but I cannot explain why.
Any suggestions on how I could implement arbitrary batch and sample shapes without chaining vmap for each of the dimensions?
log_prob:
Here I've used the above mentioned relation vec(X) ~ MVN(vec(loc), V kron U) if X ~MN(vec(loc), V kron U).
In addition I've used (L_U@L_U') kron (L_V@L_V')=(L_U kron L_V)(L_U' kron L_V') from here to parametrize the MVN by its lower triangular.
Sampling
# Case 2: sample_shape=(s,), batch_shape=(b,)
def sample2(rng_key, loc, scale_rows, scale_columns, sample_shape=()):
event_shape = loc.shape[-2:]
n, p = event_shape
batch_shape = loc.shape[:-2]
assert loc.ndim == scale_columns.ndim == scale_rows.ndim
triu_U = jsp.linalg.cholesky(scale_columns, lower=False)
assert triu_U.shape == batch_shape + (p, p)
tril_V = jsp.linalg.cholesky(scale_rows, lower=True)
assert tril_V.shape == batch_shape + (n, n)
X = random.normal(rng_key, shape=sample_shape + batch_shape + event_shape)
assert X.shape == sample_shape + batch_shape + event_shape
# X.shape = (n,p) & x_{i,j} ~ N(0,1) -> Y ~ MN(loc,tril_V@tril_V' , triu_U'@triu_U)
# with Y = loc + tril_V @ X @ triu_U (https://en.wikipedia.org/wiki/Matrix_normal_distribution)
# Step 1: use vmap to map matrix multiplicaton along batch and then along sampel dims
batch_map = jax.vmap(lambda x, y, z: x @ y @ z, in_axes=0, out_axes=0)
sample_map = jax.vmap(lambda x: batch_map(tril_V, x, triu_U))
Y = sample_map(X)
assert Y.shape == sample_shape + batch_shape + event_shape
return Y
log_prob
def log_prob1(x, loc, scale_rows, scale_columns):
event_shape = loc.shape[-2:]
n, p = event_shape
k = n * p
batch_shape = loc.shape[:-2]
new_shape = (-1,) + batch_shape + (k,)
loc_mvn = loc.reshape(new_shape)
tril_U = jsp.linalg.cholesky(scale_columns, lower=True)
tril_V = jsp.linalg.cholesky(scale_rows, lower=True)
assert tril_U.shape == batch_shape + (p, p)
assert tril_V.shape == batch_shape + (n, n)
# If (x) is the Kronecker-Product then it holds that
# (A (x) B) @ (C (x) D) =(A @ C) (x) (B @ D)
# see (KRON 13 - https://www.math.uwaterloo.ca/~hwolkowi/henry/reports/kronthesisschaecke04.pdf)
tril_scale = jax.vmap(lambda x, y: jnp.kron(x, y))(tril_U, tril_V)
assert tril_scale.shape == batch_shape + (k, k)
# mvn = dist.MultivariateNormal(loc=loc_mvn, scale_tril=tril_scale)
mvn = dist.MultivariateNormal(
loc=jnp.squeeze(loc_mvn), scale_tril=jnp.squeeze(tril_scale)
)
assert mvn.event_shape == (k,)
assert mvn.batch_shape == batch_shape
log_prob_ = mvn.log_prob(x.reshape(new_shape))
assert log_prob_.shape == sample_shape + batch_shape
return log_prob_
Full Code
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import numpyro.distributions as dist
from jax import random
from numpyro.distributions import Distribution
def sample2(rng_key, loc, scale_rows, scale_columns, sample_shape=()):
event_shape = loc.shape[-2:]
n, p = event_shape
batch_shape = loc.shape[:-2]
assert loc.ndim == scale_columns.ndim == scale_rows.ndim
triu_U = jsp.linalg.cholesky(scale_columns, lower=False)
assert triu_U.shape == batch_shape + (p, p)
tril_V = jsp.linalg.cholesky(scale_rows, lower=True)
assert tril_V.shape == batch_shape + (n, n)
X = random.normal(rng_key, shape=sample_shape + batch_shape + event_shape)
assert X.shape == sample_shape + batch_shape + event_shape
# X.shape = (n,p) & x_{i,j} ~ N(0,1) -> Y ~ MN(loc,tril_V@tril_V' , triu_U'@triu_U)
# with Y = loc + tril_V @ X @ triu_U (https://en.wikipedia.org/wiki/Matrix_normal_distribution)
# X.ndim=4 > tril_U.ndim=triu_U.ndim=3
# Step 1: use vmap to map matrix multiplicaton along batch and then along sampels
batch_map = jax.vmap(lambda x, y, z: x @ y @ z, in_axes=0, out_axes=0)
# assert batch_map(tril_V, X[0, ...], triu_U).shape == batch_shape + event_shape
sample_map = jax.vmap(lambda x: batch_map(tril_V, x, triu_U))
Y = sample_map(X)
assert Y.shape == sample_shape + batch_shape + event_shape
return Y
# Test 2:
sample_shape = (10,)
batch_shape = (2,) # type: ignore
event_shape = (2, 3)
rng_key = random.PRNGKey(435)
loc = jnp.arange(6).reshape(event_shape) * jnp.ones(batch_shape + (1, 1))
assert loc.shape == batch_shape + event_shape
tril_U = jnp.array([[1.0, 0, 0], [4.0, 1.0, 0], [0.4, 2.25, 1.0]])
scale_columns = jnp.matmul(tril_U, tril_U.T) * jnp.ones(batch_shape + (1, 1))
assert scale_columns.shape == batch_shape + (event_shape[1], event_shape[1])
tril_V = jnp.array([[4.0, 0.0], [1, 0.25]])
scale_rows = jnp.matmul(tril_V, tril_V.T) * jnp.ones(batch_shape + (1, 1))
assert scale_rows.shape == batch_shape + (event_shape[0], event_shape[0])
Y = sample2(rng_key, loc, scale_rows, scale_columns, sample_shape=sample_shape)
assert Y.shape == sample_shape + batch_shape + event_shape
# Test 3:
sample_shape = (10,) # type: ignore
batch_shape = (2, 5, 2) # type: ignore
event_shape = (2, 3)
rng_key = random.PRNGKey(435)
loc = jnp.arange(6).reshape(event_shape) * jnp.ones(batch_shape + (1, 1))
assert loc.shape == batch_shape + event_shape
tril_U = jnp.array([[1.0, 0, 0], [4.0, 1.0, 0], [0.4, 2.25, 1.0]])
scale_columns = jnp.matmul(tril_U, tril_U.T) * jnp.ones(batch_shape + (1, 1))
assert scale_columns.shape == batch_shape + (event_shape[1], event_shape[1])
tril_V = jnp.array([[4.0, 0.0], [1, 0.25]])
scale_rows = jnp.matmul(tril_V, tril_V.T) * jnp.ones(batch_shape + (1, 1))
assert scale_rows.shape == batch_shape + (event_shape[0], event_shape[0])
Y = sample2(rng_key, loc, scale_rows, scale_columns, sample_shape=sample_shape)
assert Y.shape == sample_shape + batch_shape + event_shape
sample_shape = (1_000,)
Y = sample2(
jax.random.PRNGKey(125), loc, scale_rows, scale_columns, sample_shape=sample_shape
)
assert Y.shape == sample_shape + batch_shape + event_shape
# sample_shape=(s,), batch_shape=(),
def log_prob1(x, loc, scale_rows, scale_columns):
event_shape = loc.shape[-2:]
n, p = event_shape
k = n * p
batch_shape = loc.shape[:-2]
new_shape = (-1,) + batch_shape + (k,)
loc_mvn = loc.reshape(new_shape)
tril_U = jsp.linalg.cholesky(scale_columns, lower=True)
tril_V = jsp.linalg.cholesky(scale_rows, lower=True)
assert tril_U.shape == batch_shape + (p, p)
assert tril_V.shape == batch_shape + (n, n)
# If (x) is the Kronecker-Product then it holds that
# (A (x) B) @ (C (x) D) =(A @ C) (x) (B @ D)
# see (KRON 13 - https://www.math.uwaterloo.ca/~hwolkowi/henry/reports/kronthesisschaecke04.pdf)
if batch_shape == ():
tril_scale = jnp.kron(tril_U, tril_V)
else:
tril_scale = jax.vmap(lambda x, y: jnp.kron(x, y))(tril_U, tril_V)
assert tril_scale.shape == batch_shape + (k, k)
# mvn = dist.MultivariateNormal(loc=loc_mvn, scale_tril=tril_scale)
mvn = dist.MultivariateNormal(
loc=jnp.squeeze(loc_mvn), scale_tril=jnp.squeeze(tril_scale)
)
assert mvn.event_shape == (k,)
assert mvn.batch_shape == batch_shape
log_prob_ = mvn.log_prob(x.reshape(new_shape))
assert log_prob_.shape == sample_shape + batch_shape
return log_prob_
# Test 4:
sample_shape = (10,)
batch_shape = (2,) # type: ignore
event_shape = (2, 3)
rng_key = random.PRNGKey(435)
loc = jnp.arange(6).reshape(event_shape) * jnp.ones(batch_shape + (1, 1))
assert loc.shape == batch_shape + event_shape
tril_U = jnp.array([[1.0, 0, 0], [4.0, 1.0, 0], [0.4, 2.25, 1.0]])
scale_columns = jnp.matmul(tril_U, tril_U.T) * jnp.ones(batch_shape + (1, 1))
assert scale_columns.shape == batch_shape + (event_shape[1], event_shape[1])
tril_V = jnp.array([[4.0, 0.0], [1, 0.25]])
scale_rows = jnp.matmul(tril_V, tril_V.T) * jnp.ones(batch_shape + (1, 1))
assert scale_rows.shape == batch_shape + (event_shape[0], event_shape[0])
Y = sample2(rng_key, loc, scale_rows, scale_columns, sample_shape=sample_shape)
assert Y.shape == sample_shape + batch_shape + event_shape
x_log_prob = log_prob1(Y, loc, scale_rows, scale_columns)
assert x_log_prob.shape == sample_shape + batch_shape
@kaijennissen , I apologize, but I do not have enough time to finish out my initial code. I did take a peek at your definition @kaijennissen , and it seems to me that your code, while correct, missing out on the speed savings based on the kronecker definition. There is no need to blow up the definition to the large multivariate normal, when the inverse of the kron can be defined in a smaller (computationally) way.
@quattro You are right. The current implementation is a simple translation of the MatrixNormal used in tensorflow probability. While an efficient version is straight forward without batching I'm struggling with a batched version. Hopefully I will have some time in the near future to work on this.
@fehiepsi could this issue be closed now?
Yes, thanks for the ping.
Thanks for adding this distribution, @kaijennissen!