dm-haiku icon indicating copy to clipboard operation
dm-haiku copied to clipboard

Test tolerances are too tight, resulting in 6 test failures

Open samuela opened this issue 3 years ago • 7 comments

I'm seeing the following failures:

=========================== short test summary info ============================
FAILED haiku/_src/integration/jax_transforms_test.py::JaxTransformsTest::test_jit_Recurrent(Conv1DLSTM, dynamic_unroll)
FAILED haiku/_src/integration/jax_transforms_test.py::JaxTransformsTest::test_jit_Recurrent(Conv1DLSTM, static_unroll)
FAILED haiku/_src/integration/jax_transforms_test.py::JaxTransformsTest::test_jit_Recurrent(Conv2DLSTM, dynamic_unroll)
FAILED haiku/_src/integration/jax_transforms_test.py::JaxTransformsTest::test_jit_Recurrent(Conv2DLSTM, static_unroll)
FAILED haiku/_src/integration/jax_transforms_test.py::JaxTransformsTest::test_jit_Recurrent(Conv3DLSTM, dynamic_unroll)
FAILED haiku/_src/integration/jax_transforms_test.py::JaxTransformsTest::test_jit_Recurrent(Conv3DLSTM, static_unroll)
===== 6 failed, 2461 passed, 143 skipped, 57 warnings in 277.28s (0:04:37) =====

when running on an m6a.4xlarge EC2 instance (3rd generation AMD EPYC processors). For a full log see here.

It appears to me that the tests are working as intended, but the tolerances are too tight, eg.

>   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
E   AssertionError: 
E   Not equal to tolerance rtol=1e-07, atol=1e-05
E   
E   Mismatched elements: 91 / 108 (84.3%)
E   Max absolute difference: 0.001465
E   Max relative difference: 0.0834
E    x: array([[[ 0.1921  , -0.168   , -0.425   , -0.1724  ,  0.1691  ,
E            -0.11523 , -0.3555  ,  0.4094  ,  0.2556  ,  0.06256 ,
E             0.187   ,  0.4253  ],...
E    y: array([[[ 0.193   , -0.1675  , -0.424   , -0.1721  ,  0.1694  ,
E            -0.1148  , -0.3547  ,  0.4097  ,  0.2566  ,  0.0632  ,
E             0.1871  ,  0.426   ],...

To reproduce:

  • Checkout https://github.com/NixOS/nixpkgs/commit/672f7edb4d7a14e6d2e8026f49c66be270818b0a on an m6a EC2 instance.
  • Run nix-build -A python3Packages.dm-haiku

This is with

  • Python 3.9.11
  • dm-haiku 0.0.6
  • jax 0.3.4
  • jaxlib 0.3.0
  • openblas 0.3.20
  • TF 2.8.0
  • Chex at https://github.com/deepmind/chex/commit/5adc10e0b4218f8ec775567fca38b68bbad42a3a.

samuela avatar Apr 11 '22 20:04 samuela

I'm also seeing two other errors when I skip test_jit_Recurrent tests:

=========================== short test summary info ============================
FAILED haiku/_src/integration/numpy_inputs_test.py::NumpyInputsTest::test_numpy_and_jax_results_close_ConvND_np_inputs_not_np_params_close_over_params
FAILED haiku/_src/integration/numpy_inputs_test.py::NumpyInputsTest::test_numpy_and_jax_results_close_ConvND_np_inputs_np_params_close_over_params
===== 2 failed, 2449 passed, 143 skipped, 57 warnings in 290.09s (0:04:50) =====

It appears to be a dtype mismatch:

E     TypeError: lax.conv_general_dilated requires arguments to have the same dtypes, got float32, float16.

EDIT: these appear to be resolved by upgrading Chex to 0.1.2.

samuela avatar Apr 11 '22 20:04 samuela

Hi @samuela, our tests seem to pass fine on my workstation, github actions and our internal CI, so I wonder if this is something specific to your setup (e.g. not using the latest version of libraries like JAX, NumPy etc and thus getting different numerics.

Internally we run our tests on a variety of platforms (CPU, NVIDIA GPUs, TPUs) which all have fairly different numerics, so I am reasonably confident that our tolerances should be sufficient (although we have had to adjust them in the past as JAX/XLA changed things).

Our CI testing is defined here and here are the versions of all the dependencies we used with Python 3.9: https://gist.github.com/tomhennigan/47d9de2356acca2069adc211ed2b2797.

tomhennigan avatar Apr 12 '22 07:04 tomhennigan

Looking through the test versions it seems that it's using jax v0.2.28 and jaxlib v0.1.76 that are about 2.5 months, and some .y. versions old. Current versions are v0.3.6 and v0.3.5 respectively. Perhaps that's the cause? Do you know if the test suite passes on the latest jax/jaxlib?

samuela avatar Apr 13 '22 02:04 samuela

Internally we're testing with JAX/XLA at HEAD so I'm fairly confident they pass with the latest stable release too. I'll bump the versions we're using on GHA regardless in #370 since we should be running with something more recent (I'll stick with 0.3.5 since we have corresponding jaxlib release).

tomhennigan avatar Apr 13 '22 19:04 tomhennigan

Huh interesting! Do you mind if I ask if you're running on Intel or AMD? Using MKL or OpenBLAS or something else?

I'm struggling to think of what else could affect this...

samuela avatar Apr 13 '22 20:04 samuela

Apologies for the delay, I believe the majority of our internal testing is done on Intel CPUs however on GitHub I'm not really sure (we use whatever the actions runner lands on). I don't think XLA:CPU supports MKL and I believe (although I am not very familiar) that it uses Eigen rather than a blas library (because XLA fuses multiple operations together).

I'm really not sure where the discrepancy is coming from..

tomhennigan avatar May 05 '22 09:05 tomhennigan

Interesting... I'm guessing it's an AMD vs Intel discrepancy

samuela avatar May 05 '22 19:05 samuela