sbi icon indicating copy to clipboard operation
sbi copied to clipboard

Inconsistency in documented/expected shapes for estimators

Open sethaxen opened this issue 9 months ago • 3 comments

🐛 Bug Description

ConditionalDensityEstimator and RatioEstimator on sbi main use different names, orders, and shape expectations. Also, RatioEstimator's documented shape expectations are incompatible with what it actually expects.

Details

While ConditionalDensityEstimator (CDE) and RatioEstimator (RE) do not share a common parent type, ideally they would still be as consistent as possible. I assume here that x in RE and input in CDE are roughly the same and theta and input are roughly the same, respectively.

Inconsistent attributes

CDE uses attributes (input_shape, condition_shape): https://github.com/sbi-dev/sbi/blob/a1d75559b1e50b3d19c0eac1cabdfa6a23c16ea4/sbi/neural_nets/estimators/base.py#L135 RE uses attributes (theta_shape, x_shape): https://github.com/sbi-dev/sbi/blob/a1d75559b1e50b3d19c0eac1cabdfa6a23c16ea4/sbi/neural_nets/ratio_estimators.py#L29-L30 Their order in the constructors are reversed.

Inconsistent shapes

CDE documents that the shape of the input is (sample_dim, batch_dim, *input_shape) and the shape of condition is (batch_dim, *condition_shape). While it doesn't enforce this, at least some children do. e.g. MixedDensityEstimator: https://github.com/sbi-dev/sbi/blob/a1d75559b1e50b3d19c0eac1cabdfa6a23c16ea4/sbi/neural_nets/estimators/mixed_density_estimator.py#L74 https://github.com/sbi-dev/sbi/blob/a1d75559b1e50b3d19c0eac1cabdfa6a23c16ea4/sbi/neural_nets/estimators/mixed_density_estimator.py#L141

RE documents that the shape of x is (batch_dim, *x_shape) and that the shape of theta is (sample_dim, batch_dim, *theta_shape). Note that the two classes differ in which of the arguments has a sample_dim. However, RE actually enforces that x is (*batch_shape, *x_shape) and theta is (*batch_shape, *theta_shape), i.e. the two arguments share the same prefix, which is incompatible with the documented shapes: https://github.com/sbi-dev/sbi/blob/a1d75559b1e50b3d19c0eac1cabdfa6a23c16ea4/sbi/neural_nets/ratio_estimators.py#L126-L128

Inconsistent argument order in methods

While CDE.log_prob and RE.unnormalized_log_ratio are not equivalent, one would expect their order of arguments to be similar. However, the former takes the order (input, condition): https://github.com/sbi-dev/sbi/blob/a1d75559b1e50b3d19c0eac1cabdfa6a23c16ea4/sbi/neural_nets/estimators/base.py#L155 while the latter takes (theta, x): https://github.com/sbi-dev/sbi/blob/a1d75559b1e50b3d19c0eac1cabdfa6a23c16ea4/sbi/neural_nets/ratio_estimators.py#L156

📌 Additional Context

Torch distributions (and Pyro) implement both sample and log_prob, supporting arbitrary batch_shape and sample_shape (not just a single dimension). While not necessary, it would be nice if these methods supported the same shape conventions. This would in particular simplify #1491.

sethaxen avatar Mar 20 '25 09:03 sethaxen

RE inconsistency in shape conventions between docstrings and what the code expects, I wonder if it would be useful to use jaxtyping to provide shapes in type hints. These then can be automatically rendered in the documentation and the docstrings themselves don't need to include shape information. Contrary to its name, jaxtyping has no jax dependency and supports also torch.Tensor hints.

GPyTorch, for example, uses it in a few places, e.g.: https://github.com/cornellius-gp/gpytorch/blob/b017b9c3fe4de526f7a2243ce12ce2305862c90b/gpytorch/variational/nearest_neighbor_variational_strategy.py#L177-L184

While static type checkers can't check the shape conventions are followed, runtime type checkers like beartype can do so very quickly. sbi can either depend on one of these type-checkers or it can simply run one in the test suite. e.g. for beartype, see:

  • https://beartype.readthedocs.io/en/latest/faq/#pytorch-tensors
  • https://beartype.readthedocs.io/en/latest/faq/#what-does-near-real-time-even-mean-are-you-just-making-stuff-up
  • https://beartype.readthedocs.io/en/latest/faq/#how-do-i-only-type-check-while-running-my-test-suite

sethaxen avatar Mar 20 '25 09:03 sethaxen

If there's interest in the jaxtyping approach, I'd be happy to hack together a quick proof-of-concept.

sethaxen avatar Mar 20 '25 14:03 sethaxen

I made a quick attempt to prototype jaxtyping+beartype in SBI, and I don't think it's a good fit for the following reasons:

  • jaxtyping currently doesn't support multiple variadic shapes for a single tensor (e.g. batch_dim *x_shape is fine, but *batch_shape *x_shape is not). Because string interpolation is allowed, one could e.g. for RatioEstimator interpolate self.x_shape into the string, but I think one would actually need f"*batch_shape {' '.join(map(str, self.x_shape))}", which would just be ugly in the documentation.
  • beartype successfully caught type mismatches, but the errors raised were not very descriptive. It might make sense to have the runtime checking in the test suite, but I wouldn't rely on it to be the main shape checking for user-provided inputs. Here's an example error:
E   beartype.roar.BeartypeCallHintReturnViolation: Method sbi.neural_nets.ratio_estimators.RatioEstimator.combine_theta_and_x() return "tensor([[ 1.1846,  0.7002,  1.0900,  0.6601],
E           [-1.4753, -1.7527,  1.0900,  0.660...]])" violates type hint <class 'jaxtyping.Float[Tensor, 'sample_dim batch_dim combined_event_dim']'>, as this array has 2 dimensions, not the 3 expected by the type hint.

It might still be worth using jaxtyping for methods that don't require multiple variadic shapes. GPyTorch takes this approach of piecemeal using jaxtyping for a few operators. But it's not a silver bullet.

sethaxen avatar Mar 21 '25 08:03 sethaxen