Support variably sized observations with suitable embedding net
Is your feature request related to a problem? Please describe.
Handling variable-size observations, such as those used with permutation-invariant embedding networks, RNNs, or Transformers, currently requires padding inputs (e.g., with NaNs) to a fixed size. While this approach is nice for batching during training, at test time, it's preferable to also support working with tensors of varying lengths directly.
Unfortunately, the current input_shape checks prevent this, even when the underlying methods could handle variable-length inputs without issue.
As a workaround, it's necessary to manually override the inferred shapes to bypass these checks:
x_o = torch.tensor(np.array(x_o))
posterior._x_shape = (1, x_o.shape[0], x_o.shape[1])
posterior.posterior_estimator._condition_shape = x_o.shape
posterior.sample((n_samples,), x=x_o, show_progress_bars=False)
Describe the solution you'd like
Shape checks should only be enforced where a static shape is truly necessary. Specifically:
- If an embedding network is used, shape checks should apply to the output of the embedding network.
- Eliminate redundant shape checks to avoid unnecessary constraints on variable-length inputs.
related to #218
More context for resolving this:
Right now, we enforce the condition (e.g., x_o), to be of the correct event_shape, by calling x = reshape_to_batch_event(x,...) either in the Posterior or the potential (which is very messy). However, in the case where the condition can have a variable size, and we have an embedding network that can handle this, the reshaping fails.
Suggested steps:
- Change functionality of
reshape_to_batch_eventto allow for variable sized conditions - Check whether
x_ois a valid shape by checking if the embedded input (e.g., add_embedded_condition_shapeproperty toEstimator objects - Remove redundant shape checks: ideally, shape checks should be happening at the same place regardless of the method.