flax
flax copied to clipboard
Adding "count_include_pad" argument to flax.linen.pooling.avg_pool
What does this PR do?
Now version's flax.linen.pooling.avg_pool average window_sum result include padded tokens. I add argument whether to include padded tokens or not
Checklist
- [x] This change is discussed in a Github issue/ issues
- [x] The documentation and docstrings adhere to the documentation guidelines.
- [x] This change includes necessary high-coverage tests. (No quality testing = no merge!)
Codecov Report
Merging #2451 (d4fa16d) into main (8687673) will decrease coverage by
0.83%
. The diff coverage is100.00%
.
@@ Coverage Diff @@
## main #2451 +/- ##
==========================================
- Coverage 79.66% 78.83% -0.84%
==========================================
Files 49 49
Lines 4982 5070 +88
==========================================
+ Hits 3969 3997 +28
- Misses 1013 1073 +60
Impacted Files | Coverage Δ | |
---|---|---|
flax/linen/pooling.py | 86.48% <100.00%> (+2.11%) |
:arrow_up: |
flax/training/checkpoints.py | 61.35% <0.00%> (-11.86%) |
:arrow_down: |
flax/errors.py | 87.20% <0.00%> (-0.75%) |
:arrow_down: |
flax/traverse_util.py | 98.52% <0.00%> (-0.50%) |
:arrow_down: |
flax/linen/linear.py | 97.50% <0.00%> (ø) |
|
flax/core/lift.py | 95.81% <0.00%> (+<0.01%) |
:arrow_up: |
flax/linen/transforms.py | 94.06% <0.00%> (+0.02%) |
:arrow_up: |
flax/linen/module.py | 92.75% <0.00%> (+0.03%) |
:arrow_up: |
flax/serialization.py | 69.27% <0.00%> (+1.39%) |
:arrow_up: |
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.
Hey @dslisleedh, thanks for creating this PR! I think this looks good, however, before merging we should probably add a test using this flag.
To @cgarciae,
Sorry, I tested my self but didn't share result in PR. Here's result.
for i, shape in enumerate([(1, 5, 3), (1, 5, 5, 3), (1, 5, 5, 5, 3)]):
inputs = jnp.ones(shape=shape)
for kernel_size in [(1,), (3,), (5,)]:
k_s = kernel_size * (i + 1)
for strides in [(1,), (2,), (3,)]:
s = strides * (i + 1)
for padding in ['VALID','SAME']:
res = avg_pool(inputs, k_s, strides=s, padding=padding, count_include_pad=False)
print(jnp.min(res), jnp.max(res))
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
This is first time for me to create PR, so tell me anything if I missed something. Thank you.
Just noticed there aren't any tests for pooling.py
, I'll create an issue, we can add a test for this in the future.
I found PoolTest class from ./tests/linen/test_linen.py
class PoolTest(absltest.TestCase):
def test_pool_custom_reduce(self):
x = jnp.full((1, 3, 3, 1), 2.)
mul_reduce = lambda x, y: x * y
y = nn.pooling.pool(x, 1., mul_reduce, (2, 2), (1, 1), 'VALID')
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2. ** 4))
def test_avg_pool(self):
x = jnp.full((1, 3, 3, 1), 2.)
pool = lambda x: nn.avg_pool(x, (2, 2))
y = pool(x)
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.))
y_grad = jax.grad(lambda x: pool(x).sum())(x)
expected_grad = jnp.array([
[0.25, 0.5, 0.25],
[0.5, 1., 0.5],
[0.25, 0.5, 0.25],
]).reshape((1, 3, 3, 1))
np.testing.assert_allclose(y_grad, expected_grad)
def test_avg_pool_no_batch(self):
x = jnp.full((3, 3, 1), 2.)
pool = lambda x: nn.avg_pool(x, (2, 2))
y = pool(x)
np.testing.assert_allclose(y, np.full((2, 2, 1), 2.))
y_grad = jax.grad(lambda x: pool(x).sum())(x)
expected_grad = jnp.array([
[0.25, 0.5, 0.25],
[0.5, 1., 0.5],
[0.25, 0.5, 0.25],
]).reshape((3, 3, 1))
np.testing.assert_allclose(y_grad, expected_grad)
def test_max_pool(self):
...
When I ran this code with the edits from this PR there were no problems.
(flax_test) idongheon@idongheon-ui-MacBookAir flax % python ./tests/linen/linen_test.py
Running tests under Python 3.10.5: /Users/idongheon/miniforge3/envs/flax_test/bin/python
[ RUN ] IdsTest.test_hashable
[ OK ] IdsTest.test_hashable
[ RUN ] NormalizationTest.test_batch_norm
I0915 01:20:42.378859 4314512704 xla_bridge.py:169] Remote TPU is not linked into jax; skipping remote TPU.
I0915 01:20:42.379297 4314512704 xla_bridge.py:345] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
I0915 01:20:42.379443 4314512704 xla_bridge.py:345] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0915 01:20:42.379549 4314512704 xla_bridge.py:345] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0915 01:20:42.380259 4314512704 xla_bridge.py:345] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
[ OK ] NormalizationTest.test_batch_norm
[ RUN ] NormalizationTest.test_batch_norm_complex
[ OK ] NormalizationTest.test_batch_norm_complex
[ RUN ] NormalizationTest.test_batch_norm_multi_init
[ OK ] NormalizationTest.test_batch_norm_multi_init
[ RUN ] NormalizationTest.test_group_norm
[ OK ] NormalizationTest.test_group_norm
[ RUN ] NormalizationTest.test_group_norm_raises
[ OK ] NormalizationTest.test_group_norm_raises
[ RUN ] NormalizationTest.test_layer_norm0 (reduction_axes=-1)
[ OK ] NormalizationTest.test_layer_norm0 (reduction_axes=-1)
[ RUN ] NormalizationTest.test_layer_norm1 (reduction_axes=1)
[ OK ] NormalizationTest.test_layer_norm1 (reduction_axes=1)
[ RUN ] NormalizationTest.test_layer_norm2 (reduction_axes=(1, 2))
[ OK ] NormalizationTest.test_layer_norm2 (reduction_axes=(1, 2))
[ RUN ] PoolTest.test_avg_pool
[ OK ] PoolTest.test_avg_pool
[ RUN ] PoolTest.test_avg_pool_no_batch
[ OK ] PoolTest.test_avg_pool_no_batch
[ RUN ] PoolTest.test_max_pool
[ OK ] PoolTest.test_max_pool
[ RUN ] PoolTest.test_pool_custom_reduce
[ OK ] PoolTest.test_pool_custom_reduce
[ RUN ] RecurrentTest.test_complex_input_gru
[ OK ] RecurrentTest.test_complex_input_gru
[ RUN ] RecurrentTest.test_convlstm
[ OK ] RecurrentTest.test_convlstm
[ RUN ] RecurrentTest.test_gru
[ OK ] RecurrentTest.test_gru
[ RUN ] RecurrentTest.test_lstm
[ OK ] RecurrentTest.test_lstm
[ RUN ] RecurrentTest.test_optimized_lstm_cell_matches_regular
/Users/idongheon/miniforge3/envs/flax_test/lib/python3.10/site-packages/jax/test_util.py:44: FutureWarning: jax.test_util.check_eq is deprecated and will soon be removed.
warnings.warn(f"jax.test_util.{attr} is deprecated and will soon be removed.", FutureWarning)
[ OK ] RecurrentTest.test_optimized_lstm_cell_matches_regular
[ RUN ] StochasticTest.test_dropout
[ OK ] StochasticTest.test_dropout
[ RUN ] StochasticTest.test_dropout_rate_limits
[ OK ] StochasticTest.test_dropout_rate_limits
[ RUN ] StochasticTest.test_dropout_rate_stats
[ OK ] StochasticTest.test_dropout_rate_stats
----------------------------------------------------------------------
Ran 21 tests in 13.907s
OK
@dslisleedh can you create one of more tests under PoolTest
using this new argument? You can copy an existing test, rename it, and change it to use the new count_include_pad
argument. This would help immensely as these will be checked automatically before merging new code.
To @cgarciae
I add some codes to TestPool and here's code I tested.
from absl.testing import absltest, parameterized
from flax import ids
from flax import linen as nn
import jax
from jax import random
from jax import test_util as jtu
from jax.nn import initializers
import jax.numpy as jnp
import numpy as np
# Parse absl flags test_srcdir and test_tmpdir.
jax.config.parse_flags_with_absl()
class PoolTest(parameterized.TestCase):
def test_pool_custom_reduce(self):
x = jnp.full((1, 3, 3, 1), 2.)
mul_reduce = lambda x, y: x * y
y = nn.pooling.pool(x, 1., mul_reduce, (2, 2), (1, 1), 'VALID')
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2. ** 4))
@parameterized.parameters(
{'count_include_pad': True},
{'count_include_pad': False})
def test_avg_pool(self, count_include_pad):
x = jnp.full((1, 3, 3, 1), 2.)
pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad)
y = pool(x)
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.))
y_grad = jax.grad(lambda x: pool(x).sum())(x)
expected_grad = jnp.array([
[0.25, 0.5, 0.25],
[0.5, 1., 0.5],
[0.25, 0.5, 0.25],
]).reshape((1, 3, 3, 1))
np.testing.assert_allclose(y_grad, expected_grad)
@parameterized.parameters(
{'count_include_pad': True},
{'count_include_pad': False})
def test_avg_pool_no_batch(self, count_include_pad):
x = jnp.full((3, 3, 1), 2.)
pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad)
y = pool(x)
np.testing.assert_allclose(y, np.full((2, 2, 1), 2.))
y_grad = jax.grad(lambda x: pool(x).sum())(x)
expected_grad = jnp.array([
[0.25, 0.5, 0.25],
[0.5, 1., 0.5],
[0.25, 0.5, 0.25],
]).reshape((3, 3, 1))
np.testing.assert_allclose(y_grad, expected_grad)
def test_max_pool(self):
x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
pool = lambda x: nn.max_pool(x, (2, 2))
expected_y = jnp.array([
[4., 5.],
[7., 8.],
]).reshape((1, 2, 2, 1))
y = pool(x)
np.testing.assert_allclose(y, expected_y)
y_grad = jax.grad(lambda x: pool(x).sum())(x)
expected_grad = jnp.array([
[0., 0., 0.],
[0., 1., 1.],
[0., 1., 1.],
]).reshape((1, 3, 3, 1))
np.testing.assert_allclose(y_grad, expected_grad)
if __name__ == '__main__':
absltest.main()
When I tested with the previous code, an error occurred in the non-batch avg_pool, so I corrected the PR. Thanks for telling me to test it for sure.
Below is the test result with the modified code.
(flax_test) idongheon@idongheon-ui-MacBookAir flax % python pool_test.py
Running tests under Python 3.10.5: /Users/idongheon/miniforge3/envs/flax_test/bin/python
[ RUN ] PoolTest.test_avg_pool0 (count_include_pad=True)
I0920 20:09:02.610634 4371807552 xla_bridge.py:169] Remote TPU is not linked into jax; skipping remote TPU.
I0920 20:09:02.610964 4371807552 xla_bridge.py:345] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
I0920 20:09:02.611104 4371807552 xla_bridge.py:345] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0920 20:09:02.611212 4371807552 xla_bridge.py:345] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0920 20:09:02.611675 4371807552 xla_bridge.py:345] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
[ OK ] PoolTest.test_avg_pool0 (count_include_pad=True)
[ RUN ] PoolTest.test_avg_pool1 (count_include_pad=False)
[ OK ] PoolTest.test_avg_pool1 (count_include_pad=False)
[ RUN ] PoolTest.test_avg_pool_no_batch0 (count_include_pad=True)
[ OK ] PoolTest.test_avg_pool_no_batch0 (count_include_pad=True)
[ RUN ] PoolTest.test_avg_pool_no_batch1 (count_include_pad=False)
[ OK ] PoolTest.test_avg_pool_no_batch1 (count_include_pad=False)
[ RUN ] PoolTest.test_max_pool
[ OK ] PoolTest.test_max_pool
[ RUN ] PoolTest.test_pool_custom_reduce
[ OK ] PoolTest.test_pool_custom_reduce
----------------------------------------------------------------------
Ran 6 tests in 1.586s
OK
Thank you.
Awesome @dslisleedh! Can you commit changes to the tests?
To @cgarciae
Sure :)
@dslisleedh can you add this test? None of the other tests used padding="SAME"
which is the more interesting case.
@parameterized.parameters(
{'count_include_pad': True},
{'count_include_pad': False})
def test_avg_pool_padding_same(self, count_include_pad):
x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1))
pool = lambda x: nn.avg_pool(x, (2, 2), padding="SAME", count_include_pad=count_include_pad)
y = pool(x)
if count_include_pad:
expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape((1, 2, 2, 1))
else:
expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape((1, 2, 2, 1))
np.testing.assert_allclose(y, expected_y)
@cgarciae Oh, I forgot that. Thank you.
and here is result of your code.
(flax_test) idongheon@idongheon-ui-MacBookAir flax % python ./tests/linen/linen_test.py
Running tests under Python 3.10.5: /Users/idongheon/miniforge3/envs/flax_test/bin/python
[ RUN ] IdsTest.test_hashable
[ OK ] IdsTest.test_hashable
[ RUN ] NormalizationTest.test_batch_norm
I0921 22:23:30.726801 4309613888 xla_bridge.py:169] Remote TPU is not linked into jax; skipping remote TPU.
I0921 22:23:30.727137 4309613888 xla_bridge.py:345] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
I0921 22:23:30.727280 4309613888 xla_bridge.py:345] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0921 22:23:30.727386 4309613888 xla_bridge.py:345] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0921 22:23:30.727836 4309613888 xla_bridge.py:345] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
[ OK ] NormalizationTest.test_batch_norm
[ RUN ] NormalizationTest.test_batch_norm_complex
[ OK ] NormalizationTest.test_batch_norm_complex
[ RUN ] NormalizationTest.test_batch_norm_multi_init
[ OK ] NormalizationTest.test_batch_norm_multi_init
[ RUN ] NormalizationTest.test_group_norm
[ OK ] NormalizationTest.test_group_norm
[ RUN ] NormalizationTest.test_group_norm_raises
[ OK ] NormalizationTest.test_group_norm_raises
[ RUN ] NormalizationTest.test_layer_norm0 (reduction_axes=-1)
[ OK ] NormalizationTest.test_layer_norm0 (reduction_axes=-1)
[ RUN ] NormalizationTest.test_layer_norm1 (reduction_axes=1)
[ OK ] NormalizationTest.test_layer_norm1 (reduction_axes=1)
[ RUN ] NormalizationTest.test_layer_norm2 (reduction_axes=(1, 2))
[ OK ] NormalizationTest.test_layer_norm2 (reduction_axes=(1, 2))
[ RUN ] PoolTest.test_avg_pool0 (count_include_pad=True)
[ OK ] PoolTest.test_avg_pool0 (count_include_pad=True)
[ RUN ] PoolTest.test_avg_pool1 (count_include_pad=False)
[ OK ] PoolTest.test_avg_pool1 (count_include_pad=False)
[ RUN ] PoolTest.test_avg_pool_no_batch0 (count_include_pad=True)
[ OK ] PoolTest.test_avg_pool_no_batch0 (count_include_pad=True)
[ RUN ] PoolTest.test_avg_pool_no_batch1 (count_include_pad=False)
[ OK ] PoolTest.test_avg_pool_no_batch1 (count_include_pad=False)
[ RUN ] PoolTest.test_avg_pool_padding_same0 (count_include_pad=True)
[ OK ] PoolTest.test_avg_pool_padding_same0 (count_include_pad=True)
[ RUN ] PoolTest.test_avg_pool_padding_same1 (count_include_pad=False)
[ OK ] PoolTest.test_avg_pool_padding_same1 (count_include_pad=False)
[ RUN ] PoolTest.test_max_pool
[ OK ] PoolTest.test_max_pool
[ RUN ] PoolTest.test_pool_custom_reduce
[ OK ] PoolTest.test_pool_custom_reduce
[ RUN ] RecurrentTest.test_complex_input_gru
[ OK ] RecurrentTest.test_complex_input_gru
[ RUN ] RecurrentTest.test_convlstm
[ OK ] RecurrentTest.test_convlstm
[ RUN ] RecurrentTest.test_gru
[ OK ] RecurrentTest.test_gru
[ RUN ] RecurrentTest.test_lstm
[ OK ] RecurrentTest.test_lstm
[ RUN ] RecurrentTest.test_optimized_lstm_cell_matches_regular
/Users/idongheon/miniforge3/envs/flax_test/lib/python3.10/site-packages/jax/test_util.py:44: FutureWarning: jax.test_util.check_eq is deprecated and will soon be removed.
warnings.warn(f"jax.test_util.{attr} is deprecated and will soon be removed.", FutureWarning)
[ OK ] RecurrentTest.test_optimized_lstm_cell_matches_regular
[ RUN ] StochasticTest.test_dropout
[ OK ] StochasticTest.test_dropout
[ RUN ] StochasticTest.test_dropout_rate_limits
[ OK ] StochasticTest.test_dropout_rate_limits
[ RUN ] StochasticTest.test_dropout_rate_stats
[ OK ] StochasticTest.test_dropout_rate_stats
----------------------------------------------------------------------
Ran 25 tests in 14.023s
OK