Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

make test-convergence get Number of mismatched elements

Open Dexterai opened this issue 5 months ago • 7 comments

🐛 Describe the bug

I tried running make test-convergence on a single A100 GPU,

get the failed like this

FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma3_text-32-0.0001-dtype3-1e-08-0.0001-0.005-1e-05-0.005-1e-05] - AssertionError: Number of mismatched elements: 11

Reproduce

make test-convergence
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models.py
============================================================= test session starts =============================================================
platform linux -- Python 3.12.3, pytest-8.4.1, pluggy-1.6.0
rootdir: /root/lanyun-fs/Liger-Kernel
configfile: pyproject.toml
plugins: rerunfailures-15.1, xdist-3.7.0
collecting ... 
------------------------------------------------------------- live log collection -------------------------------------------------------------
INFO     datasets:config.py:54 PyTorch version 2.7.1 available.
collected 17 items                                                                                                                            

test/convergence/fp32/test_mini_models.py::test_mini_model[mini_llama3-32-0.0001-dtype0-1e-08-2e-05-0.0001-1e-05-0.005-1e-05] PASSED    [  5%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_llava-32-0.0001-dtype1-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED      [ 11%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_mllama-32-0.0001-dtype2-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED     [ 17%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma3_text-32-0.0001-dtype3-1e-08-0.0001-0.005-1e-05-0.005-1e-05] FAILED [ 23%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen2-32-0.0001-dtype4-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED      [ 29%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen3-32-0.0001-dtype5-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED      [ 35%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen3_moe-32-0.0001-dtype6-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED  [ 41%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen2_vl-32-0.0001-dtype7-1e-05-0.1-1-0.1-0.005-1e-05] PASSED           [ 47%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen2_5_vl-32-0.0001-dtype8-1e-05-0.1-3-0.1-0.005-1e-05] PASSED         [ 52%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_olmo2-32-0.0001-dtype9-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED      [ 58%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_glm4-32-0.0001-dtype10-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED      [ 64%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_phi3-32-0.0001-dtype11-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED      [ 70%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_mistral-32-0.0001-dtype12-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED   [ 76%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma1-32-0.0001-dtype13-1e-08-0.0001-0.005-0.01-0.005-1e-05] FAILED    [ 82%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype14-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED [ 88%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma2-32-0.0001-dtype15-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED   [ 94%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_granite3-32-0.0001-dtype16-1e-08-0.0001-0.05-0.0001-0.005-1e-05] PASSED [100%]

================================================================== FAILURES ===================================================================
___________________________ test_mini_model[mini_gemma3_text-32-0.0001-dtype3-1e-08-0.0001-0.005-1e-05-0.005-1e-05] ___________________________

model_name = 'mini_gemma3_text', num_steps = 32, lr = 0.0001, dtype = torch.float32, loss_atol = 1e-08, loss_rtol = 0.0001
logprobs_atol = 0.005, logprobs_rtol = 1e-05, param_atol = 0.005, param_rtol = 1e-05

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol",
        [
            ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_llava",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not LLAVA_AVAILABLE,
                    reason="LLaVa not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_mllama",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not MLLAMA_AVAILABLE,
                    reason="Mllama not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_gemma3_text",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-4,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GEMMA3_AVAILABLE,
                    reason="Gemma3 not available in this version of transformers",
                ),
            ),
            ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_qwen3",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN3_AVAILABLE,
                    reason="Qwen3 not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_qwen3_moe",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN3_AVAILABLE,
                    reason="Qwen3 not available in this version of transformers",
                ),
            ),
            pytest.param(  # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                1,  # 5e-3,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_VL_AVAILABLE,
                    reason="Qwen2-VL not available in this version of transformers",
                ),
            ),
            # TODO: logits tolerances are significantly larger than the other tests, need to investigate
            pytest.param(  # qwen2_5_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_5_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                3,  # 5e-3,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_5_VL_AVAILABLE,
                    reason="Qwen2.5-VL not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_olmo2",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not OLMO2_AVAILABLE,
                    reason="OLMO2 not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_glm4",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GLM4_AVAILABLE,
                    reason="Glm4 not available in this version of transformers",
                ),
            ),
            ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            # TODO: mixtral is flaky so disable the test for now
            # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
            # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
            ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-2, 5e-3, 1e-5),
            ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_granite3",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-4,
                5e-2,  # 5e-3
                1e-4,  # 1e-5
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GRANITE_AVAILABLE,
                    reason="Granite not available in this version of transformers",
                ),
            ),
        ],
    )
    def test_mini_model(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logprobs_atol,
        logprobs_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
    
        expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
        actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
    
        # Compare every step of the loss
        assert_verbose_allclose(
            torch.tensor([expected_output["loss"]]),
            torch.tensor([actual_output["loss"]]),
            atol=loss_atol,
            rtol=loss_rtol,
        )
    
        # Compare the topk logprobs from evaluation step
        if expected_output["topk_logprobs"] is not None and actual_output["topk_logprobs"] is not None:
>           assert_verbose_allclose(
                expected_output["topk_logprobs"],
                actual_output["topk_logprobs"],
                atol=logprobs_atol,
                rtol=logprobs_rtol,
            )

test/convergence/fp32/test_mini_models.py:1067: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

tensor1 = tensor([[[-9.1038e-03, -4.7121e+00, -1.5932e+01,  ..., -1.6516e+01,
          -1.6565e+01, -1.6593e+01],
         [-9....     [-4.9854e-01, -4.1698e+00, -4.3570e+00,  ..., -7.2072e+00,
          -7.3368e+00, -7.3456e+00]]], device='cuda:0')
tensor2 = tensor([[[-9.1016e-03, -4.7124e+00, -1.5933e+01,  ..., -1.6516e+01,
          -1.6566e+01, -1.6593e+01],
         [-9....     [-4.9861e-01, -4.1701e+00, -4.3570e+00,  ..., -7.2076e+00,
          -7.3366e+00, -7.3458e+00]]], device='cuda:0')
rtol = 1e-05, atol = 0.005, max_print = 5

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
            if num_mismatched > max_print:
                mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
    
>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 11
E           Mismatch at index (1, 126, 17): tensor1[(1, 126, 17)] = -6.848010063171387, tensor2[(1, 126, 17)] = -6.8533782958984375
E           Mismatch at index (13, 118, 3): tensor1[(13, 118, 3)] = -3.215667963027954, tensor2[(13, 118, 3)] = -3.221395492553711
E           Mismatch at index (13, 118, 10): tensor1[(13, 118, 10)] = -5.487461090087891, tensor2[(13, 118, 10)] = -5.481538772583008
E           Mismatch at index (13, 118, 13): tensor1[(13, 118, 13)] = -5.899518013000488, tensor2[(13, 118, 13)] = -5.906171798706055
E           Mismatch at index (14, 125, 1): tensor1[(14, 125, 1)] = -3.3704092502593994, tensor2[(14, 125, 1)] = -3.361818313598633
E           ... and 6 more mismatched elements.

test/utils.py:130: AssertionError
------------------------------------------------------------ Captured stdout call -------------------------------------------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 5.5114898681640625
Step 1, Loss: 0.7358521819114685
Step 2, Loss: 0.897828221321106
Step 3, Loss: 0.7267115712165833
Step 4, Loss: 0.6775891184806824
Step 5, Loss: 0.7813704609870911
Step 6, Loss: 0.6756245493888855
Step 7, Loss: 0.9666643738746643
Step 8, Loss: 0.6444056034088135
Step 9, Loss: 0.7450430393218994
Step 10, Loss: 0.7641883492469788
Step 11, Loss: 0.40847837924957275
Step 12, Loss: 0.7006322145462036
Step 13, Loss: 0.6850311756134033
Step 14, Loss: 0.6146025061607361
Step 15, Loss: 0.7307778000831604
Step 16, Loss: 0.7162883281707764
Step 17, Loss: 0.8030434250831604
Step 18, Loss: 0.7284414172172546
Step 19, Loss: 0.7638292908668518
Step 20, Loss: 0.6066721081733704
Step 21, Loss: 0.36790019273757935
Step 22, Loss: 0.48964276909828186
Step 23, Loss: 0.4520471692085266
Step 24, Loss: 0.45229068398475647
Step 25, Loss: 0.4304249584674835
Step 26, Loss: 0.3818521797657013
Step 27, Loss: 0.5118127465248108
Step 28, Loss: 0.6146227121353149
Step 29, Loss: 0.38547804951667786
Step 30, Loss: 0.7631076574325562
Step 31, Loss: 0.4806024432182312
Eval Loss: 0.618675172328949
Liger kernel patches have been reverted.
Step 0, Loss: 5.511490821838379
Step 1, Loss: 0.7358517646789551
Step 2, Loss: 0.897828221321106
Step 3, Loss: 0.7267120480537415
Step 4, Loss: 0.6775884628295898
Step 5, Loss: 0.7813707590103149
Step 6, Loss: 0.6756248474121094
Step 7, Loss: 0.9666653871536255
Step 8, Loss: 0.644402801990509
Step 9, Loss: 0.7450424432754517
Step 10, Loss: 0.7641878128051758
Step 11, Loss: 0.40847688913345337
Step 12, Loss: 0.7006286382675171
Step 13, Loss: 0.6850258111953735
Step 14, Loss: 0.6146000027656555
Step 15, Loss: 0.7307751178741455
Step 16, Loss: 0.7162874341011047
Step 17, Loss: 0.8030444979667664
Step 18, Loss: 0.7284409403800964
Step 19, Loss: 0.7638280391693115
Step 20, Loss: 0.6066715121269226
Step 21, Loss: 0.367902547121048
Step 22, Loss: 0.48964276909828186
Step 23, Loss: 0.45204880833625793
Step 24, Loss: 0.452290415763855
Step 25, Loss: 0.43042486906051636
Step 26, Loss: 0.3818501830101013
Step 27, Loss: 0.5118110179901123
Step 28, Loss: 0.6146328449249268
Step 29, Loss: 0.3854790925979614
Step 30, Loss: 0.7631104588508606
Step 31, Loss: 0.48059630393981934
Eval Loss: 0.6186762452125549
Liger kernel patches have been reverted.
_____________________________ test_mini_model[mini_gemma1-32-0.0001-dtype13-1e-08-0.0001-0.005-0.01-0.005-1e-05] ______________________________

model_name = 'mini_gemma1', num_steps = 32, lr = 0.0001, dtype = torch.float32, loss_atol = 1e-08, loss_rtol = 0.0001, logprobs_atol = 0.005
logprobs_rtol = 0.01, param_atol = 0.005, param_rtol = 1e-05

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol",
        [
            ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_llava",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not LLAVA_AVAILABLE,
                    reason="LLaVa not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_mllama",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not MLLAMA_AVAILABLE,
                    reason="Mllama not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_gemma3_text",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-4,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GEMMA3_AVAILABLE,
                    reason="Gemma3 not available in this version of transformers",
                ),
            ),
            ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_qwen3",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN3_AVAILABLE,
                    reason="Qwen3 not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_qwen3_moe",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN3_AVAILABLE,
                    reason="Qwen3 not available in this version of transformers",
                ),
            ),
            pytest.param(  # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                1,  # 5e-3,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_VL_AVAILABLE,
                    reason="Qwen2-VL not available in this version of transformers",
                ),
            ),
            # TODO: logits tolerances are significantly larger than the other tests, need to investigate
            pytest.param(  # qwen2_5_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_5_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                3,  # 5e-3,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_5_VL_AVAILABLE,
                    reason="Qwen2.5-VL not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_olmo2",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not OLMO2_AVAILABLE,
                    reason="OLMO2 not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_glm4",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GLM4_AVAILABLE,
                    reason="Glm4 not available in this version of transformers",
                ),
            ),
            ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            # TODO: mixtral is flaky so disable the test for now
            # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
            # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
            ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-2, 5e-3, 1e-5),
            ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_granite3",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-4,
                5e-2,  # 5e-3
                1e-4,  # 1e-5
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GRANITE_AVAILABLE,
                    reason="Granite not available in this version of transformers",
                ),
            ),
        ],
    )
    def test_mini_model(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logprobs_atol,
        logprobs_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
    
        expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
        actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
    
        # Compare every step of the loss
>       assert_verbose_allclose(
            torch.tensor([expected_output["loss"]]),
            torch.tensor([actual_output["loss"]]),
            atol=loss_atol,
            rtol=loss_rtol,
        )

test/convergence/fp32/test_mini_models.py:1058: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

tensor1 = tensor([[1.0585, 0.7503, 0.9095, 0.7038, 0.6323, 0.7315, 0.7252, 0.8883, 0.6050,
         0.7210, 0.7658, 0.4156, 0.67...336, 0.5679, 0.3464, 0.4418, 0.3993, 0.4071, 0.3988, 0.3726,
         0.4775, 0.5491, 0.3414, 0.6705, 0.4464, 0.5609]])
tensor2 = tensor([[1.0586, 0.7503, 0.9095, 0.7038, 0.6323, 0.7316, 0.7248, 0.8882, 0.6049,
         0.7208, 0.7655, 0.4154, 0.67...332, 0.5675, 0.3460, 0.4413, 0.3989, 0.4071, 0.3997, 0.3721,
         0.4769, 0.5490, 0.3415, 0.6706, 0.4464, 0.5609]])
rtol = 0.0001, atol = 1e-08, max_print = 5

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
            if num_mismatched > max_print:
                mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
    
>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 23
E           Mismatch at index (0, 6): tensor1[(0, 6)] = 0.725193202495575, tensor2[(0, 6)] = 0.7248455286026001
E           Mismatch at index (0, 8): tensor1[(0, 8)] = 0.6050088405609131, tensor2[(0, 8)] = 0.6048830151557922
E           Mismatch at index (0, 9): tensor1[(0, 9)] = 0.7209844589233398, tensor2[(0, 9)] = 0.7208169102668762
E           Mismatch at index (0, 10): tensor1[(0, 10)] = 0.7657538652420044, tensor2[(0, 10)] = 0.7655386328697205
E           Mismatch at index (0, 11): tensor1[(0, 11)] = 0.4156359136104584, tensor2[(0, 11)] = 0.4153878092765808
E           ... and 18 more mismatched elements.

test/utils.py:130: AssertionError
------------------------------------------------------------ Captured stdout call -------------------------------------------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 1.0585311651229858
Step 1, Loss: 0.7502652406692505
Step 2, Loss: 0.9094656109809875
Step 3, Loss: 0.7037572860717773
Step 4, Loss: 0.6322704553604126
Step 5, Loss: 0.7315280437469482
Step 6, Loss: 0.725193202495575
Step 7, Loss: 0.8882756233215332
Step 8, Loss: 0.6050088405609131
Step 9, Loss: 0.7209844589233398
Step 10, Loss: 0.7657538652420044
Step 11, Loss: 0.4156359136104584
Step 12, Loss: 0.6751559376716614
Step 13, Loss: 0.6617658734321594
Step 14, Loss: 0.5803778767585754
Step 15, Loss: 0.7222874760627747
Step 16, Loss: 0.708590030670166
Step 17, Loss: 0.7739989757537842
Step 18, Loss: 0.6725180149078369
Step 19, Loss: 0.7336098551750183
Step 20, Loss: 0.5679132342338562
Step 21, Loss: 0.3464488685131073
Step 22, Loss: 0.4417855739593506
Step 23, Loss: 0.3993169963359833
Step 24, Loss: 0.4070657789707184
Step 25, Loss: 0.39878153800964355
Step 26, Loss: 0.37256762385368347
Step 27, Loss: 0.47752323746681213
Step 28, Loss: 0.5490694046020508
Step 29, Loss: 0.341403603553772
Step 30, Loss: 0.670497477054596
Step 31, Loss: 0.4463615119457245
Eval Loss: 0.5608715415000916
Liger kernel patches have been reverted.
Step 0, Loss: 1.058566927909851
Step 1, Loss: 0.7502727508544922
Step 2, Loss: 0.9094695448875427
Step 3, Loss: 0.7037727236747742
Step 4, Loss: 0.6322908997535706
Step 5, Loss: 0.7315629124641418
Step 6, Loss: 0.7248455286026001
Step 7, Loss: 0.8882305026054382
Step 8, Loss: 0.6048830151557922
Step 9, Loss: 0.7208169102668762
Step 10, Loss: 0.7655386328697205
Step 11, Loss: 0.4153878092765808
Step 12, Loss: 0.6748801469802856
Step 13, Loss: 0.661463737487793
Step 14, Loss: 0.5800580382347107
Step 15, Loss: 0.721951961517334
Step 16, Loss: 0.7082199454307556
Step 17, Loss: 0.773615300655365
Step 18, Loss: 0.6721071600914001
Step 19, Loss: 0.733183741569519
Step 20, Loss: 0.567453145980835
Step 21, Loss: 0.34597691893577576
Step 22, Loss: 0.44131916761398315
Step 23, Loss: 0.39894071221351624
Step 24, Loss: 0.40705054998397827
Step 25, Loss: 0.39973393082618713
Step 26, Loss: 0.3721126317977905
Step 27, Loss: 0.4769252836704254
Step 28, Loss: 0.5490176677703857
Step 29, Loss: 0.34152016043663025
Step 30, Loss: 0.6706250309944153
Step 31, Loss: 0.44644686579704285
Eval Loss: 0.5609042644500732
Liger kernel patches have been reverted.
=========================================================== short test summary info ===========================================================
FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma3_text-32-0.0001-dtype3-1e-08-0.0001-0.005-1e-05-0.005-1e-05] - AssertionError: Number of mismatched elements: 11
FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma1-32-0.0001-dtype13-1e-08-0.0001-0.005-0.01-0.005-1e-05] - AssertionError: Number of mismatched elements: 23
============================================= 2 failed, 15 passed, 1 warning in 156.15s (0:02:36) =============================================
make: *** [Makefile:23: test-convergence] Error 1

Versions

System Information

  • OS: Linux 5.15.0-112-generic (#122-Ubuntu SMP Thu May 23 07:48:21 UTC 2024)
  • Architecture: x86_64

Python Environment

  • Python Version: 3.12.3 | packaged by Anaconda, Inc. | (main, May 6 2024, 19:46:43) [GCC 11.2.0]
  • Python Implementation: CPython

Package Versions

  • torch: 2.7.1
  • liger-kernel: 0.5.10
  • transformers: 4.53.0
  • datasets: 3.6.0
  • pytest: 8.4.1
  • packaging: 25.0
  • accelerate: Not installed
  • numpy: 2.3.1
  • safetensors: 0.5.3
  • tokenizers: 0.21.2

CUDA Version

  • Build cuda_12.1.r12.1/compiler.32688072_0

Dexterai avatar Jun 27 '25 06:06 Dexterai