flax icon indicating copy to clipboard operation
flax copied to clipboard

Pool functions reduce over batch dimension and not last dimension

Open simonschoelly opened this issue 11 months ago • 1 comments

For a multidimensional tensor the pooling functions seem to reduce over the batch dimension (the first one) but don't allow to reduce over the last dimension. This might be on purpose but is totally not clear from the documentation. I don't need a workaround as I actually want to reduce over one of the middle dimensions but thought I should still report it.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Arch Linux, 6.12.4-arch1-1
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib:
flax: 0.10.2
jax: 0.4.38
jaxlib: 0.4.38
  • Python version: 3.12.8
  • GPU/TPU model and memory: NVIDIA GeForce RTX 2060 6144MiB (There is another NVIDIA TITAN Xp GPU on that machine but flax/jax is using the first one).
  • CUDA version (if applicable): 12.6.r12.6

Problem you have encountered:

When using a pool functions such as avg_pool or max_pool from either flax.nnx or flax.linen and one specifies a window that has a size > 1 in the first dimension then the reduction is done on the batch dimension.

Furthermore the window and stride arguments don't allow one to use a tuple that spans over all dimensions so that one cannot pool over the last dimension.

What you expected to happen:

Either pooling should run separately for each batch entry or one should be able to specify a window of the same dimension as the whole tensor.

Logs, error messages, etc:

Steps to reproduce:

import jax.numpy as jnp
from flax import nnx

# create a tensor of shape (2, 3, 4) - i.e. with a batch size of 2
x = jnp.float32(range(24)).reshape((2,3,4))

print(x[0,0,0]) # 0.0
print(x[1,0,0]) # 12.0

x_reduced = nnx.avg_pool(x, window_shape=(2, 1), padding='VALID')

# This shows that the reduction happened over the batch dimension
print(x_reduced.shape) # (1, 3, 4)

# This show that the first reduced value is the average of x[0,0,0] and x[1,0,0]
print(x_reduced[0,0,0]) # 6.0

# Trying to to reduce over the last dimension gives an exception
x_reduced = nnx.avg_pool(x, window_shape=(1, 1, 4), padding='VALID')

The message of the exception is

AssertionError                            Traceback (most recent call last)
Cell In[23], line 1
----> 1 x_reduced = nnx.avg_pool(x, window_shape=(1, 1, 4), padding='VALID')

File ~/myproject/.venv/lib/python3.12/site-packages/flax/linen/pooling.py:97, in avg_pool(inputs, window_shape, strides, padding, count_include_pad)
     79 def avg_pool(
     80   inputs, window_shape, strides=None, padding='VALID', count_include_pad=True
     81 ):
     82   """Pools the input by taking the average over a window.
     83 
     84   Args:
   (...)
     95     The average for each window slice.
     96   """
---> 97   y = pool(inputs, 0.0, lax.add, window_shape, strides, padding)
     98   if count_include_pad:
     99     y = y / np.prod(window_shape)

File ~/myproject/.venv/lib/python3.12/site-packages/flax/linen/pooling.py:62, in pool(inputs, init, reduce_fn, window_shape, strides, padding)
     59   dims = (1,) + dims
     60   is_single_input = True
---> 62 assert inputs.ndim == len(dims), f'len({inputs.shape}) != len({dims})'
     63 if not isinstance(padding, str):
     64   padding = tuple(map(tuple, padding))

AssertionError: len((2, 3, 4)) != len((1, 1, 4, 1))

simonschoelly avatar Jan 21 '25 17:01 simonschoelly

Hi @simonschoelly, the issue is that pooling layers don't operate over the last dimension as most commonly you reduce over the time / space dimensions. Maybe you could add a dummy features dimension:

x = jnp.float32(range(24)).reshape((2, 3, 4, 1))
x_reduced = nnx.avg_pool(x, window_shape=(2, 1), padding='VALID')
print(x_reduced.shape)  # (2, 2, 4, 1)

cgarciae avatar Jan 23 '25 16:01 cgarciae