ignite icon indicating copy to clipboard operation
ignite copied to clipboard

Improve tests using pytest parametrize

Open vfdev-5 opened this issue 3 years ago • 11 comments

Improve tests using pytest parametrize

The idea is to improve our testing code using pytest.mark.parametrize.

Currently some tests are implemented as :

def test_something():

    def _test(a):
       # internal function doing some tests
       # For example, a dummy check below
       assert 0 < a < 5

    a = 1
    _test(a)

    a = 2
    _test(a, b, c)

We would like to implement that using pytest.mark.parametrize:


@pytest.mark.parametrize("a", [1, 2])
def test_something(a):
    # internal function doing some tests
    # For example, a dummy check below
    assert 0 < a < 5

Another example PR doing that for a test case: https://github.com/pytorch/ignite/pull/2521/files

What to do:

  • [ ] Read CONTRIBUTING to see how to develop and execute a specific test: https://github.com/pytorch/ignite/blob/master/CONTRIBUTING.md#run-tests
  • [ ] Explore test codebase and identify tests to improve. Please pay attention that some tests most probably can't be reimplemented with pytest.mark.parametrize as they can generated random data. Let's skip all _test_distrib_* tests as they can be particularly complicated to port.
  • [ ] Once a test to improve is identified, comment out here about it and propose a draft code snippet of how you think to reimplement it using pytest.mark.parametrize.
  • [ ] In parallel, you can try to reimplement it locally and test if the new test is still passing. Please make sure not to modify test assertions to make test pass as your PR wont be accepted.
  • [ ] Send a draft PR

Thanks!

vfdev-5 avatar Mar 22 '22 11:03 vfdev-5

Hi, I would like to implement this.

nmcguire101 avatar Mar 22 '22 19:03 nmcguire101

@nmcguire101 please follow "What to do" guidelines

vfdev-5 avatar Mar 22 '22 21:03 vfdev-5

@vfdev-5 In https://github.com/pytorch/ignite/blob/3a286b1d13a4a0476b3ee8e8046f16465818c9f6/tests/ignite/metrics/gan/test_fid.py#L158 https://github.com/pytorch/ignite/blob/3a286b1d13a4a0476b3ee8e8046f16465818c9f6/tests/ignite/metrics/gan/test_inception_score.py#L68 the _test function could be improved in the _test_distrib_integration function I could make changes all the places where this _test_distrib_integration function is used Draft code:

def _test_distrib_integration(device):

    from ignite.engine import Engine

    rank = idist.get_rank()
    torch.manual_seed(12)
    
    metric_devices = [torch.device("cpu")]
    if device.type != "xla":
        metric_devices.append(idist.device())
    
    
    @pytest.mark.parameterize("metric_device",metric_devices)
    def _test(metric_device):



       

       

divo12 avatar Mar 23 '22 11:03 divo12

@divo12 thanks for the suggestion. Sorry, I had to mark that all _test_distrib_* tests can be difficult to port using pytest parametrize. Please, take a look at https://github.com/pytorch/ignite/blob/master/tests/ignite/handlers/test_param_scheduler.py tests, for example: https://github.com/pytorch/ignite/blob/3a286b1d13a4a0476b3ee8e8046f16465818c9f6/tests/ignite/handlers/test_param_scheduler.py#L383 These tests should be easier to port.

vfdev-5 avatar Mar 23 '22 11:03 vfdev-5

@vfdev-5 check this one https://github.com/pytorch/ignite/blob/3a286b1d13a4a0476b3ee8e8046f16465818c9f6/tests/ignite/handlers/test_checkpoint.py#L99 draft code:

def test_checkpoint_default():
    model = DummyModel()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    @pytest.mark.parametrize("to_save,obj,name",
                             [
                                 ({"model": model},model.state_dict(),"model"),
                                 ({"model": model, "optimizer": optimizer},{"model": model.state_dict(), "optimizer": optimizer.state_dict()},"checkpoint"),
                             ])
    def _test(to_save, obj, name):

divo12 avatar Mar 23 '22 11:03 divo12

@divo12 please check again the example: https://github.com/pytorch/ignite/issues/2522#issue-1176643726 The point is also to remove internal _test function and parametrize test_checkpoint_default function.

EDIT: I think this one is rather tricky to make it parametrized.

vfdev-5 avatar Mar 23 '22 15:03 vfdev-5

Umm i could shift this part of my code above 'test_checkpoint_default' and parametrize 'test_checkpoint_default' with the parameters of '_test' and remove '_test'

    model = DummyModel()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    @pytest.mark.parametrize("to_save,obj,name",
                             [
                                 ({"model": model},model.state_dict(),"model"),
                                 ({"model": model, "optimizer": optimizer},{"model": model.state_dict(), "optimizer": optimizer.state_dict()},"checkpoint"),
                             ])

What is going to be the tricky part in this?

divo12 avatar Mar 23 '22 17:03 divo12

In this case you declared

    model = DummyModel()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1

as globals which is not the best choice neither but maybe OK... You can see in this test file that there are many other tests using the same pattern and using model, optimizer. What we can do is to define model, optimizer globally near class DummyModel(...) and use parametrize as you drafted...

@divo12 please update other tests and send a PR. Thanks !

vfdev-5 avatar Mar 23 '22 17:03 vfdev-5

@vfdev-5 I have sent a draft PR.Please review it

divo12 avatar Mar 24 '22 07:03 divo12

Hi @vfdev-5 Great idea of using pytest.mark.parameterize

A simple fix in https://github.com/pytorch/ignite/blob/3a286b1d13a4a0476b3ee8e8046f16465818c9f6/tests/ignite/handlers/test_time_limit.py#L18-L39 would be to write it as


@pytest.mark.parametrize("n_iters, limit",[(20, 10),(5,10)])
def test_terminate_on_time_limit(n_iters, limit):
        started = time.time()
        trainer = Engine(_train_func)

        @trainer.on(Events.TERMINATE)
        def _():
            trainer.state.is_terminated = True

        trainer.add_event_handler(Events.ITERATION_COMPLETED, TimeLimit(limit))
        trainer.state.is_terminated = False

        trainer.run(range(n_iters))
        elapsed = round(time.time() - started)
        assert elapsed <= limit + 1
        assert trainer.state.is_terminated == (n_iters > limit)

and move the _train function outside. Let me know what you think. Thanks!

Ishan-Kumar2 avatar Mar 24 '22 10:03 Ishan-Kumar2

Looks good to me @Ishan-Kumar2 , please send a PR :)

vfdev-5 avatar Mar 24 '22 10:03 vfdev-5

Let's close this one as a lot of updates were already done and description is out-dated. Thanks everyone who contributed to fix this tracker issue

vfdev-5 avatar Oct 17 '22 13:10 vfdev-5