flax icon indicating copy to clipboard operation
flax copied to clipboard

Flax avg_pool also includes padding tokens when computing the mean

Open marcvanzee opened this issue 2 years ago • 1 comments

I am not sure if this is an issue but it is definitely a cause of possible confusion: We currently implement pooling.avg_pool as avg_pool(x) = lax.reduce_window(lax.add, x) / prod(window_size). If we use padding, we always divide by the full window size even if this contains padding tokens.

Example:

xs = np.array([1, 2]).reshape((1, 2, 1))
avg_pool(xs, window_shape=(2,), strides=(1,), padding="SAME").reshape((2,))
# Result: [1.5, 1. ]

Is this what we want? the first result (1+2)/2=1.5 makes sense, but the second result 2/2=1. is a bit odd. Shouldn't we do 2/1=2?

Other frameworks do it as follows:

  • Pytorch also uses zero padding but they have a manual override option: https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
  • TF does not use padding and does not provide an override option.

Personally I feel that including padding tokens with value 0 is wrong (it seems like an arbitrary constant). At the very least we should be explicit about our choice and document it.

A possible solution to implementing average pooling and only counting non-padding tokens is to doing an additional sum_pool2 on the same input shape with only 1s, where you pad with 0s. Then you return sum_pool / sum_pool2, which correctly ignores the padding tokens.

marcvanzee avatar Aug 16 '22 11:08 marcvanzee

Discussed this offline with @jheek and @cgarciae. We agreed that the current behavior is not desirable since we are assuming that padding tokens for avg_pool are 0's and we include them when counting the average, but we are not docuementing this anywhere. Tensorflow has chosen to implement this differently, namely by excluding the padding tokens, and similarly, they are not documenting this in their APis. Pytorch seems to have the best of both worlds: they allow the user to specify it in a flags. This seems something we could do as well.

marcvanzee avatar Aug 18 '22 12:08 marcvanzee

I'm going to close this since the changes in #2448 seem to address the main point, feel free to reopen.

levskaya avatar Oct 19 '22 00:10 levskaya