Pool functions reduce over batch dimension and not last dimension
For a multidimensional tensor the pooling functions seem to reduce over the batch dimension (the first one) but don't allow to reduce over the last dimension. This might be on purpose but is totally not clear from the documentation. I don't need a workaround as I actually want to reduce over one of the middle dimensions but thought I should still report it.
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Arch Linux,
6.12.4-arch1-1 - Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib:
flax: 0.10.2
jax: 0.4.38
jaxlib: 0.4.38
- Python version:
3.12.8 - GPU/TPU model and memory:
NVIDIA GeForce RTX 20606144MiB(There is anotherNVIDIA TITAN XpGPU on that machine but flax/jax is using the first one). - CUDA version (if applicable):
12.6.r12.6
Problem you have encountered:
When using a pool functions such as avg_pool or max_pool from either flax.nnx or flax.linen and one specifies a window that has a size > 1 in the first dimension then the reduction is done on the batch dimension.
Furthermore the window and stride arguments don't allow one to use a tuple that spans over all dimensions so that one cannot pool over the last dimension.
What you expected to happen:
Either pooling should run separately for each batch entry or one should be able to specify a window of the same dimension as the whole tensor.
Logs, error messages, etc:
Steps to reproduce:
import jax.numpy as jnp
from flax import nnx
# create a tensor of shape (2, 3, 4) - i.e. with a batch size of 2
x = jnp.float32(range(24)).reshape((2,3,4))
print(x[0,0,0]) # 0.0
print(x[1,0,0]) # 12.0
x_reduced = nnx.avg_pool(x, window_shape=(2, 1), padding='VALID')
# This shows that the reduction happened over the batch dimension
print(x_reduced.shape) # (1, 3, 4)
# This show that the first reduced value is the average of x[0,0,0] and x[1,0,0]
print(x_reduced[0,0,0]) # 6.0
# Trying to to reduce over the last dimension gives an exception
x_reduced = nnx.avg_pool(x, window_shape=(1, 1, 4), padding='VALID')
The message of the exception is
AssertionError Traceback (most recent call last)
Cell In[23], line 1
----> 1 x_reduced = nnx.avg_pool(x, window_shape=(1, 1, 4), padding='VALID')
File ~/myproject/.venv/lib/python3.12/site-packages/flax/linen/pooling.py:97, in avg_pool(inputs, window_shape, strides, padding, count_include_pad)
79 def avg_pool(
80 inputs, window_shape, strides=None, padding='VALID', count_include_pad=True
81 ):
82 """Pools the input by taking the average over a window.
83
84 Args:
(...)
95 The average for each window slice.
96 """
---> 97 y = pool(inputs, 0.0, lax.add, window_shape, strides, padding)
98 if count_include_pad:
99 y = y / np.prod(window_shape)
File ~/myproject/.venv/lib/python3.12/site-packages/flax/linen/pooling.py:62, in pool(inputs, init, reduce_fn, window_shape, strides, padding)
59 dims = (1,) + dims
60 is_single_input = True
---> 62 assert inputs.ndim == len(dims), f'len({inputs.shape}) != len({dims})'
63 if not isinstance(padding, str):
64 padding = tuple(map(tuple, padding))
AssertionError: len((2, 3, 4)) != len((1, 1, 4, 1))
Hi @simonschoelly, the issue is that pooling layers don't operate over the last dimension as most commonly you reduce over the time / space dimensions. Maybe you could add a dummy features dimension:
x = jnp.float32(range(24)).reshape((2, 3, 4, 1))
x_reduced = nnx.avg_pool(x, window_shape=(2, 1), padding='VALID')
print(x_reduced.shape) # (2, 2, 4, 1)