gemma icon indicating copy to clipboard operation
gemma copied to clipboard

Issue with unit tests on NVIdia V100 (GPU)

Open DwarKapex opened this issue 9 months ago • 1 comments

Hi everyone.

I see the issue when run unit tests on NVidia V100 (GPU). Here is the link for more details.

Briefly:

=========================== short test summary info ============================
FAILED opt/gemma/gemma/layers_test.py::EinsumTest::test_rmsnorm0 - AssertionE...
FAILED opt/gemma/gemma/modules_test.py::FeedForwardTest::test_ffw0 - Assertio...
FAILED opt/gemma/gemma/positional_embeddings_test.py::PositionalEmbeddingsTest::test_adds_positional_embeddings0
================== 3 failed, 13 passed, 2 warnings in 35.61s ===================```

Some details:
1. test_rmsnorm0 ([link](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/9099672951/job/25013672689?pr=590#step:7:348)). Looks like this is an EPS-error. I don't think it's a good idea to compare expected array of floats with resulted one. Is it possible to add some discrepancy between expected and calculated arrays? Like `rtol=1e-5, atol=1e-5`?
2. test_ffw0 ([link](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/9099672951/job/25013672689?pr=590#step:7:415)) is similar to previous one.
3. test_adds_positional_embeddings0 [link](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/9099672951/job/25013672689?pr=590#step:7:486). IMHO, jax cannot digest is correctly on GPUs

Thank you for your help! Hope it's fixable! =)

DwarKapex avatar May 15 '24 21:05 DwarKapex