flax icon indicating copy to clipboard operation
flax copied to clipboard

Adding "count_include_pad" argument to flax.linen.pooling.avg_pool

Open dslisleedh opened this issue 1 year ago • 11 comments

What does this PR do?

Now version's flax.linen.pooling.avg_pool average window_sum result include padded tokens. I add argument whether to include padded tokens or not

Checklist

  • [x] This change is discussed in a Github issue/ issues
  • [x] The documentation and docstrings adhere to the documentation guidelines.
  • [x] This change includes necessary high-coverage tests. (No quality testing = no merge!)

dslisleedh avatar Sep 09 '22 08:09 dslisleedh

Codecov Report

Merging #2451 (d4fa16d) into main (8687673) will decrease coverage by 0.83%. The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main    #2451      +/-   ##
==========================================
- Coverage   79.66%   78.83%   -0.84%     
==========================================
  Files          49       49              
  Lines        4982     5070      +88     
==========================================
+ Hits         3969     3997      +28     
- Misses       1013     1073      +60     
Impacted Files Coverage Δ
flax/linen/pooling.py 86.48% <100.00%> (+2.11%) :arrow_up:
flax/training/checkpoints.py 61.35% <0.00%> (-11.86%) :arrow_down:
flax/errors.py 87.20% <0.00%> (-0.75%) :arrow_down:
flax/traverse_util.py 98.52% <0.00%> (-0.50%) :arrow_down:
flax/linen/linear.py 97.50% <0.00%> (ø)
flax/core/lift.py 95.81% <0.00%> (+<0.01%) :arrow_up:
flax/linen/transforms.py 94.06% <0.00%> (+0.02%) :arrow_up:
flax/linen/module.py 92.75% <0.00%> (+0.03%) :arrow_up:
flax/serialization.py 69.27% <0.00%> (+1.39%) :arrow_up:

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

codecov-commenter avatar Sep 09 '22 12:09 codecov-commenter

Hey @dslisleedh, thanks for creating this PR! I think this looks good, however, before merging we should probably add a test using this flag.

cgarciae avatar Sep 13 '22 13:09 cgarciae

To @cgarciae,

Sorry, I tested my self but didn't share result in PR. Here's result.

for i, shape in enumerate([(1, 5, 3), (1, 5, 5, 3), (1, 5, 5, 5, 3)]):
    inputs = jnp.ones(shape=shape)
    for kernel_size in [(1,), (3,), (5,)]:
        k_s = kernel_size * (i + 1)
        for strides in [(1,), (2,), (3,)]:
            s = strides * (i + 1)
            for padding in ['VALID','SAME']:
                res = avg_pool(inputs, k_s, strides=s, padding=padding, count_include_pad=False)

                print(jnp.min(res), jnp.max(res))

1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0

This is first time for me to create PR, so tell me anything if I missed something. Thank you.

dslisleedh avatar Sep 13 '22 22:09 dslisleedh

Just noticed there aren't any tests for pooling.py, I'll create an issue, we can add a test for this in the future.

cgarciae avatar Sep 14 '22 15:09 cgarciae

I found PoolTest class from ./tests/linen/test_linen.py

class PoolTest(absltest.TestCase):

  def test_pool_custom_reduce(self):
    x = jnp.full((1, 3, 3, 1), 2.)
    mul_reduce = lambda x, y: x * y
    y = nn.pooling.pool(x, 1., mul_reduce, (2, 2), (1, 1), 'VALID')
    np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2. ** 4))

  def test_avg_pool(self):
    x = jnp.full((1, 3, 3, 1), 2.)
    pool = lambda x: nn.avg_pool(x, (2, 2))
    y = pool(x)
    np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.))
    y_grad = jax.grad(lambda x: pool(x).sum())(x)
    expected_grad = jnp.array([
        [0.25, 0.5, 0.25],
        [0.5, 1., 0.5],
        [0.25, 0.5, 0.25],
    ]).reshape((1, 3, 3, 1))
    np.testing.assert_allclose(y_grad, expected_grad)

  def test_avg_pool_no_batch(self):
    x = jnp.full((3, 3, 1), 2.)
    pool = lambda x: nn.avg_pool(x, (2, 2))
    y = pool(x)
    np.testing.assert_allclose(y, np.full((2, 2, 1), 2.))
    y_grad = jax.grad(lambda x: pool(x).sum())(x)
    expected_grad = jnp.array([
        [0.25, 0.5, 0.25],
        [0.5, 1., 0.5],
        [0.25, 0.5, 0.25],
    ]).reshape((3, 3, 1))
    np.testing.assert_allclose(y_grad, expected_grad)

  def test_max_pool(self):
    ...

When I ran this code with the edits from this PR there were no problems.

