probability
probability copied to clipboard
Enable batch support for `windowed_mean|variance`
This PR makes functions windowed_mean and windowed_variance to accept indices with batch dimensions.
Example:
x = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype=np.float32)
low_indices = [[0, 0, 0], [1, 0, 0], [2, 2, 0]]
high_indices = [[3, 3, 3], [1, 2, 3], [3, 2, 1]]
tfp.stats.windowed_mean(x, low_indices=low_indices, high_indices=high_indices, axis=1)
Now gives:
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[2. , 2. , 2. ],
[0. , 1.5, 2. ],
[3. , 0. , 1. ]], dtype=float32)>
Was previously failing with:
tensorflow.python.framework.errors_impl.InvalidArgumentError: required broadcastable shapes [Op:SelectV2]
@axch I made changes in the code you authored, could you kindly have a look at this PR? Thanks
@nicolaspi thanks for the contribution! I am no longer an active maintainer of TFP, so I'm not really in a position to review your PR in detail (@jburnim please suggest someone?). On a quick look, though, I see a couple potential code style issues:
- Do we need the dependency on tf.experimental.numpy?
- Do we need the special case for rank-1 indices? Could we define a more uniform behavior instead?
- I'm guessing some of the shape munging already has relevant helpers defined elsewhere in TFP, but I don't remember off-hand
- TFP generally tries to handle static and dynamic TF shapes uniformly using prefer_static and tensorshape_util.
Thanks for your feedback!
- Do we need the dependency on tf.experimental.numpy?
We need specifically the take_along_axis function that allow to gather the slices along each batch dimensions. I replaced the 'experimental' import path with from tensorflow.python.ops import numpy_ops.
- Do we need the special case for rank-1 indices? Could we define a more uniform behavior instead?
There is two motivations for this case. First, for backward compatibility, it is equivalent to the legacy non batched usage. Second, it is the only case I can think of where the broadcast is unambiguous when rank(indices) < rank(x).
- I'm guessing some of the shape munging already has relevant helpers defined elsewhere in TFP, but I don't remember off-hand
In any case, I modified the unit tests to test against non static shapes.
- TFP generally tries to handle static and dynamic TF shapes uniformly using prefer_static and tensorshape_util.
I made usage of prefer_static whenever possible.
@jburnim can you please suggest a reviewer? CC @axch
I'll take a look at this.
Hi @SiegeLordEx, I have assessed your comments, can you have a look? Thanks