ignite icon indicating copy to clipboard operation
ignite copied to clipboard

Introduce a variable skip_unrolling in class Metric

Open simeetnayan81 opened this issue 1 year ago • 10 comments

Fixes #2940

Description: Introduce a variable skip_unrolling in class Metric as discussed here https://discord.com/channels/831462531327328276/1110662056622964860/1253769540710567977

Check list:

  • [x] New tests are added (if a new feature is added)
  • [x] New doc strings: description and/or example code are in RST format
  • [ ] Documentation is updated (if required)

simeetnayan81 avatar Jun 25 '24 15:06 simeetnayan81

Tests should be added to end of the test_metric.py file?

simeetnayan81 avatar Jun 27 '24 03:06 simeetnayan81

Yes, you can add it in the end of the file

vfdev-5 avatar Jun 27 '24 04:06 vfdev-5

skip_unrolling = False is already covered by all the prior tests. I have added a test for when skip_unrolling = True. Kindly review and let me know the changes.

simeetnayan81 avatar Jun 28 '24 05:06 simeetnayan81

@vfdev-5 Before adding the example in the docstring, I wanted to confirm, to make skip_unrolling effective for the loss function, we might also need to change this. https://github.com/pytorch/ignite/blob/master/ignite/metrics/loss.py#L77 Prev:

def __init__(
        self,
        loss_fn: Callable,
        output_transform: Callable = lambda x: x,
        batch_size: Callable = len,
        device: Union[str, torch.device] = torch.device("cpu"),
    ):
        super(Loss, self).__init__(output_transform, device=device)

Change to:

def __init__(
        self,
        loss_fn: Callable,
        output_transform: Callable = lambda x: x,
        batch_size: Callable = len,
        device: Union[str, torch.device] = torch.device("cpu"),
        skip_unrolling=False
    ):
        super(Loss, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling)

simeetnayan81 avatar Jun 28 '24 09:06 simeetnayan81

@simeetnayan81 yes, you are right, we need to add this new arg to all metrics defining a constructor. Let's update Loss metric here and update other metrics in a follow-up PR.

vfdev-5 avatar Jun 28 '24 10:06 vfdev-5

Things to do in a follow-up PR.

  • [x] Add test for updated Loss class
  • [ ] Update other sub-classes of Metric with skip_unrolling arg as required, add tests and docstring

simeetnayan81 avatar Jun 28 '24 11:06 simeetnayan81

Thanks for the updates and the TODO. Can we do this point here ?

Add test for updated Loss class

vfdev-5 avatar Jun 28 '24 11:06 vfdev-5

Alright @vfdev-5

simeetnayan81 avatar Jun 28 '24 11:06 simeetnayan81

Have made the changes, the new test works locally.

simeetnayan81 avatar Jun 28 '24 15:06 simeetnayan81

The test is failing because list[torch.Tensor, torch.Tensor] is supported on python 3.9 and above. Let me modify it a bit.

simeetnayan81 avatar Jun 29 '24 05:06 simeetnayan81