(flax_test) idongheon@idongheon-ui-MacBookAir flax % python ./tests/linen/linen_test.py 
Running tests under Python 3.10.5: /Users/idongheon/miniforge3/envs/flax_test/bin/python
[ RUN      ] IdsTest.test_hashable
[       OK ] IdsTest.test_hashable
[ RUN      ] NormalizationTest.test_batch_norm
I0915 01:20:42.378859 4314512704 xla_bridge.py:169] Remote TPU is not linked into jax; skipping remote TPU.
I0915 01:20:42.379297 4314512704 xla_bridge.py:345] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
I0915 01:20:42.379443 4314512704 xla_bridge.py:345] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0915 01:20:42.379549 4314512704 xla_bridge.py:345] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0915 01:20:42.380259 4314512704 xla_bridge.py:345] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
[       OK ] NormalizationTest.test_batch_norm
[ RUN      ] NormalizationTest.test_batch_norm_complex
[       OK ] NormalizationTest.test_batch_norm_complex
[ RUN      ] NormalizationTest.test_batch_norm_multi_init
[       OK ] NormalizationTest.test_batch_norm_multi_init
[ RUN      ] NormalizationTest.test_group_norm
[       OK ] NormalizationTest.test_group_norm
[ RUN      ] NormalizationTest.test_group_norm_raises
[       OK ] NormalizationTest.test_group_norm_raises
[ RUN      ] NormalizationTest.test_layer_norm0 (reduction_axes=-1)
[       OK ] NormalizationTest.test_layer_norm0 (reduction_axes=-1)
[ RUN      ] NormalizationTest.test_layer_norm1 (reduction_axes=1)
[       OK ] NormalizationTest.test_layer_norm1 (reduction_axes=1)
[ RUN      ] NormalizationTest.test_layer_norm2 (reduction_axes=(1, 2))
[       OK ] NormalizationTest.test_layer_norm2 (reduction_axes=(1, 2))
[ RUN      ] PoolTest.test_avg_pool
[       OK ] PoolTest.test_avg_pool
[ RUN      ] PoolTest.test_avg_pool_no_batch
[       OK ] PoolTest.test_avg_pool_no_batch
[ RUN      ] PoolTest.test_max_pool
[       OK ] PoolTest.test_max_pool
[ RUN      ] PoolTest.test_pool_custom_reduce
[       OK ] PoolTest.test_pool_custom_reduce
[ RUN      ] RecurrentTest.test_complex_input_gru
[       OK ] RecurrentTest.test_complex_input_gru
[ RUN      ] RecurrentTest.test_convlstm
[       OK ] RecurrentTest.test_convlstm
[ RUN      ] RecurrentTest.test_gru
[       OK ] RecurrentTest.test_gru
[ RUN      ] RecurrentTest.test_lstm
[       OK ] RecurrentTest.test_lstm
[ RUN      ] RecurrentTest.test_optimized_lstm_cell_matches_regular
/Users/idongheon/miniforge3/envs/flax_test/lib/python3.10/site-packages/jax/test_util.py:44: FutureWarning: jax.test_util.check_eq is deprecated and will soon be removed.
  warnings.warn(f"jax.test_util.{attr} is deprecated and will soon be removed.", FutureWarning)
[       OK ] RecurrentTest.test_optimized_lstm_cell_matches_regular
[ RUN      ] StochasticTest.test_dropout
[       OK ] StochasticTest.test_dropout
[ RUN      ] StochasticTest.test_dropout_rate_limits
[       OK ] StochasticTest.test_dropout_rate_limits
[ RUN      ] StochasticTest.test_dropout_rate_stats
[       OK ] StochasticTest.test_dropout_rate_stats
----------------------------------------------------------------------
Ran 21 tests in 13.907s

OK

dslisleedh avatar Sep 14 '22 16:09 dslisleedh

@dslisleedh can you create one of more tests under PoolTest using this new argument? You can copy an existing test, rename it, and change it to use the new count_include_pad argument. This would help immensely as these will be checked automatically before merging new code.

cgarciae avatar Sep 19 '22 15:09 cgarciae

To @cgarciae

I add some codes to TestPool and here's code I tested.

from absl.testing import absltest, parameterized

from flax import ids
from flax import linen as nn

import jax
from jax import random
from jax import test_util as jtu
from jax.nn import initializers
import jax.numpy as jnp

import numpy as np

# Parse absl flags test_srcdir and test_tmpdir.
jax.config.parse_flags_with_absl()


