ignite
ignite copied to clipboard
Introduce a variable skip_unrolling in class Metric
Fixes #2940
Description: Introduce a variable skip_unrolling in class Metric as discussed here https://discord.com/channels/831462531327328276/1110662056622964860/1253769540710567977
Check list:
- [x] New tests are added (if a new feature is added)
- [x] New doc strings: description and/or example code are in RST format
- [ ] Documentation is updated (if required)
Tests should be added to end of the test_metric.py file?
Yes, you can add it in the end of the file
skip_unrolling = False is already covered by all the prior tests. I have added a test for when skip_unrolling = True. Kindly review and let me know the changes.
@vfdev-5 Before adding the example in the docstring, I wanted to confirm, to make skip_unrolling effective for the loss function, we might also need to change this. https://github.com/pytorch/ignite/blob/master/ignite/metrics/loss.py#L77 Prev:
def __init__(
self,
loss_fn: Callable,
output_transform: Callable = lambda x: x,
batch_size: Callable = len,
device: Union[str, torch.device] = torch.device("cpu"),
):
super(Loss, self).__init__(output_transform, device=device)
Change to:
def __init__(
self,
loss_fn: Callable,
output_transform: Callable = lambda x: x,
batch_size: Callable = len,
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling=False
):
super(Loss, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling)
@simeetnayan81 yes, you are right, we need to add this new arg to all metrics defining a constructor. Let's update Loss metric here and update other metrics in a follow-up PR.
Things to do in a follow-up PR.
- [x] Add test for updated Loss class
- [ ] Update other sub-classes of Metric with
skip_unrollingarg as required, add tests and docstring
Thanks for the updates and the TODO. Can we do this point here ?
Add test for updated Loss class
Alright @vfdev-5
Have made the changes, the new test works locally.
The test is failing because list[torch.Tensor, torch.Tensor] is supported on python 3.9 and above. Let me modify it a bit.