pytorch-semseg
pytorch-semseg copied to clipboard
The metrics of computing confusion matrix might be more efficient
import numpy as np
from sklearn.metrics import confusion_matrix
class runningScore(object):
def __init__(self, n_classes):
self.n_classes = n_classes
self.confusion_matrix = np.zeros((n_classes, n_classes))
def _fast_hist(self, label_true, label_pred, n_class):
mask = (label_true >= 0) & (label_true < n_class)
hist = np.bincount(
n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2
).reshape(n_class, n_class)
return hist
def update(self, label_trues, label_preds):
# for lt, lp in zip(label_trues, label_preds):
# self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
self.confusion_matrix += confusion_matrix(label_trues.flatten(), label_preds.flatten(), list(range(self.n_classes)))
I tested the above code, it is faster 10x than original code
Thanks!
Can you send a PR?
`Traceback (most recent call last):
File "train.py", line 248, in
File "D:\Anaconda3\envs\py\lib\site-packages\sklearn\metrics_classification.py", line 296, in confusion_matrix 17:43:29.100873 line 57 if self.is_transform: 17:43:29.100873 line 58 img, lbl = self.transform(img, lbl) y_type, y_true, y_pred = _check_targets(y_true, y_pred) File "D:\Anaconda3\envs\py\lib\site-packages\sklearn\metrics_classification.py", line 83, in _check_targets check_consistent_length(y_true, y_pred) File "D:\Anaconda3\envs\py\lib\site-packages\sklearn\utils\validation.py", line 263, in check_consistent_length " samples: %r" % [int(l) for l in lengths]) ValueError: Found input variables with inconsistent numbers of samples: [2764800, 691200]`
Have you met this error brfore? thank you!