probability icon indicating copy to clipboard operation
probability copied to clipboard

Enable batch support for `windowed_mean|variance`

Open nicolaspi opened this issue 3 years ago • 5 comments

This PR makes functions windowed_mean and windowed_variance to accept indices with batch dimensions.

Example:

x = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype=np.float32)
low_indices = [[0, 0, 0], [1, 0, 0], [2, 2, 0]]
high_indices = [[3, 3, 3], [1, 2, 3], [3, 2, 1]]
tfp.stats.windowed_mean(x, low_indices=low_indices, high_indices=high_indices, axis=1)

Now gives:

<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[2. , 2. , 2. ],
       [0. , 1.5, 2. ],
       [3. , 0. , 1. ]], dtype=float32)>

Was previously failing with:

tensorflow.python.framework.errors_impl.InvalidArgumentError: required broadcastable shapes [Op:SelectV2]

nicolaspi avatar Aug 08 '22 18:08 nicolaspi

@axch I made changes in the code you authored, could you kindly have a look at this PR? Thanks

nicolaspi avatar Aug 09 '22 07:08 nicolaspi

@nicolaspi thanks for the contribution! I am no longer an active maintainer of TFP, so I'm not really in a position to review your PR in detail (@jburnim please suggest someone?). On a quick look, though, I see a couple potential code style issues:

  • Do we need the dependency on tf.experimental.numpy?
  • Do we need the special case for rank-1 indices? Could we define a more uniform behavior instead?
  • I'm guessing some of the shape munging already has relevant helpers defined elsewhere in TFP, but I don't remember off-hand
  • TFP generally tries to handle static and dynamic TF shapes uniformly using prefer_static and tensorshape_util.

axch avatar Aug 09 '22 16:08 axch

Thanks for your feedback!

  • Do we need the dependency on tf.experimental.numpy?

We need specifically the take_along_axis function that allow to gather the slices along each batch dimensions. I replaced the 'experimental' import path with from tensorflow.python.ops import numpy_ops.

  • Do we need the special case for rank-1 indices? Could we define a more uniform behavior instead?

There is two motivations for this case. First, for backward compatibility, it is equivalent to the legacy non batched usage. Second, it is the only case I can think of where the broadcast is unambiguous when rank(indices) < rank(x).

  • I'm guessing some of the shape munging already has relevant helpers defined elsewhere in TFP, but I don't remember off-hand

In any case, I modified the unit tests to test against non static shapes.

I made usage of prefer_static whenever possible.

nicolaspi avatar Aug 10 '22 08:08 nicolaspi

@jburnim can you please suggest a reviewer? CC @axch

nicolaspi avatar Sep 20 '22 07:09 nicolaspi

I'll take a look at this.

SiegeLordEx avatar Sep 20 '22 07:09 SiegeLordEx

Hi @SiegeLordEx, I have assessed your comments, can you have a look? Thanks

nicolaspi avatar Oct 06 '22 09:10 nicolaspi