sbi icon indicating copy to clipboard operation
sbi copied to clipboard

FMPE fails to sample with 2D conditions

Open gmoss13 opened this issue 1 month ago • 1 comments

🐛 Bug Description

When the condition is more than 1-dimensional (e.g. an image instead of a vector), FMPE trains successfully but fails to sample because of a shape mismatch.

🔄 Steps to Reproduce

This is a bit of a stupid example, but the parameter sets the mean, and x is a matrix where each entry is sampled from a normal with that mean.

import torch
from sbi.inference import FMPE
from sbi.utils import BoxUniform
from sbi.neural_nets.embedding_nets import CNNEmbedding
from sbi.neural_nets import posterior_flow_nn


prior = BoxUniform(low=torch.Tensor([-1.0]), high=torch.Tensor([1.0]))
def simulator(theta):
    mean = torch.ones(10,10)
    mean = mean.unsqueeze(0).expand(theta.shape[0], -1, -1)
    print(mean.shape)
    mean = mean + theta.unsqueeze(-1)#mean is a 10x10 matrix with all elements equal to theta

    x = mean + 0.1*torch.randn_like(mean)
    return x

theta = prior.sample((1000,))
x = simulator(theta)

embedding_net = CNNEmbedding(
    input_shape = (10,10),
    in_channels=1,
    output_dim = 5,
    kernel_size=3
)
flow_estimator = posterior_flow_nn(
    model='mlp',
    embedding_net=embedding_net,
)

inference = FMPE(prior=prior, vf_estimator=flow_estimator)
_ = inference.append_simulations(theta, x).train(max_num_epochs=2)
posterior = inference.build_posterior()

x_o = 0.1*torch.randn(1,10,10) + 0.5

samples = posterior.sample((10,), x=x_o)

This fails with the error RuntimeError: shape '[10, 1]' is invalid for input of size 100

✅ Expected Behavior

We get posterior samples

...

📌 Additional Context

I've come across this problem some time ago, and traced it down to this line in ZukoNeuralODE

https://github.com/sbi-dev/sbi/blob/06f13a8eb0c83dd1abb1cf4c33685ead60d133ec/sbi/samplers/ode_solvers/zuko_ode.py#L95-L98

We are expanding the batch size to using condition.shape[:-1] as the condition batch size. But of course, if the condition is 2D, then the batch size is condition.shape[:-2] etc. My proposed solution just passes the condition_shape of the ConditionalVectorFieldEstimator to ZukoNeuralODE so that we can appropriately expand to the batch size here. I can make a quick PR to do this if that sounds right.

gmoss13 avatar Nov 14 '25 16:11 gmoss13

Thanks for raising this @gmoss13 ! Your proposed solution sounds good. 🚀

janfb avatar Nov 14 '25 19:11 janfb