sbi icon indicating copy to clipboard operation
sbi copied to clipboard

Support variably sized observations with suitable embedding net

Open manuelgloeckler opened this issue 1 year ago • 2 comments

Is your feature request related to a problem? Please describe.

Handling variable-size observations, such as those used with permutation-invariant embedding networks, RNNs, or Transformers, currently requires padding inputs (e.g., with NaNs) to a fixed size. While this approach is nice for batching during training, at test time, it's preferable to also support working with tensors of varying lengths directly.

Unfortunately, the current input_shape checks prevent this, even when the underlying methods could handle variable-length inputs without issue.

As a workaround, it's necessary to manually override the inferred shapes to bypass these checks:

x_o = torch.tensor(np.array(x_o))
posterior._x_shape = (1, x_o.shape[0], x_o.shape[1])
posterior.posterior_estimator._condition_shape = x_o.shape
posterior.sample((n_samples,), x=x_o, show_progress_bars=False)

Describe the solution you'd like

Shape checks should only be enforced where a static shape is truly necessary. Specifically:

  • If an embedding network is used, shape checks should apply to the output of the embedding network.
  • Eliminate redundant shape checks to avoid unnecessary constraints on variable-length inputs.

manuelgloeckler avatar Dec 10 '24 08:12 manuelgloeckler

related to #218

janfb avatar Dec 13 '24 08:12 janfb

More context for resolving this:

Right now, we enforce the condition (e.g., x_o), to be of the correct event_shape, by calling x = reshape_to_batch_event(x,...) either in the Posterior or the potential (which is very messy). However, in the case where the condition can have a variable size, and we have an embedding network that can handle this, the reshaping fails.

Suggested steps:

  1. Change functionality of reshape_to_batch_event to allow for variable sized conditions
  2. Check whether x_o is a valid shape by checking if the embedded input (e.g., add _embedded_condition_shape property to Estimator objects
  3. Remove redundant shape checks: ideally, shape checks should be happening at the same place regardless of the method.

gmoss13 avatar Feb 26 '25 11:02 gmoss13