Dimensions mismatch error when using `fit_surrogate_posterior()` with `sample_size` > 1
I am training a LinearRegression with 2 trainable weights and 3 trajectories. There are no problems when I use variational inference to solve the problem independently for the 3 trajectories, but when I decide to use a single pair of weights broadcasted to the 3 trajectories I am getting an error if I use fit_surrogate_posterior(sample_size=50, ...):
ValueError: Dimensions must be equal, but are 3 and 50 for '{{node monte_carlo_variational_loss/expectation/make_component_state_space_models/design_matrix_linop/matmul/MatMul}} = BatchMatMulV2[T=DT_FLOAT, adj_x=false, adj_y=false](monte_carlo_variational_loss/expectation/make_component_state_space_models/design_matrix_linop/matmul/MatMul/a, monte_carlo_variational_loss/expectation/make_component_state_space_models/strided_slice)' with input shapes: [3,5,2], [50,2,1].
A code that demonstrates both cases is provided below:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
tfd = tfp.distributions
tfb = tfp.bijectors
sts = tfp.sts
dtype = tf.float32
# (3,5,2)
xs = np.array([[[0., 0.], [1., 1.], [2., 1.41], [3., 1.73], [4., 2.]],
[[0., 0.], [1., 1.], [2., 1.41], [3., 1.73], [4., 2.]],
[[0., 0.], [1., 1.], [2., 1.41], [3., 1.73], [4., 2.]]])
# (3,5)
ys = np.array([[2.62, 0.17, 0.82, -1.42, 3.59],
[0.71, -0.42, 1.92, 0.01, 4.73],
[-1.07, 1.73, 1.48, 3.24, 5.62]])
single_solution = False # solve 3 independent problems
# single_solution = True # single solution
if single_solution:
# batch_shape = []
weights_prior = tfd.Independent(distribution=tfd.Normal(loc=[0, 0], scale=5.),
reinterpreted_batch_ndims=1,
name='weights_prior'
)
else:
# batch_shape = [3]
weights_prior = tfd.Independent(distribution=tfd.Normal(loc=[[0, 0],
[0, 0],
[0, 0]], scale=5.),
reinterpreted_batch_ndims=1,
name='weights_prior'
)
# batch_shape = [] or [3]
lin_reg = sts.LinearRegression(
design_matrix=tf.cast(xs, dtype=dtype),
weights_prior=weights_prior,
name='lin_reg')
# batch_shape = [] or [3]
model = sts.Sum(
components=[lin_reg],
constant_offset=tf.constant(0., dtype=dtype),
observation_noise_scale_prior=tfd.LogNormal(loc=1, scale=2.),
name='time_series')
variational_posteriors = tfp.sts.build_factored_surrogate_posterior(model=model)
sample_size = 50
num_variational_steps = 200
optimizer = tf.optimizers.Adam(learning_rate=.01)
observed_time_series = tf.cast(ys, dtype=dtype)
@tf.function(experimental_compile=True)
def train():
elbo_loss_curve = tfp.vi.fit_surrogate_posterior(
target_log_prob_fn=model.joint_log_prob(observed_time_series=observed_time_series),
surrogate_posterior=variational_posteriors,
optimizer=optimizer,
sample_size=sample_size,
num_steps=num_variational_steps)
return elbo_loss_curve
elbo_loss_curve = train()
plt.figure(figsize=(7, 5))
plt.plot(elbo_loss_curve)
plt.yscale('log')
plt.show()
It looks like when the Monte Carlo Elbo is calculated, the sample dimension of the weights surrogate posterior is colliding with the batch dimension of the data. A workaround for single_solution is to define weights_prior so that it has the same batch rank asdesign_matrix, i.e. use weights_prior = tfd.Independent(tfd.Normal(loc=[[0, 0]], ... so that its batch shape is [1].
I'll file a bug to handle this better automatically, and in the meantime update the documentation. Thanks for the clean repro.