numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Support MatrixNormal distribution

Open quattro opened this issue 4 years ago • 3 comments

Hi all, the MatrixNormal provides a clean way to parameterize the distribution of n x k matrices X ~ MN_{nk}(M, U, V) as a function of its n x k mean M, n x n row-wise covariance U and k x k column-wise covariance V.

This can be equivalently modeled as a classical multi-variate normal by applying the vec operation of X and kronecker product between U and V to get vec(X) ~ N(vec(M), V kron U). However directly evaluating this likelihood ignores the computational advantages offered by the MatrixNormal formulation (computing inverse of V kron U requires O((nk)^3) compared with the O(n^3 + k^3) formulation).

It would be great to see this implemented in numpyro to exploit its variance/covariance structure for efficient likelihood calculation. There are some corner cases that devolve into batched (or transposed batched) MVNs (ie when either U or V is identity), but in its full generality it could be useful (ie certain factor analysis methods, or multi-variate outcomes models).

Implementation should be straightforward (I've got a non-batched version worked out), but could see batching complicate implementation details.

quattro avatar Oct 07 '21 20:10 quattro

Hi @quattro, this is a nice distribution to have. Please don't worry much about batching stuff. We can iterate using PR comments.

fehiepsi avatar Oct 08 '21 00:10 fehiepsi

Any progress on this?

kaijennissen avatar Feb 04 '22 13:02 kaijennissen

I was able to implement a basic example. (MN=MatrixNormal, MVN=MultivariateNormal)

Sampling: I've used the fact that if X is a (n x p) matrix with x_(ij) ~ N(0,1) and Y=loc + U @ X @ V' then Y ~ MN(loc, A@A',B@B'). Assume we have sample_shape=(s,), batch_shape=() . Then we can sample X form a Normal distribution with sample_shape = (s,)+event_shape where event_shape=(n,p). We then have to map to matrix multiplication along the (single) first axis which can be achieved using vmap, e.g.

X = dist.Normal(0,1).sample(rng, sample_shape=(s,)+event_shape)
Y = loc + jax.vmap(lambda x: U@x@V)(X)

I was able to come up with an extension for sample_shape=(s,) and batch_shape=(b,)by first using vmap to map the matrix multiplication along the (single) batch axis and then map this along the (single) sampling axis. To my surprise this also worked for batch_shape=(b1,b1) but I cannot explain why. Any suggestions on how I could implement arbitrary batch and sample shapes without chaining vmap for each of the dimensions?

log_prob: Here I've used the above mentioned relation vec(X) ~ MVN(vec(loc), V kron U) if X ~MN(vec(loc), V kron U). In addition I've used (L_U@L_U') kron (L_V@L_V')=(L_U kron L_V)(L_U' kron L_V') from here to parametrize the MVN by its lower triangular.

Sampling
# Case 2: sample_shape=(s,), batch_shape=(b,)
def sample2(rng_key, loc, scale_rows, scale_columns, sample_shape=()):
    event_shape = loc.shape[-2:]
    n, p = event_shape
    batch_shape = loc.shape[:-2]
    assert loc.ndim == scale_columns.ndim == scale_rows.ndim

    triu_U = jsp.linalg.cholesky(scale_columns, lower=False)
    assert triu_U.shape == batch_shape + (p, p)
    tril_V = jsp.linalg.cholesky(scale_rows, lower=True)
    assert tril_V.shape == batch_shape + (n, n)

    X = random.normal(rng_key, shape=sample_shape + batch_shape + event_shape)
    assert X.shape == sample_shape + batch_shape + event_shape

    # X.shape = (n,p) & x_{i,j} ~ N(0,1) -> Y ~ MN(loc,tril_V@tril_V' , triu_U'@triu_U)
    # with Y = loc + tril_V @ X @ triu_U (https://en.wikipedia.org/wiki/Matrix_normal_distribution)
    # Step 1: use vmap to map matrix multiplicaton along batch and then along sampel dims
    batch_map = jax.vmap(lambda x, y, z: x @ y @ z, in_axes=0, out_axes=0)
    sample_map = jax.vmap(lambda x: batch_map(tril_V, x, triu_U))
    Y = sample_map(X)
    assert Y.shape == sample_shape + batch_shape + event_shape

    return Y
log_prob
def log_prob1(x, loc, scale_rows, scale_columns):
    event_shape = loc.shape[-2:]
    n, p = event_shape
    k = n * p
    batch_shape = loc.shape[:-2]

    new_shape = (-1,) + batch_shape + (k,)
    loc_mvn = loc.reshape(new_shape)

    tril_U = jsp.linalg.cholesky(scale_columns, lower=True)
    tril_V = jsp.linalg.cholesky(scale_rows, lower=True)
    assert tril_U.shape == batch_shape + (p, p)
    assert tril_V.shape == batch_shape + (n, n)

    # If (x) is the Kronecker-Product then it holds that
    # (A (x) B) @ (C (x) D) =(A @ C) (x) (B @ D)
    # see (KRON 13 - https://www.math.uwaterloo.ca/~hwolkowi/henry/reports/kronthesisschaecke04.pdf)
    tril_scale = jax.vmap(lambda x, y: jnp.kron(x, y))(tril_U, tril_V)
    assert tril_scale.shape == batch_shape + (k, k)
    # mvn = dist.MultivariateNormal(loc=loc_mvn, scale_tril=tril_scale)
    mvn = dist.MultivariateNormal(
        loc=jnp.squeeze(loc_mvn), scale_tril=jnp.squeeze(tril_scale)
    )
    assert mvn.event_shape == (k,)
    assert mvn.batch_shape == batch_shape
    log_prob_ = mvn.log_prob(x.reshape(new_shape))
    assert log_prob_.shape == sample_shape + batch_shape
    return log_prob_

Full Code
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import numpyro.distributions as dist
from jax import random
from numpyro.distributions import Distribution


def sample2(rng_key, loc, scale_rows, scale_columns, sample_shape=()):
    event_shape = loc.shape[-2:]
    n, p = event_shape
    batch_shape = loc.shape[:-2]
    assert loc.ndim == scale_columns.ndim == scale_rows.ndim

    triu_U = jsp.linalg.cholesky(scale_columns, lower=False)
    assert triu_U.shape == batch_shape + (p, p)
    tril_V = jsp.linalg.cholesky(scale_rows, lower=True)
    assert tril_V.shape == batch_shape + (n, n)

    X = random.normal(rng_key, shape=sample_shape + batch_shape + event_shape)
    assert X.shape == sample_shape + batch_shape + event_shape

    # X.shape = (n,p) & x_{i,j} ~ N(0,1) -> Y ~ MN(loc,tril_V@tril_V' , triu_U'@triu_U)
    # with Y = loc + tril_V @ X @ triu_U (https://en.wikipedia.org/wiki/Matrix_normal_distribution)
    # X.ndim=4 > tril_U.ndim=triu_U.ndim=3
    # Step 1: use vmap to map matrix multiplicaton along batch and then along sampels
    batch_map = jax.vmap(lambda x, y, z: x @ y @ z, in_axes=0, out_axes=0)
    # assert batch_map(tril_V, X[0, ...], triu_U).shape == batch_shape + event_shape
    sample_map = jax.vmap(lambda x: batch_map(tril_V, x, triu_U))
    Y = sample_map(X)
    assert Y.shape == sample_shape + batch_shape + event_shape

    return Y


# Test 2:
sample_shape = (10,)
batch_shape = (2,)  # type: ignore
event_shape = (2, 3)

rng_key = random.PRNGKey(435)


loc = jnp.arange(6).reshape(event_shape) * jnp.ones(batch_shape + (1, 1))
assert loc.shape == batch_shape + event_shape

tril_U = jnp.array([[1.0, 0, 0], [4.0, 1.0, 0], [0.4, 2.25, 1.0]])
scale_columns = jnp.matmul(tril_U, tril_U.T) * jnp.ones(batch_shape + (1, 1))
assert scale_columns.shape == batch_shape + (event_shape[1], event_shape[1])

tril_V = jnp.array([[4.0, 0.0], [1, 0.25]])
scale_rows = jnp.matmul(tril_V, tril_V.T) * jnp.ones(batch_shape + (1, 1))
assert scale_rows.shape == batch_shape + (event_shape[0], event_shape[0])

Y = sample2(rng_key, loc, scale_rows, scale_columns, sample_shape=sample_shape)
assert Y.shape == sample_shape + batch_shape + event_shape

# Test 3:
sample_shape = (10,)  # type: ignore
batch_shape = (2, 5, 2)  # type: ignore
event_shape = (2, 3)

rng_key = random.PRNGKey(435)


loc = jnp.arange(6).reshape(event_shape) * jnp.ones(batch_shape + (1, 1))
assert loc.shape == batch_shape + event_shape

tril_U = jnp.array([[1.0, 0, 0], [4.0, 1.0, 0], [0.4, 2.25, 1.0]])
scale_columns = jnp.matmul(tril_U, tril_U.T) * jnp.ones(batch_shape + (1, 1))
assert scale_columns.shape == batch_shape + (event_shape[1], event_shape[1])

tril_V = jnp.array([[4.0, 0.0], [1, 0.25]])
scale_rows = jnp.matmul(tril_V, tril_V.T) * jnp.ones(batch_shape + (1, 1))
assert scale_rows.shape == batch_shape + (event_shape[0], event_shape[0])

Y = sample2(rng_key, loc, scale_rows, scale_columns, sample_shape=sample_shape)
assert Y.shape == sample_shape + batch_shape + event_shape


sample_shape = (1_000,)
Y = sample2(
    jax.random.PRNGKey(125), loc, scale_rows, scale_columns, sample_shape=sample_shape
)
assert Y.shape == sample_shape + batch_shape + event_shape


# sample_shape=(s,), batch_shape=(),
def log_prob1(x, loc, scale_rows, scale_columns):
    event_shape = loc.shape[-2:]
    n, p = event_shape
    k = n * p
    batch_shape = loc.shape[:-2]
    new_shape = (-1,) + batch_shape + (k,)
    loc_mvn = loc.reshape(new_shape)

    tril_U = jsp.linalg.cholesky(scale_columns, lower=True)
    tril_V = jsp.linalg.cholesky(scale_rows, lower=True)
    assert tril_U.shape == batch_shape + (p, p)
    assert tril_V.shape == batch_shape + (n, n)

    # If (x) is the Kronecker-Product then it holds that
    # (A (x) B) @ (C (x) D) =(A @ C) (x) (B @ D)
    # see (KRON 13 - https://www.math.uwaterloo.ca/~hwolkowi/henry/reports/kronthesisschaecke04.pdf)

    if batch_shape == ():
        tril_scale = jnp.kron(tril_U, tril_V)
    else:
        tril_scale = jax.vmap(lambda x, y: jnp.kron(x, y))(tril_U, tril_V)
    assert tril_scale.shape == batch_shape + (k, k)
    # mvn = dist.MultivariateNormal(loc=loc_mvn, scale_tril=tril_scale)
    mvn = dist.MultivariateNormal(
        loc=jnp.squeeze(loc_mvn), scale_tril=jnp.squeeze(tril_scale)
    )
    assert mvn.event_shape == (k,)
    assert mvn.batch_shape == batch_shape
    log_prob_ = mvn.log_prob(x.reshape(new_shape))
    assert log_prob_.shape == sample_shape + batch_shape
    return log_prob_


# Test 4:
sample_shape = (10,)
batch_shape = (2,)  # type: ignore
event_shape = (2, 3)

rng_key = random.PRNGKey(435)


loc = jnp.arange(6).reshape(event_shape) * jnp.ones(batch_shape + (1, 1))
assert loc.shape == batch_shape + event_shape

tril_U = jnp.array([[1.0, 0, 0], [4.0, 1.0, 0], [0.4, 2.25, 1.0]])
scale_columns = jnp.matmul(tril_U, tril_U.T) * jnp.ones(batch_shape + (1, 1))
assert scale_columns.shape == batch_shape + (event_shape[1], event_shape[1])

tril_V = jnp.array([[4.0, 0.0], [1, 0.25]])
scale_rows = jnp.matmul(tril_V, tril_V.T) * jnp.ones(batch_shape + (1, 1))
assert scale_rows.shape == batch_shape + (event_shape[0], event_shape[0])

Y = sample2(rng_key, loc, scale_rows, scale_columns, sample_shape=sample_shape)
assert Y.shape == sample_shape + batch_shape + event_shape

x_log_prob = log_prob1(Y, loc, scale_rows, scale_columns)
assert x_log_prob.shape == sample_shape + batch_shape

kaijennissen avatar Feb 06 '22 19:02 kaijennissen

@kaijennissen , I apologize, but I do not have enough time to finish out my initial code. I did take a peek at your definition @kaijennissen , and it seems to me that your code, while correct, missing out on the speed savings based on the kronecker definition. There is no need to blow up the definition to the large multivariate normal, when the inverse of the kron can be defined in a smaller (computationally) way.

quattro avatar Oct 03 '22 21:10 quattro

@quattro You are right. The current implementation is a simple translation of the MatrixNormal used in tensorflow probability. While an efficient version is straight forward without batching I'm struggling with a batched version. Hopefully I will have some time in the near future to work on this.

kaijennissen avatar Oct 09 '22 15:10 kaijennissen

@fehiepsi could this issue be closed now?

cgarciga avatar Jan 30 '23 13:01 cgarciga

Yes, thanks for the ping.

Thanks for adding this distribution, @kaijennissen!

fehiepsi avatar Jan 30 '23 14:01 fehiepsi