dm-haiku
dm-haiku copied to clipboard
Test tolerances are too tight, resulting in 6 test failures
I'm seeing the following failures:
=========================== short test summary info ============================
FAILED haiku/_src/integration/jax_transforms_test.py::JaxTransformsTest::test_jit_Recurrent(Conv1DLSTM, dynamic_unroll)
FAILED haiku/_src/integration/jax_transforms_test.py::JaxTransformsTest::test_jit_Recurrent(Conv1DLSTM, static_unroll)
FAILED haiku/_src/integration/jax_transforms_test.py::JaxTransformsTest::test_jit_Recurrent(Conv2DLSTM, dynamic_unroll)
FAILED haiku/_src/integration/jax_transforms_test.py::JaxTransformsTest::test_jit_Recurrent(Conv2DLSTM, static_unroll)
FAILED haiku/_src/integration/jax_transforms_test.py::JaxTransformsTest::test_jit_Recurrent(Conv3DLSTM, dynamic_unroll)
FAILED haiku/_src/integration/jax_transforms_test.py::JaxTransformsTest::test_jit_Recurrent(Conv3DLSTM, static_unroll)
===== 6 failed, 2461 passed, 143 skipped, 57 warnings in 277.28s (0:04:37) =====
when running on an m6a.4xlarge EC2 instance (3rd generation AMD EPYC processors). For a full log see here.
It appears to me that the tests are working as intended, but the tolerances are too tight, eg.
> return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
E AssertionError:
E Not equal to tolerance rtol=1e-07, atol=1e-05
E
E Mismatched elements: 91 / 108 (84.3%)
E Max absolute difference: 0.001465
E Max relative difference: 0.0834
E x: array([[[ 0.1921 , -0.168 , -0.425 , -0.1724 , 0.1691 ,
E -0.11523 , -0.3555 , 0.4094 , 0.2556 , 0.06256 ,
E 0.187 , 0.4253 ],...
E y: array([[[ 0.193 , -0.1675 , -0.424 , -0.1721 , 0.1694 ,
E -0.1148 , -0.3547 , 0.4097 , 0.2566 , 0.0632 ,
E 0.1871 , 0.426 ],...
To reproduce:
- Checkout https://github.com/NixOS/nixpkgs/commit/672f7edb4d7a14e6d2e8026f49c66be270818b0a on an m6a EC2 instance.
- Run
nix-build -A python3Packages.dm-haiku
This is with
- Python 3.9.11
- dm-haiku 0.0.6
- jax 0.3.4
- jaxlib 0.3.0
- openblas 0.3.20
- TF 2.8.0
- Chex at https://github.com/deepmind/chex/commit/5adc10e0b4218f8ec775567fca38b68bbad42a3a.
I'm also seeing two other errors when I skip test_jit_Recurrent tests:
=========================== short test summary info ============================
FAILED haiku/_src/integration/numpy_inputs_test.py::NumpyInputsTest::test_numpy_and_jax_results_close_ConvND_np_inputs_not_np_params_close_over_params
FAILED haiku/_src/integration/numpy_inputs_test.py::NumpyInputsTest::test_numpy_and_jax_results_close_ConvND_np_inputs_np_params_close_over_params
===== 2 failed, 2449 passed, 143 skipped, 57 warnings in 290.09s (0:04:50) =====
It appears to be a dtype mismatch:
E TypeError: lax.conv_general_dilated requires arguments to have the same dtypes, got float32, float16.
EDIT: these appear to be resolved by upgrading Chex to 0.1.2.
Hi @samuela, our tests seem to pass fine on my workstation, github actions and our internal CI, so I wonder if this is something specific to your setup (e.g. not using the latest version of libraries like JAX, NumPy etc and thus getting different numerics.
Internally we run our tests on a variety of platforms (CPU, NVIDIA GPUs, TPUs) which all have fairly different numerics, so I am reasonably confident that our tolerances should be sufficient (although we have had to adjust them in the past as JAX/XLA changed things).
Our CI testing is defined here and here are the versions of all the dependencies we used with Python 3.9: https://gist.github.com/tomhennigan/47d9de2356acca2069adc211ed2b2797.
Looking through the test versions it seems that it's using jax v0.2.28 and jaxlib v0.1.76 that are about 2.5 months, and some .y. versions old. Current versions are v0.3.6 and v0.3.5 respectively. Perhaps that's the cause? Do you know if the test suite passes on the latest jax/jaxlib?
Internally we're testing with JAX/XLA at HEAD so I'm fairly confident they pass with the latest stable release too. I'll bump the versions we're using on GHA regardless in #370 since we should be running with something more recent (I'll stick with 0.3.5 since we have corresponding jaxlib release).
Huh interesting! Do you mind if I ask if you're running on Intel or AMD? Using MKL or OpenBLAS or something else?
I'm struggling to think of what else could affect this...
Apologies for the delay, I believe the majority of our internal testing is done on Intel CPUs however on GitHub I'm not really sure (we use whatever the actions runner lands on). I don't think XLA:CPU supports MKL and I believe (although I am not very familiar) that it uses Eigen rather than a blas library (because XLA fuses multiple operations together).
I'm really not sure where the discrepancy is coming from..
Interesting... I'm guessing it's an AMD vs Intel discrepancy