class PoolTest(parameterized.TestCase):

    def test_pool_custom_reduce(self):
        x = jnp.full((1, 3, 3, 1), 2.)
        mul_reduce = lambda x, y: x * y
        y = nn.pooling.pool(x, 1., mul_reduce, (2, 2), (1, 1), 'VALID')
        np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2. ** 4))

    @parameterized.parameters(
        {'count_include_pad': True},
        {'count_include_pad': False})
    def test_avg_pool(self, count_include_pad):
        x = jnp.full((1, 3, 3, 1), 2.)
        pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad)
        y = pool(x)
        np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.))
        y_grad = jax.grad(lambda x: pool(x).sum())(x)
        expected_grad = jnp.array([
            [0.25, 0.5, 0.25],
            [0.5, 1., 0.5],
            [0.25, 0.5, 0.25],
        ]).reshape((1, 3, 3, 1))
        np.testing.assert_allclose(y_grad, expected_grad)

    @parameterized.parameters(
        {'count_include_pad': True},
        {'count_include_pad': False})
    def test_avg_pool_no_batch(self, count_include_pad):
        x = jnp.full((3, 3, 1), 2.)
        pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad)
        y = pool(x)
        np.testing.assert_allclose(y, np.full((2, 2, 1), 2.))
        y_grad = jax.grad(lambda x: pool(x).sum())(x)
        expected_grad = jnp.array([
            [0.25, 0.5, 0.25],
            [0.5, 1., 0.5],
            [0.25, 0.5, 0.25],
        ]).reshape((3, 3, 1))
        np.testing.assert_allclose(y_grad, expected_grad)

    def test_max_pool(self):
        x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
        pool = lambda x: nn.max_pool(x, (2, 2))
        expected_y = jnp.array([
            [4., 5.],
            [7., 8.],
        ]).reshape((1, 2, 2, 1))
        y = pool(x)
        np.testing.assert_allclose(y, expected_y)
        y_grad = jax.grad(lambda x: pool(x).sum())(x)
        expected_grad = jnp.array([
            [0., 0., 0.],
            [0., 1., 1.],
            [0., 1., 1.],
        ]).reshape((1, 3, 3, 1))
        np.testing.assert_allclose(y_grad, expected_grad)


if __name__ == '__main__':
  absltest.main()

When I tested with the previous code, an error occurred in the non-batch avg_pool, so I corrected the PR. Thanks for telling me to test it for sure.

Below is the test result with the modified code.

(flax_test) idongheon@idongheon-ui-MacBookAir flax % python pool_test.py 
Running tests under Python 3.10.5: /Users/idongheon/miniforge3/envs/flax_test/bin/python
[ RUN      ] PoolTest.test_avg_pool0 (count_include_pad=True)
I0920 20:09:02.610634 4371807552 xla_bridge.py:169] Remote TPU is not linked into jax; skipping remote TPU.
I0920 20:09:02.610964 4371807552 xla_bridge.py:345] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
I0920 20:09:02.611104 4371807552 xla_bridge.py:345] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0920 20:09:02.611212 4371807552 xla_bridge.py:345] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0920 20:09:02.611675 4371807552 xla_bridge.py:345] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
[       OK ] PoolTest.test_avg_pool0 (count_include_pad=True)
[ RUN      ] PoolTest.test_avg_pool1 (count_include_pad=False)
[       OK ] PoolTest.test_avg_pool1 (count_include_pad=False)
[ RUN      ] PoolTest.test_avg_pool_no_batch0 (count_include_pad=True)
[       OK ] PoolTest.test_avg_pool_no_batch0 (count_include_pad=True)
[ RUN      ] PoolTest.test_avg_pool_no_batch1 (count_include_pad=False)
[       OK ] PoolTest.test_avg_pool_no_batch1 (count_include_pad=False)
[ RUN      ] PoolTest.test_max_pool
[       OK ] PoolTest.test_max_pool
[ RUN      ] PoolTest.test_pool_custom_reduce
[       OK ] PoolTest.test_pool_custom_reduce
----------------------------------------------------------------------
Ran 6 tests in 1.586s

OK

Thank you.

dslisleedh avatar Sep 20 '22 11:09 dslisleedh

Awesome @dslisleedh! Can you commit changes to the tests?

cgarciae avatar Sep 20 '22 14:09 cgarciae

To @cgarciae

Sure :)

dslisleedh avatar Sep 20 '22 15:09 dslisleedh

@dslisleedh can you add this test? None of the other tests used padding="SAME" which is the more interesting case.

    @parameterized.parameters(
        {'count_include_pad': True},
        {'count_include_pad': False})
    def test_avg_pool_padding_same(self, count_include_pad):
        x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1))
        pool = lambda x: nn.avg_pool(x, (2, 2), padding="SAME", count_include_pad=count_include_pad)
        y = pool(x)
        if count_include_pad:
          expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape((1, 2, 2, 1))
        else:
          expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape((1, 2, 2, 1))
        np.testing.assert_allclose(y, expected_y)

cgarciae avatar Sep 20 '22 22:09 cgarciae

