probability icon indicating copy to clipboard operation
probability copied to clipboard

Dimensions mismatch error when using `fit_surrogate_posterior()` with `sample_size` > 1

Open ImScientist opened this issue 4 years ago • 1 comments

I am training a LinearRegression with 2 trainable weights and 3 trajectories. There are no problems when I use variational inference to solve the problem independently for the 3 trajectories, but when I decide to use a single pair of weights broadcasted to the 3 trajectories I am getting an error if I use fit_surrogate_posterior(sample_size=50, ...):

ValueError: Dimensions must be equal, but are 3 and 50 for '{{node monte_carlo_variational_loss/expectation/make_component_state_space_models/design_matrix_linop/matmul/MatMul}} = BatchMatMulV2[T=DT_FLOAT, adj_x=false, adj_y=false](monte_carlo_variational_loss/expectation/make_component_state_space_models/design_matrix_linop/matmul/MatMul/a, monte_carlo_variational_loss/expectation/make_component_state_space_models/strided_slice)' with input shapes: [3,5,2], [50,2,1].

A code that demonstrates both cases is provided below:

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt

tfd = tfp.distributions
tfb = tfp.bijectors
sts = tfp.sts
dtype = tf.float32

# (3,5,2)
xs = np.array([[[0., 0.], [1., 1.], [2., 1.41], [3., 1.73], [4., 2.]],
               [[0., 0.], [1., 1.], [2., 1.41], [3., 1.73], [4., 2.]],
               [[0., 0.], [1., 1.], [2., 1.41], [3., 1.73], [4., 2.]]])

# (3,5)
ys = np.array([[2.62, 0.17, 0.82, -1.42, 3.59],
               [0.71, -0.42, 1.92, 0.01, 4.73],
               [-1.07, 1.73, 1.48, 3.24, 5.62]])

single_solution = False  # solve 3 independent problems
# single_solution = True  # single solution

if single_solution:
    # batch_shape = []
    weights_prior = tfd.Independent(distribution=tfd.Normal(loc=[0, 0], scale=5.),
                                    reinterpreted_batch_ndims=1,
                                    name='weights_prior'
                                    )
else:
    # batch_shape = [3]
    weights_prior = tfd.Independent(distribution=tfd.Normal(loc=[[0, 0],
                                                                 [0, 0],
                                                                 [0, 0]], scale=5.),
                                    reinterpreted_batch_ndims=1,
                                    name='weights_prior'
                                    )

# batch_shape = [] or [3]
lin_reg = sts.LinearRegression(
    design_matrix=tf.cast(xs, dtype=dtype),
    weights_prior=weights_prior,
    name='lin_reg')

# batch_shape = [] or [3]
model = sts.Sum(
    components=[lin_reg],
    constant_offset=tf.constant(0., dtype=dtype),
    observation_noise_scale_prior=tfd.LogNormal(loc=1, scale=2.),
    name='time_series')

variational_posteriors = tfp.sts.build_factored_surrogate_posterior(model=model)

sample_size = 50
num_variational_steps = 200
optimizer = tf.optimizers.Adam(learning_rate=.01)

observed_time_series = tf.cast(ys, dtype=dtype)


@tf.function(experimental_compile=True)
def train():
    elbo_loss_curve = tfp.vi.fit_surrogate_posterior(
        target_log_prob_fn=model.joint_log_prob(observed_time_series=observed_time_series),
        surrogate_posterior=variational_posteriors,
        optimizer=optimizer,
        sample_size=sample_size,
        num_steps=num_variational_steps)
    return elbo_loss_curve


elbo_loss_curve = train()

plt.figure(figsize=(7, 5))
plt.plot(elbo_loss_curve)
plt.yscale('log')
plt.show()

ImScientist avatar Jan 16 '22 15:01 ImScientist

It looks like when the Monte Carlo Elbo is calculated, the sample dimension of the weights surrogate posterior is colliding with the batch dimension of the data. A workaround for single_solution is to define weights_prior so that it has the same batch rank asdesign_matrix, i.e. use weights_prior = tfd.Independent(tfd.Normal(loc=[[0, 0]], ... so that its batch shape is [1].

I'll file a bug to handle this better automatically, and in the meantime update the documentation. Thanks for the clean repro.

emilyfertig avatar Jan 18 '22 04:01 emilyfertig