lerobot icon indicating copy to clipboard operation
lerobot copied to clipboard

Tests are broken on main as of 89c6be8

Open AshisGhosh opened this issue 2 months ago • 1 comments

System Info

`main` / 89c6be8

Information

  • [X] One of the scripts in the examples/ folder of LeRobot
  • [ ] My own task or dataset (give details below)

Reproduction

Run

DATA_DIR="tests/data" python -m pytest -sv ./tests

Output:

================================================================= FAILURES =================================================================
_________________________________________ test_backward_compatibility[aloha-act-extra_overrides2] __________________________________________

env_name = 'aloha', policy_name = 'act', extra_overrides = ['policy.n_action_steps=10']

    @pytest.mark.parametrize(
        "env_name, policy_name, extra_overrides",
        [
            ("xarm", "tdmpc", []),
            (
                "pusht",
                "diffusion",
                ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
            ),
            ("aloha", "act", ["policy.n_action_steps=10"]),
        ],
    )
    # As artifacts have been generated on an x86_64 kernel, this test won't
    # pass if it's run on another platform due to floating point errors
    @require_x86_64_kernel
    def test_backward_compatibility(env_name, policy_name, extra_overrides):
        """
        NOTE: If this test does not pass, and you have intentionally changed something in the policy:
            1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
               include a report on what changed and how that affected the outputs.
            2. Go to the `if __name__ == "__main__"` block of `test/scripts/save_policy_to_safetensors.py` and
               add the policies you want to update the test artifacts for.
            3. Run `python test/scripts/save_policy_to_safetensors.py`. The test artifact should be updated.
            4. Check that this test now passes.
            5. Remember to restore `test/scripts/save_policy_to_safetensors.py` to its original state.
            6. Remember to stage and commit the resulting changes to `tests/data`.
        """
        env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
        saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
        saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
        saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")
        saved_actions = load_file(env_policy_dir / "actions.safetensors")
    
        output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
    
        for key in saved_output_dict:
            assert torch.isclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7).all()
        for key in saved_grad_stats:
>           assert torch.isclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7).all()
E           assert tensor(False)
E            +  where tensor(False) = <built-in method all of Tensor object at 0x7d36a6cda660>()
E            +    where <built-in method all of Tensor object at 0x7d36a6cda660> = tensor(False).all
E            +      where tensor(False) = <built-in method isclose of type object at 0x7d3855386760>(tensor(0.0026), tensor(0.0005), rtol=0.1, atol=1e-07)
E            +        where <built-in method isclose of type object at 0x7d3855386760> = torch.isclose

tests/test_policies.py:274: AssertionError
========================================================= short test summary info ==========================================================
FAILED tests/test_policies.py::test_backward_compatibility[aloha-act-extra_overrides2] - assert tensor(False)
================================================ 1 failed, 38 passed, 26 skipped in 14.17s =================================================

Expected behavior

All tests to pass

AshisGhosh avatar May 13 '24 02:05 AshisGhosh