@cgarciae Oh, I forgot that. Thank you.

and here is result of your code.

(flax_test) idongheon@idongheon-ui-MacBookAir flax % python ./tests/linen/linen_test.py 
Running tests under Python 3.10.5: /Users/idongheon/miniforge3/envs/flax_test/bin/python
[ RUN      ] IdsTest.test_hashable
[       OK ] IdsTest.test_hashable
[ RUN      ] NormalizationTest.test_batch_norm
I0921 22:23:30.726801 4309613888 xla_bridge.py:169] Remote TPU is not linked into jax; skipping remote TPU.
I0921 22:23:30.727137 4309613888 xla_bridge.py:345] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
I0921 22:23:30.727280 4309613888 xla_bridge.py:345] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0921 22:23:30.727386 4309613888 xla_bridge.py:345] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0921 22:23:30.727836 4309613888 xla_bridge.py:345] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
[       OK ] NormalizationTest.test_batch_norm
[ RUN      ] NormalizationTest.test_batch_norm_complex
[       OK ] NormalizationTest.test_batch_norm_complex
[ RUN      ] NormalizationTest.test_batch_norm_multi_init
[       OK ] NormalizationTest.test_batch_norm_multi_init
[ RUN      ] NormalizationTest.test_group_norm
[       OK ] NormalizationTest.test_group_norm
[ RUN      ] NormalizationTest.test_group_norm_raises
[       OK ] NormalizationTest.test_group_norm_raises
[ RUN      ] NormalizationTest.test_layer_norm0 (reduction_axes=-1)
[       OK ] NormalizationTest.test_layer_norm0 (reduction_axes=-1)
[ RUN      ] NormalizationTest.test_layer_norm1 (reduction_axes=1)
[       OK ] NormalizationTest.test_layer_norm1 (reduction_axes=1)
[ RUN      ] NormalizationTest.test_layer_norm2 (reduction_axes=(1, 2))
[       OK ] NormalizationTest.test_layer_norm2 (reduction_axes=(1, 2))
[ RUN      ] PoolTest.test_avg_pool0 (count_include_pad=True)
[       OK ] PoolTest.test_avg_pool0 (count_include_pad=True)
[ RUN      ] PoolTest.test_avg_pool1 (count_include_pad=False)
[       OK ] PoolTest.test_avg_pool1 (count_include_pad=False)
[ RUN      ] PoolTest.test_avg_pool_no_batch0 (count_include_pad=True)
[       OK ] PoolTest.test_avg_pool_no_batch0 (count_include_pad=True)
[ RUN      ] PoolTest.test_avg_pool_no_batch1 (count_include_pad=False)
[       OK ] PoolTest.test_avg_pool_no_batch1 (count_include_pad=False)
[ RUN      ] PoolTest.test_avg_pool_padding_same0 (count_include_pad=True)
[       OK ] PoolTest.test_avg_pool_padding_same0 (count_include_pad=True)
[ RUN      ] PoolTest.test_avg_pool_padding_same1 (count_include_pad=False)
[       OK ] PoolTest.test_avg_pool_padding_same1 (count_include_pad=False)
[ RUN      ] PoolTest.test_max_pool
[       OK ] PoolTest.test_max_pool
[ RUN      ] PoolTest.test_pool_custom_reduce
[       OK ] PoolTest.test_pool_custom_reduce
[ RUN      ] RecurrentTest.test_complex_input_gru
[       OK ] RecurrentTest.test_complex_input_gru
[ RUN      ] RecurrentTest.test_convlstm
[       OK ] RecurrentTest.test_convlstm
[ RUN      ] RecurrentTest.test_gru
[       OK ] RecurrentTest.test_gru
[ RUN      ] RecurrentTest.test_lstm
[       OK ] RecurrentTest.test_lstm
[ RUN      ] RecurrentTest.test_optimized_lstm_cell_matches_regular
/Users/idongheon/miniforge3/envs/flax_test/lib/python3.10/site-packages/jax/test_util.py:44: FutureWarning: jax.test_util.check_eq is deprecated and will soon be removed.
  warnings.warn(f"jax.test_util.{attr} is deprecated and will soon be removed.", FutureWarning)
[       OK ] RecurrentTest.test_optimized_lstm_cell_matches_regular
[ RUN      ] StochasticTest.test_dropout
[       OK ] StochasticTest.test_dropout
[ RUN      ] StochasticTest.test_dropout_rate_limits
[       OK ] StochasticTest.test_dropout_rate_limits
[ RUN      ] StochasticTest.test_dropout_rate_stats
[       OK ] StochasticTest.test_dropout_rate_stats
----------------------------------------------------------------------
Ran 25 tests in 14.023s

OK

dslisleedh avatar Sep 21 '22 13:09 dslisleedh