dm-haiku
dm-haiku copied to clipboard
`hk.max_pool` window misalignment
When a window is defined as such:
x = hk.max_pool(
value=x,
window_shape=(2, 2),
strides=(2, 2),
padding="VALID",
channel_axis=-1,
)
It gets misaligned and applies the max-pooling along the channel axis. I believe the issue is in the following line of code:
https://github.com/deepmind/dm-haiku/blob/7964d01f1c0dd907c8ea016ad1d1cc7ae48ac05d/haiku/_src/pool.py#L46-L47
Pulling that out gives the following result:
>>> from haiku._src.pool import _infer_shape
>>> shape = (2,16,16,5)
>>> x = np.ones(shape)
>>> window_size = (2, 2)
>>> channel_axis = -1
>>> _infer_shape(x, window_size, channel_axis)
(1, 1, 2, 2)
Cheers :slightly_smiling_face:, Simon
Wow this issue has been around for half a year? :sweat_smile:
I just stumbled across the same issue, the following behaviour is super weird:
a = jnp.zeros([1, 8, 8, 8])
pool1 = hk.max_pool(a, 2, 2, 'SAME') # => (1, 4, 4, 8)
pool2 = hk.max_pool(a, [2, 2], [2, 2], 'SAME') # => (1, 8, 4, 4)
Welp, I guess the docs do indeed specify that channel_axis is only used if window_shape/strides are an integer. It's still super confusing, but actually documented.
I got the same confusion as well, the source comes from _infer_shape.
Description
MaxPool is calling max_pool, which eventually calls _infer_shape.
But in MaxPool the argument description is only
channel_axis: Axis of the spatial channels for which pooling is skipped.
while in max_pool, the description is
channel_axis: Axis of the spatial channels for which pooling is skipped, used to infer
window_shapeorstridesif they are an integer.
This means that, if window_shape or strides are not ints, the channel_axis argument is ignored.
Reproduction
import jax.numpy as jnp
import haiku as hk
@hk.testing.transform_and_run()
def f(x):
max_pool = hk.MaxPool(
window_shape=(
2,
2,
2,
),
strides=(
2,
2,
2,
),
padding="VALID",
channel_axis=-1,
)
return max_pool(x)
x = jnp.ones((2, 4, 6, 8, 3))
print(f(x).shape)
This prints
(2, 4, 3, 4, 1)
In order to get the shape right, we need to pass full shapes (1,2,2,2,1), or only ignoring batch axis (2,2,2,1). But the docstring is not clear for this.
Current _infer_shape has following behaviour:
from typing import Union, Optional, Sequence, Tuple
import jax.numpy as jnp
def _infer_shape(
x: jnp.ndarray,
size: Union[int, Sequence[int]],
channel_axis: Optional[int] = -1,
) -> Tuple[int, ...]:
"""Infer shape for pooling window or strides."""
if isinstance(size, int):
if channel_axis and not 0 <= abs(channel_axis) < x.ndim:
raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
if channel_axis and channel_axis < 0:
channel_axis = x.ndim + channel_axis
return (1,) + tuple(size if d != channel_axis else 1
for d in range(1, x.ndim))
elif len(size) < x.ndim:
# Assume additional dimensions are batch dimensions.
return (1,) * (x.ndim - len(size)) + tuple(size)
else:
assert x.ndim == len(size)
return tuple(size)
x = jnp.ones((2, 4, 6, 8, 3))
print(_infer_shape(x, size=(2,2,2)))
print(_infer_shape(x, size=(2,2,2,1)))
print(_infer_shape(x, size=(1,2,2,2,1)))
This prints
(1, 1, 2, 2, 2)
(1, 2, 2, 2, 1)
(1, 2, 2, 2, 1)
Suggestion
We should at least update the docstring. Otherwise, the following two situations should not happen together
- size is a sequence of integers
- channel_axis is defined