BayesFlow
BayesFlow copied to clipboard
batch_shape vs. batch_size in simulator and related objects
I was thinking about the batch_shape argument which we use across bayesflow, especially in the simulator.
Currently, we use batch_shape = (int,) to allow for potentially multi-dimensional batch shapes. My feeling is that we will never need more than one batch dimension such that we could replace batch_shape = (int,) with batch_size = int patterns throughout. That is, with the batch_size approach, we would always only have and allow for a single batch dimension.
I was wondering what others thing about this? Tagging @LarsKue and @stefanradev93. Also, @daniel-habermann do you think we would need (or at least benefit from) multiple batch dims for multilevel models?
Having batch_size as default for all flat (non-hierarchical models) is definitely nice. I can definitely see where batch_shape can also come in handy for dynamic / hierarchical models.
Would this mean having separate batch_size and batch_shape arguments?
How would batch_shape multi dim come in handy for dynamic / hierarchical models?
The arguments for batch_shape == tuple[int, ...] over batch_size == int are:
- Consistency with ML libraries, e.g. PyTorch Distributions. These only allow a shape argument, not a single integer.
- Easier extension to multi-dimensional sampling, e.g. for the current
HierarchicalSimulator:
simulator = HierarchicalSimulator([fn1, fn2, fn3])
data = simulator.sample((32, 64, 128))
data["fn1_param"].shape == (32, ...)
data["fn2_param"].shape == (32, 64, ...)
data["fn3_param"].shape == (32, 64, 128, ...)
Numpy approaches this problem by allowing integers by default and additionally shape arguments where possible. The argument is also intentionally ambiguously named "size". I think we could follow this approach as well.
I don't think size alone would be a good name because it is too ambigous in our context, since there are some many different sizes floating around. But I do like batch_size as name much more than batch_shape for the following reasons:
- In 90%+ of the cases (I argue for 100% but that is a different issue) we will only need a single batch dimension.
- We also call it
batch_sizein other parts of the library. So usingbatch_sizeeverywhere would be consistent within the library. - We can still allow the
Shapetype in addition tointuntil I have convinced you we don't actually ever need more than one batch dimension.
I will try to explain my take on the HierarchicalSimulator at a different place once I have organized my thoughts on this topic.
I second that. We can also emphasize batch_size (int) for everyday users and also allow batch_shape (tuple, None as default) for power users or special cases, which overrides batch_size.