pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Unable to extract confusion matrix as a metric from trainer

Open lathashree01 opened this issue 3 months ago • 1 comments

Bug description

Hi team,

During testing, I would like to extract Binaryconfusion matrix as a metric from the trainer.

I can see the value being calculated successfully but it is failing in trainer where the number of elements in result is greater than 1.

https://github.com/Lightning-AI/pytorch-lightning/blob/d1949766f8cddd424e2fac3a68b275bebe13d3e4/src/lightning/fabric/utilities/apply_func.py#L123

def convert_tensors_to_scalars(data: Any) -> Any:
 """Recursively walk through a collection and convert single-item tensors to scalar values.
 Raises:
     ValueError:
         If tensors inside ``metrics`` contains multiple elements, hence preventing conversion to a scalar.

 """
 def to_item(value: Tensor) -> Union[int, float, bool]:
     if value.numel() != 1:
         raise ValueError(
             f"The metric `{value}` does not contain a single element, thus it cannot be converted to a scalar."
         )
     return value.item()
 return apply_to_collection(data, Tensor, to_item)

PS: I am using Anomalib library which is based on pytorch lightning.

How do I resolve this or is there another way to get this? Any help would be greatly appreciated.

Thanks

What version are you seeing the problem on?

v2.2

How to reproduce the bug

Extract BinaryConfusionMatrix as a metric during trainer.test

Error messages and logs

ValueError: The metric `tensor([[343, 497],
        [  0,   4]])` does not contain a single element, thus it cannot be converted to a scalar.

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

lathashree01 avatar May 01 '24 16:05 lathashree01