Inconsistency in documented/expected shapes for estimators
🐛 Bug Description
ConditionalDensityEstimator and RatioEstimator on sbi main use different names, orders, and shape expectations. Also, RatioEstimator's documented shape expectations are incompatible with what it actually expects.
Details
While ConditionalDensityEstimator (CDE) and RatioEstimator (RE) do not share a common parent type, ideally they would still be as consistent as possible. I assume here that x in RE and input in CDE are roughly the same and theta and input are roughly the same, respectively.
Inconsistent attributes
CDE uses attributes (input_shape, condition_shape): https://github.com/sbi-dev/sbi/blob/a1d75559b1e50b3d19c0eac1cabdfa6a23c16ea4/sbi/neural_nets/estimators/base.py#L135
RE uses attributes (theta_shape, x_shape): https://github.com/sbi-dev/sbi/blob/a1d75559b1e50b3d19c0eac1cabdfa6a23c16ea4/sbi/neural_nets/ratio_estimators.py#L29-L30
Their order in the constructors are reversed.
Inconsistent shapes
CDE documents that the shape of the input is (sample_dim, batch_dim, *input_shape) and the shape of condition is (batch_dim, *condition_shape). While it doesn't enforce this, at least some children do. e.g. MixedDensityEstimator:
https://github.com/sbi-dev/sbi/blob/a1d75559b1e50b3d19c0eac1cabdfa6a23c16ea4/sbi/neural_nets/estimators/mixed_density_estimator.py#L74
https://github.com/sbi-dev/sbi/blob/a1d75559b1e50b3d19c0eac1cabdfa6a23c16ea4/sbi/neural_nets/estimators/mixed_density_estimator.py#L141
RE documents that the shape of x is (batch_dim, *x_shape) and that the shape of theta is (sample_dim, batch_dim, *theta_shape). Note that the two classes differ in which of the arguments has a sample_dim. However, RE actually enforces that x is (*batch_shape, *x_shape) and theta is (*batch_shape, *theta_shape), i.e. the two arguments share the same prefix, which is incompatible with the documented shapes:
https://github.com/sbi-dev/sbi/blob/a1d75559b1e50b3d19c0eac1cabdfa6a23c16ea4/sbi/neural_nets/ratio_estimators.py#L126-L128
Inconsistent argument order in methods
While CDE.log_prob and RE.unnormalized_log_ratio are not equivalent, one would expect their order of arguments to be similar. However, the former takes the order (input, condition):
https://github.com/sbi-dev/sbi/blob/a1d75559b1e50b3d19c0eac1cabdfa6a23c16ea4/sbi/neural_nets/estimators/base.py#L155
while the latter takes (theta, x):
https://github.com/sbi-dev/sbi/blob/a1d75559b1e50b3d19c0eac1cabdfa6a23c16ea4/sbi/neural_nets/ratio_estimators.py#L156
📌 Additional Context
Torch distributions (and Pyro) implement both sample and log_prob, supporting arbitrary batch_shape and sample_shape (not just a single dimension). While not necessary, it would be nice if these methods supported the same shape conventions. This would in particular simplify #1491.
RE inconsistency in shape conventions between docstrings and what the code expects, I wonder if it would be useful to use jaxtyping to provide shapes in type hints. These then can be automatically rendered in the documentation and the docstrings themselves don't need to include shape information. Contrary to its name, jaxtyping has no jax dependency and supports also torch.Tensor hints.
GPyTorch, for example, uses it in a few places, e.g.: https://github.com/cornellius-gp/gpytorch/blob/b017b9c3fe4de526f7a2243ce12ce2305862c90b/gpytorch/variational/nearest_neighbor_variational_strategy.py#L177-L184
While static type checkers can't check the shape conventions are followed, runtime type checkers like beartype can do so very quickly. sbi can either depend on one of these type-checkers or it can simply run one in the test suite. e.g. for beartype, see:
- https://beartype.readthedocs.io/en/latest/faq/#pytorch-tensors
- https://beartype.readthedocs.io/en/latest/faq/#what-does-near-real-time-even-mean-are-you-just-making-stuff-up
- https://beartype.readthedocs.io/en/latest/faq/#how-do-i-only-type-check-while-running-my-test-suite
If there's interest in the jaxtyping approach, I'd be happy to hack together a quick proof-of-concept.
I made a quick attempt to prototype jaxtyping+beartype in SBI, and I don't think it's a good fit for the following reasons:
- jaxtyping currently doesn't support multiple variadic shapes for a single tensor (e.g.
batch_dim *x_shapeis fine, but*batch_shape *x_shapeis not). Because string interpolation is allowed, one could e.g. forRatioEstimatorinterpolateself.x_shapeinto the string, but I think one would actually needf"*batch_shape {' '.join(map(str, self.x_shape))}", which would just be ugly in the documentation. - beartype successfully caught type mismatches, but the errors raised were not very descriptive. It might make sense to have the runtime checking in the test suite, but I wouldn't rely on it to be the main shape checking for user-provided inputs. Here's an example error:
E beartype.roar.BeartypeCallHintReturnViolation: Method sbi.neural_nets.ratio_estimators.RatioEstimator.combine_theta_and_x() return "tensor([[ 1.1846, 0.7002, 1.0900, 0.6601],
E [-1.4753, -1.7527, 1.0900, 0.660...]])" violates type hint <class 'jaxtyping.Float[Tensor, 'sample_dim batch_dim combined_event_dim']'>, as this array has 2 dimensions, not the 3 expected by the type hint.
It might still be worth using jaxtyping for methods that don't require multiple variadic shapes. GPyTorch takes this approach of piecemeal using jaxtyping for a few operators. But it's not a silver bullet.