sbi icon indicating copy to clipboard operation
sbi copied to clipboard

Add support for multimodal data

Open StarostinV opened this issue 3 months ago • 4 comments

🚀 Feature Request

Add support for multimodal x_o

SBI expects data x_o to be Tensor, which makes it difficult to work with multimodal data. In many real-world scientific applications data can contain different modalities, e.g. 2d images, 1d signals, scalar context data, which should be provided together as x_o. Currently, one would have to make some tricks like x_o = torch.cat([data.flatten(1) for data in data_list], 1), which is obviously not user-friendly. An ideal solution would be to support both Tensor and dict[str, Tensor].

Describe the solution you'd like

Add a dedicated Data class to wrap x_o:

  • Support both Tensor and dict[str, Tensor] in the constructor (or a list of these for iid).
  • to_model_input() method that would return a Tensor or dict[str, Tensor] (so no changes in the current models are needed).
  • Provide sbi-specific methods, including .batch_size and .is_iid() (instead of checking x_o.shape[0]).
  • Add native concatenation for iid mode.
  • Register as a pytree.

Describe alternatives you've considered

We could use some existing solutions like tensordict. It seems that tensordict could be integrated into pytorch in the future, but so far that would be an additional dependency and probably an overkill.

📌 Additional Context

Already discussed with @janfb briefly, but we could talk more about the details. It seems to me that we could make a surgical addition with minimal changes and I would be happy to do that.

StarostinV avatar Oct 07 '25 17:10 StarostinV

We can absolutely support this, but we could also just point people to the training interface for achieving this.

michaeldeistler avatar Oct 08 '25 09:10 michaeldeistler

We can absolutely support this, but we could also just point people to the training interface for achieving this.

I'm sure power users can find ways to work around this, but I think we want sbi to be not only only a research platform for ML, but also an accessible toolbox for scientists without much expertize in ML or the package internals. This issue is about improving the API for scientists (and a step towards further features like auto-selection of encoders based on data modality).

StarostinV avatar Oct 09 '25 12:10 StarostinV

I just played around a bit, and the training inference does not support this. Some simple things that break are:

However, there are also some major blockers:

Overall, I think this would be a very cool feature, but it will involve a lot of work. If we do it, we should probably first enable it for the training interface and then consider whether we want to also support it for .append_simulations.

michaeldeistler avatar Oct 14 '25 13:10 michaeldeistler

A quick idea: JAX has a ravel_pytree function which flattens any pytree (e.g., a list of lists of arrays, or a dictionary of arrays,...). It also returns a function to undo the flattening.

Thus, a simple implementation for sbi could be:

x: list[list[tensor]] = ...
x_flat, unravel_fn = ravel_pytree(x)
# x_flat is a tensor, unravel_fn is a callable.

inference.append_simulations(theta, x_flat, unravel_fn=unravel_fn)

Internally, we then perform

x = unravel_fn(x_flat)
embedding = self.embedding_net(x)

I think this would be quite minimal to implement. However, I am not sure if anything like ravel_pytree exists in Pytorch.

michaeldeistler avatar Nov 05 '25 15:11 michaeldeistler