pytorch-semseg icon indicating copy to clipboard operation
pytorch-semseg copied to clipboard

The metrics of computing confusion matrix might be more efficient

Open xyl576807077 opened this issue 5 years ago • 2 comments

import numpy as np
from sklearn.metrics import confusion_matrix

class runningScore(object):
    def __init__(self, n_classes):
        self.n_classes = n_classes
        self.confusion_matrix = np.zeros((n_classes, n_classes))

    def _fast_hist(self, label_true, label_pred, n_class):
        mask = (label_true >= 0) & (label_true < n_class)
        hist = np.bincount(
            n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2
        ).reshape(n_class, n_class)
        return hist

    def update(self, label_trues, label_preds):
        # for lt, lp in zip(label_trues, label_preds):
        #     self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
        self.confusion_matrix += confusion_matrix(label_trues.flatten(), label_preds.flatten(), list(range(self.n_classes)))

I tested the above code, it is faster 10x than original code

xyl576807077 avatar Mar 25 '19 08:03 xyl576807077

Thanks!

Can you send a PR?

meetps avatar Mar 25 '19 09:03 meetps

`Traceback (most recent call last): File "train.py", line 248, in train(cfg, writer, logger) File "train.py", line 183, in train running_metrics_val.update(gt, pred) File "F:\Segnet\semseg\ptsemseg\metrics.py", line 23, in update self.confusion_matrix1 += confusion_matrix(label_trues.flatten(), label_preds.flatten(), list(range(self.n_classes))) File "D:\Anaconda3\envs\py\lib\site-packages\sklearn\utils\validation.py", line 74, in inner_f New var:....... lbl = ndarray<(720, 960), uint8> 17:43:29.099873 line 52 lbl = np.array(lbl, dtype=np.int8) Modified var:.. lbl = ndarray<(720, 960), int8> return f(**kwargs)17:43:29.099873 line 54 if self.augmentations is not None:

File "D:\Anaconda3\envs\py\lib\site-packages\sklearn\metrics_classification.py", line 296, in confusion_matrix 17:43:29.100873 line 57 if self.is_transform: 17:43:29.100873 line 58 img, lbl = self.transform(img, lbl) y_type, y_true, y_pred = _check_targets(y_true, y_pred) File "D:\Anaconda3\envs\py\lib\site-packages\sklearn\metrics_classification.py", line 83, in _check_targets check_consistent_length(y_true, y_pred) File "D:\Anaconda3\envs\py\lib\site-packages\sklearn\utils\validation.py", line 263, in check_consistent_length " samples: %r" % [int(l) for l in lengths]) ValueError: Found input variables with inconsistent numbers of samples: [2764800, 691200]`

Have you met this error brfore? thank you!

benzi007 avatar Jun 17 '21 09:06 benzi007