dm-haiku icon indicating copy to clipboard operation
dm-haiku copied to clipboard

`hk.max_pool` window misalignment

Open SimonBiggs opened this issue 4 years ago • 3 comments
trafficstars

When a window is defined as such:

x = hk.max_pool(
    value=x,
    window_shape=(2, 2),
    strides=(2, 2),
    padding="VALID",
    channel_axis=-1,
)

It gets misaligned and applies the max-pooling along the channel axis. I believe the issue is in the following line of code:

https://github.com/deepmind/dm-haiku/blob/7964d01f1c0dd907c8ea016ad1d1cc7ae48ac05d/haiku/_src/pool.py#L46-L47

Pulling that out gives the following result:

>>> from haiku._src.pool import _infer_shape

>>> shape = (2,16,16,5)
>>> x = np.ones(shape)
>>> window_size = (2, 2)
>>> channel_axis = -1

>>> _infer_shape(x, window_size, channel_axis)
(1, 1, 2, 2)

Cheers :slightly_smiling_face:, Simon

SimonBiggs avatar Apr 27 '21 12:04 SimonBiggs

Wow this issue has been around for half a year? :sweat_smile:

I just stumbled across the same issue, the following behaviour is super weird:

a = jnp.zeros([1, 8, 8, 8])
pool1 = hk.max_pool(a, 2, 2, 'SAME')            # => (1, 4, 4, 8)
pool2 = hk.max_pool(a, [2, 2], [2, 2], 'SAME')  # => (1, 8, 4, 4)

khdlr avatar Nov 11 '21 09:11 khdlr

Welp, I guess the docs do indeed specify that channel_axis is only used if window_shape/strides are an integer. It's still super confusing, but actually documented.

khdlr avatar Nov 11 '21 13:11 khdlr

I got the same confusion as well, the source comes from _infer_shape.

Description

MaxPool is calling max_pool, which eventually calls _infer_shape.

But in MaxPool the argument description is only

channel_axis: Axis of the spatial channels for which pooling is skipped.

while in max_pool, the description is

channel_axis: Axis of the spatial channels for which pooling is skipped, used to infer window_shape or strides if they are an integer.

This means that, if window_shape or strides are not ints, the channel_axis argument is ignored.

Reproduction

import jax.numpy as jnp
import haiku as hk

@hk.testing.transform_and_run()
def f(x):
    max_pool = hk.MaxPool(
        window_shape=(
            2,
            2,
            2,
        ),
        strides=(
            2,
            2,
            2,
        ),
        padding="VALID",
        channel_axis=-1,
    )
    return max_pool(x)

x = jnp.ones((2, 4, 6, 8, 3))
print(f(x).shape)

This prints

(2, 4, 3, 4, 1)

In order to get the shape right, we need to pass full shapes (1,2,2,2,1), or only ignoring batch axis (2,2,2,1). But the docstring is not clear for this.

Current _infer_shape has following behaviour:

from typing import Union, Optional, Sequence, Tuple
import jax.numpy as jnp

def _infer_shape(
    x: jnp.ndarray,
    size: Union[int, Sequence[int]],
    channel_axis: Optional[int] = -1,
) -> Tuple[int, ...]:
    """Infer shape for pooling window or strides."""
    if isinstance(size, int):
        if channel_axis and not 0 <= abs(channel_axis) < x.ndim:
            raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
        if channel_axis and channel_axis < 0:
            channel_axis = x.ndim + channel_axis
        return (1,) + tuple(size if d != channel_axis else 1
                            for d in range(1, x.ndim))
    elif len(size) < x.ndim:
        # Assume additional dimensions are batch dimensions.
        return (1,) * (x.ndim - len(size)) + tuple(size)
    else:
        assert x.ndim == len(size)
        return tuple(size)

x = jnp.ones((2, 4, 6, 8, 3))
print(_infer_shape(x, size=(2,2,2)))
print(_infer_shape(x, size=(2,2,2,1)))
print(_infer_shape(x, size=(1,2,2,2,1)))

This prints

(1, 1, 2, 2, 2)
(1, 2, 2, 2, 1)
(1, 2, 2, 2, 1)

Suggestion

We should at least update the docstring. Otherwise, the following two situations should not happen together

  • size is a sequence of integers
  • channel_axis is defined

mathpluscode avatar Jan 03 '23 13:01 mathpluscode