Add support for multimodal data
🚀 Feature Request
Add support for multimodal x_o
SBI expects data x_o to be Tensor, which makes it difficult to work with multimodal data. In many real-world scientific applications data can contain different modalities, e.g. 2d images, 1d signals, scalar context data, which should be provided together as x_o. Currently, one would have to make some tricks like x_o = torch.cat([data.flatten(1) for data in data_list], 1), which is obviously not user-friendly. An ideal solution would be to support both Tensor and dict[str, Tensor].
Describe the solution you'd like
Add a dedicated Data class to wrap x_o:
- Support both
Tensoranddict[str, Tensor]in the constructor (or a list of these for iid). -
to_model_input()method that would return a Tensor or dict[str, Tensor] (so no changes in the current models are needed). - Provide sbi-specific methods, including
.batch_sizeand.is_iid()(instead of checkingx_o.shape[0]). - Add native concatenation for iid mode.
- Register as a pytree.
Describe alternatives you've considered
We could use some existing solutions like tensordict. It seems that tensordict could be integrated into pytorch in the future, but so far that would be an additional dependency and probably an overkill.
📌 Additional Context
Already discussed with @janfb briefly, but we could talk more about the details. It seems to me that we could make a surgical addition with minimal changes and I would be happy to do that.
We can absolutely support this, but we could also just point people to the training interface for achieving this.
We can absolutely support this, but we could also just point people to the training interface for achieving this.
I'm sure power users can find ways to work around this, but I think we want sbi to be not only only a research platform for ML, but also an accessible toolbox for scientists without much expertize in ML or the package internals. This issue is about improving the API for scientists (and a step towards further features like auto-selection of encoders based on data modality).
I just played around a bit, and the training inference does not support this. Some simple things that break are:
- check_data_device
- get_numel
-
z-scoring (which can be circumvented by setting
build_nsf(..., z_score_y=None), which I am mostly fine with.
However, there are also some major blockers:
- our neural networks
- and even the nflows backend (and potentially other backends, I did not check)
Overall, I think this would be a very cool feature, but it will involve a lot of work. If we do it, we should probably first enable it for the training interface and then consider whether we want to also support it for .append_simulations.
A quick idea: JAX has a ravel_pytree function which flattens any pytree (e.g., a list of lists of arrays, or a dictionary of arrays,...). It also returns a function to undo the flattening.
Thus, a simple implementation for sbi could be:
x: list[list[tensor]] = ...
x_flat, unravel_fn = ravel_pytree(x)
# x_flat is a tensor, unravel_fn is a callable.
inference.append_simulations(theta, x_flat, unravel_fn=unravel_fn)
Internally, we then perform
x = unravel_fn(x_flat)
embedding = self.embedding_net(x)
I think this would be quite minimal to implement. However, I am not sure if anything like ravel_pytree exists in Pytorch.