Liger-Kernel
Liger-Kernel copied to clipboard
make test-convergence get Number of mismatched elements
🐛 Describe the bug
I tried running make test-convergence on a single A100 GPU,
get the failed like this
FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma3_text-32-0.0001-dtype3-1e-08-0.0001-0.005-1e-05-0.005-1e-05] - AssertionError: Number of mismatched elements: 11
Reproduce
make test-convergence
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models.py
============================================================= test session starts =============================================================
platform linux -- Python 3.12.3, pytest-8.4.1, pluggy-1.6.0
rootdir: /root/lanyun-fs/Liger-Kernel
configfile: pyproject.toml
plugins: rerunfailures-15.1, xdist-3.7.0
collecting ...
------------------------------------------------------------- live log collection -------------------------------------------------------------
INFO datasets:config.py:54 PyTorch version 2.7.1 available.
collected 17 items
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_llama3-32-0.0001-dtype0-1e-08-2e-05-0.0001-1e-05-0.005-1e-05] PASSED [ 5%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_llava-32-0.0001-dtype1-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 11%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_mllama-32-0.0001-dtype2-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 17%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma3_text-32-0.0001-dtype3-1e-08-0.0001-0.005-1e-05-0.005-1e-05] FAILED [ 23%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen2-32-0.0001-dtype4-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 29%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen3-32-0.0001-dtype5-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 35%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen3_moe-32-0.0001-dtype6-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 41%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen2_vl-32-0.0001-dtype7-1e-05-0.1-1-0.1-0.005-1e-05] PASSED [ 47%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen2_5_vl-32-0.0001-dtype8-1e-05-0.1-3-0.1-0.005-1e-05] PASSED [ 52%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_olmo2-32-0.0001-dtype9-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 58%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_glm4-32-0.0001-dtype10-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 64%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_phi3-32-0.0001-dtype11-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 70%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_mistral-32-0.0001-dtype12-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 76%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma1-32-0.0001-dtype13-1e-08-0.0001-0.005-0.01-0.005-1e-05] FAILED [ 82%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype14-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED [ 88%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma2-32-0.0001-dtype15-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED [ 94%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_granite3-32-0.0001-dtype16-1e-08-0.0001-0.05-0.0001-0.005-1e-05] PASSED [100%]
================================================================== FAILURES ===================================================================
___________________________ test_mini_model[mini_gemma3_text-32-0.0001-dtype3-1e-08-0.0001-0.005-1e-05-0.005-1e-05] ___________________________
model_name = 'mini_gemma3_text', num_steps = 32, lr = 0.0001, dtype = torch.float32, loss_atol = 1e-08, loss_rtol = 0.0001
logprobs_atol = 0.005, logprobs_rtol = 1e-05, param_atol = 0.005, param_rtol = 1e-05
@pytest.mark.parametrize(
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol",
[
("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
pytest.param(
"mini_llava",
32,
1e-4,
torch.float32,
1e-8,
1e-5,
5e-3,
1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not LLAVA_AVAILABLE,
reason="LLaVa not available in this version of transformers",
),
),
pytest.param(
"mini_mllama",
32,
1e-4,
torch.float32,
1e-8,
1e-5,
5e-3,
1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not MLLAMA_AVAILABLE,
reason="Mllama not available in this version of transformers",
),
),
pytest.param(
"mini_gemma3_text",
32,
1e-4,
torch.float32,
1e-8,
1e-4,
5e-3,
1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not GEMMA3_AVAILABLE,
reason="Gemma3 not available in this version of transformers",
),
),
("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
pytest.param(
"mini_qwen3",
32,
1e-4,
torch.float32,
1e-8,
1e-5,
5e-3,
1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not QWEN3_AVAILABLE,
reason="Qwen3 not available in this version of transformers",
),
),
pytest.param(
"mini_qwen3_moe",
32,
1e-4,
torch.float32,
1e-8,
1e-5,
5e-3,
1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not QWEN3_AVAILABLE,
reason="Qwen3 not available in this version of transformers",
),
),
pytest.param( # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
"mini_qwen2_vl",
32,
1e-4,
torch.float32,
1e-5, # 1e-8,
1e-1, # 1e-5,
1, # 5e-3,
1e-1, # 1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not QWEN2_VL_AVAILABLE,
reason="Qwen2-VL not available in this version of transformers",
),
),
# TODO: logits tolerances are significantly larger than the other tests, need to investigate
pytest.param( # qwen2_5_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
"mini_qwen2_5_vl",
32,
1e-4,
torch.float32,
1e-5, # 1e-8,
1e-1, # 1e-5,
3, # 5e-3,
1e-1, # 1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not QWEN2_5_VL_AVAILABLE,
reason="Qwen2.5-VL not available in this version of transformers",
),
),
pytest.param(
"mini_olmo2",
32,
1e-4,
torch.float32,
1e-8,
1e-5,
5e-3,
1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not OLMO2_AVAILABLE,
reason="OLMO2 not available in this version of transformers",
),
),
pytest.param(
"mini_glm4",
32,
1e-4,
torch.float32,
1e-8,
1e-5,
5e-3,
1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not GLM4_AVAILABLE,
reason="Glm4 not available in this version of transformers",
),
),
("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
# TODO: mixtral is flaky so disable the test for now
# ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
# Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-2, 5e-3, 1e-5),
("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
pytest.param(
"mini_granite3",
32,
1e-4,
torch.float32,
1e-8,
1e-4,
5e-2, # 5e-3
1e-4, # 1e-5
5e-3,
1e-5,
marks=pytest.mark.skipif(
not GRANITE_AVAILABLE,
reason="Granite not available in this version of transformers",
),
),
],
)
def test_mini_model(
model_name,
num_steps,
lr,
dtype,
loss_atol,
loss_rtol,
logprobs_atol,
logprobs_rtol,
param_atol,
param_rtol,
):
# Non-liger models should be initialized and tested first to avoid the module being overridden
expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
# Compare every step of the loss
assert_verbose_allclose(
torch.tensor([expected_output["loss"]]),
torch.tensor([actual_output["loss"]]),
atol=loss_atol,
rtol=loss_rtol,
)
# Compare the topk logprobs from evaluation step
if expected_output["topk_logprobs"] is not None and actual_output["topk_logprobs"] is not None:
> assert_verbose_allclose(
expected_output["topk_logprobs"],
actual_output["topk_logprobs"],
atol=logprobs_atol,
rtol=logprobs_rtol,
)
test/convergence/fp32/test_mini_models.py:1067:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tensor1 = tensor([[[-9.1038e-03, -4.7121e+00, -1.5932e+01, ..., -1.6516e+01,
-1.6565e+01, -1.6593e+01],
[-9.... [-4.9854e-01, -4.1698e+00, -4.3570e+00, ..., -7.2072e+00,
-7.3368e+00, -7.3456e+00]]], device='cuda:0')
tensor2 = tensor([[[-9.1016e-03, -4.7124e+00, -1.5933e+01, ..., -1.6516e+01,
-1.6566e+01, -1.6593e+01],
[-9.... [-4.9861e-01, -4.1701e+00, -4.3570e+00, ..., -7.2076e+00,
-7.3366e+00, -7.3458e+00]]], device='cuda:0')
rtol = 1e-05, atol = 0.005, max_print = 5
def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
"""
Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
Parameters:
tensor1 (torch.Tensor): First tensor to compare.
tensor2 (torch.Tensor): Second tensor to compare.
rtol (float): Relative tolerance.
atol (float): Absolute tolerance.
max_print (int): Maximum number of mismatched elements to print.
Raises:
AssertionError: If the tensors are not all close within the given tolerance.
"""
# Check if the shapes of the tensors match
if tensor1.shape != tensor2.shape:
raise AssertionError("Input tensors must have the same shape.")
# Calculate the difference between the tensors
diff = torch.abs(tensor1 - tensor2)
# Determine the tolerance
tolerance = atol + rtol * torch.abs(tensor2)
# Find tolerance mismatched elements
tol_mismatched = diff > tolerance
# Find nan mismatched elements
nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
# Find +inf mismatched elements
posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
# Find -inf mismatched elements
neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
# Find all mismatched elements
mismatched = torch.logical_or(
torch.logical_or(tol_mismatched, nan_mismatched),
torch.logical_or(posinf_mismatched, neginf_mismatched),
)
mismatched_indices = torch.nonzero(mismatched)
# Count the number of mismatched elements
num_mismatched = mismatched.sum().item()
# Check if all elements are close
all_close = num_mismatched == 0
# Raise AssertionError with detailed information if there are mismatches
if not all_close and num_mismatched >= 1:
mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
print_count = min(max_print, num_mismatched)
for index in mismatched_indices[:print_count]:
i = tuple(index.tolist())
mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
if num_mismatched > max_print:
mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
> raise AssertionError("\n".join(mismatch_details))
E AssertionError: Number of mismatched elements: 11
E Mismatch at index (1, 126, 17): tensor1[(1, 126, 17)] = -6.848010063171387, tensor2[(1, 126, 17)] = -6.8533782958984375
E Mismatch at index (13, 118, 3): tensor1[(13, 118, 3)] = -3.215667963027954, tensor2[(13, 118, 3)] = -3.221395492553711
E Mismatch at index (13, 118, 10): tensor1[(13, 118, 10)] = -5.487461090087891, tensor2[(13, 118, 10)] = -5.481538772583008
E Mismatch at index (13, 118, 13): tensor1[(13, 118, 13)] = -5.899518013000488, tensor2[(13, 118, 13)] = -5.906171798706055
E Mismatch at index (14, 125, 1): tensor1[(14, 125, 1)] = -3.3704092502593994, tensor2[(14, 125, 1)] = -3.361818313598633
E ... and 6 more mismatched elements.
test/utils.py:130: AssertionError
------------------------------------------------------------ Captured stdout call -------------------------------------------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 5.5114898681640625
Step 1, Loss: 0.7358521819114685
Step 2, Loss: 0.897828221321106
Step 3, Loss: 0.7267115712165833
Step 4, Loss: 0.6775891184806824
Step 5, Loss: 0.7813704609870911
Step 6, Loss: 0.6756245493888855
Step 7, Loss: 0.9666643738746643
Step 8, Loss: 0.6444056034088135
Step 9, Loss: 0.7450430393218994
Step 10, Loss: 0.7641883492469788
Step 11, Loss: 0.40847837924957275
Step 12, Loss: 0.7006322145462036
Step 13, Loss: 0.6850311756134033
Step 14, Loss: 0.6146025061607361
Step 15, Loss: 0.7307778000831604
Step 16, Loss: 0.7162883281707764
Step 17, Loss: 0.8030434250831604
Step 18, Loss: 0.7284414172172546
Step 19, Loss: 0.7638292908668518
Step 20, Loss: 0.6066721081733704
Step 21, Loss: 0.36790019273757935
Step 22, Loss: 0.48964276909828186
Step 23, Loss: 0.4520471692085266
Step 24, Loss: 0.45229068398475647
Step 25, Loss: 0.4304249584674835
Step 26, Loss: 0.3818521797657013
Step 27, Loss: 0.5118127465248108
Step 28, Loss: 0.6146227121353149
Step 29, Loss: 0.38547804951667786
Step 30, Loss: 0.7631076574325562
Step 31, Loss: 0.4806024432182312
Eval Loss: 0.618675172328949
Liger kernel patches have been reverted.
Step 0, Loss: 5.511490821838379
Step 1, Loss: 0.7358517646789551
Step 2, Loss: 0.897828221321106
Step 3, Loss: 0.7267120480537415
Step 4, Loss: 0.6775884628295898
Step 5, Loss: 0.7813707590103149
Step 6, Loss: 0.6756248474121094
Step 7, Loss: 0.9666653871536255
Step 8, Loss: 0.644402801990509
Step 9, Loss: 0.7450424432754517
Step 10, Loss: 0.7641878128051758
Step 11, Loss: 0.40847688913345337
Step 12, Loss: 0.7006286382675171
Step 13, Loss: 0.6850258111953735
Step 14, Loss: 0.6146000027656555
Step 15, Loss: 0.7307751178741455
Step 16, Loss: 0.7162874341011047
Step 17, Loss: 0.8030444979667664
Step 18, Loss: 0.7284409403800964
Step 19, Loss: 0.7638280391693115
Step 20, Loss: 0.6066715121269226
Step 21, Loss: 0.367902547121048
Step 22, Loss: 0.48964276909828186
Step 23, Loss: 0.45204880833625793
Step 24, Loss: 0.452290415763855
Step 25, Loss: 0.43042486906051636
Step 26, Loss: 0.3818501830101013
Step 27, Loss: 0.5118110179901123
Step 28, Loss: 0.6146328449249268
Step 29, Loss: 0.3854790925979614
Step 30, Loss: 0.7631104588508606
Step 31, Loss: 0.48059630393981934
Eval Loss: 0.6186762452125549
Liger kernel patches have been reverted.
_____________________________ test_mini_model[mini_gemma1-32-0.0001-dtype13-1e-08-0.0001-0.005-0.01-0.005-1e-05] ______________________________
model_name = 'mini_gemma1', num_steps = 32, lr = 0.0001, dtype = torch.float32, loss_atol = 1e-08, loss_rtol = 0.0001, logprobs_atol = 0.005
logprobs_rtol = 0.01, param_atol = 0.005, param_rtol = 1e-05
@pytest.mark.parametrize(
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol",
[
("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
pytest.param(
"mini_llava",
32,
1e-4,
torch.float32,
1e-8,
1e-5,
5e-3,
1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not LLAVA_AVAILABLE,
reason="LLaVa not available in this version of transformers",
),
),
pytest.param(
"mini_mllama",
32,
1e-4,
torch.float32,
1e-8,
1e-5,
5e-3,
1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not MLLAMA_AVAILABLE,
reason="Mllama not available in this version of transformers",
),
),
pytest.param(
"mini_gemma3_text",
32,
1e-4,
torch.float32,
1e-8,
1e-4,
5e-3,
1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not GEMMA3_AVAILABLE,
reason="Gemma3 not available in this version of transformers",
),
),
("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
pytest.param(
"mini_qwen3",
32,
1e-4,
torch.float32,
1e-8,
1e-5,
5e-3,
1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not QWEN3_AVAILABLE,
reason="Qwen3 not available in this version of transformers",
),
),
pytest.param(
"mini_qwen3_moe",
32,
1e-4,
torch.float32,
1e-8,
1e-5,
5e-3,
1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not QWEN3_AVAILABLE,
reason="Qwen3 not available in this version of transformers",
),
),
pytest.param( # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
"mini_qwen2_vl",
32,
1e-4,
torch.float32,
1e-5, # 1e-8,
1e-1, # 1e-5,
1, # 5e-3,
1e-1, # 1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not QWEN2_VL_AVAILABLE,
reason="Qwen2-VL not available in this version of transformers",
),
),
# TODO: logits tolerances are significantly larger than the other tests, need to investigate
pytest.param( # qwen2_5_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
"mini_qwen2_5_vl",
32,
1e-4,
torch.float32,
1e-5, # 1e-8,
1e-1, # 1e-5,
3, # 5e-3,
1e-1, # 1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not QWEN2_5_VL_AVAILABLE,
reason="Qwen2.5-VL not available in this version of transformers",
),
),
pytest.param(
"mini_olmo2",
32,
1e-4,
torch.float32,
1e-8,
1e-5,
5e-3,
1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not OLMO2_AVAILABLE,
reason="OLMO2 not available in this version of transformers",
),
),
pytest.param(
"mini_glm4",
32,
1e-4,
torch.float32,
1e-8,
1e-5,
5e-3,
1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not GLM4_AVAILABLE,
reason="Glm4 not available in this version of transformers",
),
),
("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
# TODO: mixtral is flaky so disable the test for now
# ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
# Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-2, 5e-3, 1e-5),
("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
pytest.param(
"mini_granite3",
32,
1e-4,
torch.float32,
1e-8,
1e-4,
5e-2, # 5e-3
1e-4, # 1e-5
5e-3,
1e-5,
marks=pytest.mark.skipif(
not GRANITE_AVAILABLE,
reason="Granite not available in this version of transformers",
),
),
],
)
def test_mini_model(
model_name,
num_steps,
lr,
dtype,
loss_atol,
loss_rtol,
logprobs_atol,
logprobs_rtol,
param_atol,
param_rtol,
):
# Non-liger models should be initialized and tested first to avoid the module being overridden
expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
# Compare every step of the loss
> assert_verbose_allclose(
torch.tensor([expected_output["loss"]]),
torch.tensor([actual_output["loss"]]),
atol=loss_atol,
rtol=loss_rtol,
)
test/convergence/fp32/test_mini_models.py:1058:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tensor1 = tensor([[1.0585, 0.7503, 0.9095, 0.7038, 0.6323, 0.7315, 0.7252, 0.8883, 0.6050,
0.7210, 0.7658, 0.4156, 0.67...336, 0.5679, 0.3464, 0.4418, 0.3993, 0.4071, 0.3988, 0.3726,
0.4775, 0.5491, 0.3414, 0.6705, 0.4464, 0.5609]])
tensor2 = tensor([[1.0586, 0.7503, 0.9095, 0.7038, 0.6323, 0.7316, 0.7248, 0.8882, 0.6049,
0.7208, 0.7655, 0.4154, 0.67...332, 0.5675, 0.3460, 0.4413, 0.3989, 0.4071, 0.3997, 0.3721,
0.4769, 0.5490, 0.3415, 0.6706, 0.4464, 0.5609]])
rtol = 0.0001, atol = 1e-08, max_print = 5
def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
"""
Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
Parameters:
tensor1 (torch.Tensor): First tensor to compare.
tensor2 (torch.Tensor): Second tensor to compare.
rtol (float): Relative tolerance.
atol (float): Absolute tolerance.
max_print (int): Maximum number of mismatched elements to print.
Raises:
AssertionError: If the tensors are not all close within the given tolerance.
"""
# Check if the shapes of the tensors match
if tensor1.shape != tensor2.shape:
raise AssertionError("Input tensors must have the same shape.")
# Calculate the difference between the tensors
diff = torch.abs(tensor1 - tensor2)
# Determine the tolerance
tolerance = atol + rtol * torch.abs(tensor2)
# Find tolerance mismatched elements
tol_mismatched = diff > tolerance
# Find nan mismatched elements
nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
# Find +inf mismatched elements
posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
# Find -inf mismatched elements
neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
# Find all mismatched elements
mismatched = torch.logical_or(
torch.logical_or(tol_mismatched, nan_mismatched),
torch.logical_or(posinf_mismatched, neginf_mismatched),
)
mismatched_indices = torch.nonzero(mismatched)
# Count the number of mismatched elements
num_mismatched = mismatched.sum().item()
# Check if all elements are close
all_close = num_mismatched == 0
# Raise AssertionError with detailed information if there are mismatches
if not all_close and num_mismatched >= 1:
mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
print_count = min(max_print, num_mismatched)
for index in mismatched_indices[:print_count]:
i = tuple(index.tolist())
mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
if num_mismatched > max_print:
mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
> raise AssertionError("\n".join(mismatch_details))
E AssertionError: Number of mismatched elements: 23
E Mismatch at index (0, 6): tensor1[(0, 6)] = 0.725193202495575, tensor2[(0, 6)] = 0.7248455286026001
E Mismatch at index (0, 8): tensor1[(0, 8)] = 0.6050088405609131, tensor2[(0, 8)] = 0.6048830151557922
E Mismatch at index (0, 9): tensor1[(0, 9)] = 0.7209844589233398, tensor2[(0, 9)] = 0.7208169102668762
E Mismatch at index (0, 10): tensor1[(0, 10)] = 0.7657538652420044, tensor2[(0, 10)] = 0.7655386328697205
E Mismatch at index (0, 11): tensor1[(0, 11)] = 0.4156359136104584, tensor2[(0, 11)] = 0.4153878092765808
E ... and 18 more mismatched elements.
test/utils.py:130: AssertionError
------------------------------------------------------------ Captured stdout call -------------------------------------------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 1.0585311651229858
Step 1, Loss: 0.7502652406692505
Step 2, Loss: 0.9094656109809875
Step 3, Loss: 0.7037572860717773
Step 4, Loss: 0.6322704553604126
Step 5, Loss: 0.7315280437469482
Step 6, Loss: 0.725193202495575
Step 7, Loss: 0.8882756233215332
Step 8, Loss: 0.6050088405609131
Step 9, Loss: 0.7209844589233398
Step 10, Loss: 0.7657538652420044
Step 11, Loss: 0.4156359136104584
Step 12, Loss: 0.6751559376716614
Step 13, Loss: 0.6617658734321594
Step 14, Loss: 0.5803778767585754
Step 15, Loss: 0.7222874760627747
Step 16, Loss: 0.708590030670166
Step 17, Loss: 0.7739989757537842
Step 18, Loss: 0.6725180149078369
Step 19, Loss: 0.7336098551750183
Step 20, Loss: 0.5679132342338562
Step 21, Loss: 0.3464488685131073
Step 22, Loss: 0.4417855739593506
Step 23, Loss: 0.3993169963359833
Step 24, Loss: 0.4070657789707184
Step 25, Loss: 0.39878153800964355
Step 26, Loss: 0.37256762385368347
Step 27, Loss: 0.47752323746681213
Step 28, Loss: 0.5490694046020508
Step 29, Loss: 0.341403603553772
Step 30, Loss: 0.670497477054596
Step 31, Loss: 0.4463615119457245
Eval Loss: 0.5608715415000916
Liger kernel patches have been reverted.
Step 0, Loss: 1.058566927909851
Step 1, Loss: 0.7502727508544922
Step 2, Loss: 0.9094695448875427
Step 3, Loss: 0.7037727236747742
Step 4, Loss: 0.6322908997535706
Step 5, Loss: 0.7315629124641418
Step 6, Loss: 0.7248455286026001
Step 7, Loss: 0.8882305026054382
Step 8, Loss: 0.6048830151557922
Step 9, Loss: 0.7208169102668762
Step 10, Loss: 0.7655386328697205
Step 11, Loss: 0.4153878092765808
Step 12, Loss: 0.6748801469802856
Step 13, Loss: 0.661463737487793
Step 14, Loss: 0.5800580382347107
Step 15, Loss: 0.721951961517334
Step 16, Loss: 0.7082199454307556
Step 17, Loss: 0.773615300655365
Step 18, Loss: 0.6721071600914001
Step 19, Loss: 0.733183741569519
Step 20, Loss: 0.567453145980835
Step 21, Loss: 0.34597691893577576
Step 22, Loss: 0.44131916761398315
Step 23, Loss: 0.39894071221351624
Step 24, Loss: 0.40705054998397827
Step 25, Loss: 0.39973393082618713
Step 26, Loss: 0.3721126317977905
Step 27, Loss: 0.4769252836704254
Step 28, Loss: 0.5490176677703857
Step 29, Loss: 0.34152016043663025
Step 30, Loss: 0.6706250309944153
Step 31, Loss: 0.44644686579704285
Eval Loss: 0.5609042644500732
Liger kernel patches have been reverted.
=========================================================== short test summary info ===========================================================
FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma3_text-32-0.0001-dtype3-1e-08-0.0001-0.005-1e-05-0.005-1e-05] - AssertionError: Number of mismatched elements: 11
FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma1-32-0.0001-dtype13-1e-08-0.0001-0.005-0.01-0.005-1e-05] - AssertionError: Number of mismatched elements: 23
============================================= 2 failed, 15 passed, 1 warning in 156.15s (0:02:36) =============================================
make: *** [Makefile:23: test-convergence] Error 1
Versions
System Information
- OS: Linux 5.15.0-112-generic (#122-Ubuntu SMP Thu May 23 07:48:21 UTC 2024)
- Architecture: x86_64
Python Environment
- Python Version: 3.12.3 | packaged by Anaconda, Inc. | (main, May 6 2024, 19:46:43) [GCC 11.2.0]
- Python Implementation: CPython
Package Versions
- torch: 2.7.1
- liger-kernel: 0.5.10
- transformers: 4.53.0
- datasets: 3.6.0
- pytest: 8.4.1
- packaging: 25.0
- accelerate: Not installed
- numpy: 2.3.1
- safetensors: 0.5.3
- tokenizers: 0.21.2
CUDA Version
- Build cuda_12.1.r12.1/compiler.32688072_0