deeplearning-cv-notes icon indicating copy to clipboard operation
deeplearning-cv-notes copied to clipboard

混淆矩阵的绘制

Open jayboxyz opened this issue 5 years ago • 2 comments

注1:

在下文计算混淆矩阵的代码中,可能会出现一个报错:

missing from current font.

加入下面代码可以解决该报错:

plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['SimHei']

注2:

当使用如下代码保存使用 plt.savefig 保存生成的图片时,结果打开生成的图片却是一片空白。

import matplotlib.pyplot as plt

""" 一些画图代码 """

plt.show()
plt.savefig("filename.png")

原因:其实产生这个现象的原因很简单:在 plt.show() 后调用了 plt.savefig() ,在 plt.show() 后实际上已经创建了一个新的空白的图片(坐标轴),这时候你再 plt.savefig() 就会保存这个新生成的空白图片。

解决:在 plt.show() 之前调用 plt.savefig();

jayboxyz avatar Nov 11 '19 12:11 jayboxyz

1、混淆矩阵的绘制(Plot a confusion matrix)

# 绘制混淆矩阵
def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix
conf_matrix = torch.zeros(10, 10)
for data, target in test_loader:
    output = fullModel(data.to(device))
    conf_matrix = confusion_matrix(output, target, conf_matrix)

最后得到的conf_matrix就是混淆矩阵的值。 image

有了上面的混淆矩阵中具体的值,下面就是进行可视化的步骤。可视化我们使用seaborn来进行完成。因为我这里conf_matrix的值是tensor, 所以需要先转换为Numpy.

import seaborn as sn
df_cm = pd.DataFrame(conf_matrix.numpy(),
                     index = [i for i in list(Attack2Index.keys())],
                     columns = [i for i in list(Attack2Index.keys())])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True, cmap="BuPu")

image

混淆矩阵的可视化(进行美化):

import itertools
# 绘制混淆矩阵
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    Input
    - cm : 计算出的混淆矩阵的值
    - classes : 混淆矩阵中每一行每一列对应的列
    - normalize : True:显示百分比, False:显示个数
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')
    print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

测试:

plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title='Normalized confusion matrix')

image

2、除了上文,可以看下该文 混淆矩阵及绘图

image

from sklearn.metrics import confusion_matrix
from sklearn.metrics import recall_score
import matplotlib.pyplot as plt
 
 
# 预测数据,predict之后的预测结果集
guess = [1, 0, 1, 2, 1, 0, 1, 0, 1, 0]
# 真实结果集
fact = [0, 1, 0, 1, 2, 1, 0, 1, 0, 1]
# 类别
classes = list(set(fact))
# 排序,准确对上分类结果
classes.sort()
# 对比,得到混淆矩阵
confusion = confusion_matrix(guess, fact)
# 热度图,后面是指定的颜色块,gray也可以,gray_x反色也可以
plt.imshow(confusion, cmap=plt.cm.Blues)
# 这个东西就要注意了
# ticks 这个是坐标轴上的坐标点
# label 这个是坐标轴的注释说明
indices = range(len(confusion))
# 坐标位置放入
# 第一个是迭代对象,表示坐标的顺序
# 第二个是坐标显示的数值的数组,第一个表示的其实就是坐标显示数字数组的index,但是记住必须是迭代对象
plt.xticks(indices, classes)
plt.yticks(indices, classes)
# 热度显示仪?就是旁边的那个验孕棒啦
plt.colorbar()
# 就是坐标轴含义说明了
plt.xlabel('guess')
plt.ylabel('fact')
# 显示数据,直观些
for first_index in range(len(confusion)):
    for second_index in range(len(confusion[first_index])):
        plt.text(first_index, second_index, confusion[first_index][second_index])
 
# 显示
plt.show()
 
# PS:注意坐标轴上的显示,就是classes
# 如果数据正确的,对应关系显示错了就功亏一篑了
# 一个错误发生,想要说服别人就更难了

3、如何用python画好confusion matrix

image

'''compute confusion matrix
labels.txt: contain label name.
predict.txt: predict_label true_label
'''
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
#load labels.
labels = []
file = open('labels.txt', 'r')
lines = file.readlines()
for line in lines:
	labels.append(line.strip())
file.close()
 
y_true = []
y_pred = []
#load true and predict labels.
file = open('predict.txt', 'r')
lines = file.readlines()
for line in lines:
	y_true.append(int(line.split(" ")[1].strip()))
	y_pred.append(int(line.split(" ")[0].strip()))
file.close()
tick_marks = np.array(range(len(labels))) + 0.5
def plot_confusion_matrix(cm, title='Confusion Matrix', cmap = plt.cm.binary):
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    xlocations = np.array(range(len(labels)))
    plt.xticks(xlocations, labels, rotation=90)
    plt.yticks(xlocations, labels)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
cm = confusion_matrix(y_true, y_pred)
print cm
np.set_printoptions(precision=2)
cm_normalized = cm.astype('float')/cm.sum(axis=1)[:, np.newaxis]
print cm_normalized
plt.figure(figsize=(12,8), dpi=120)
#set the fontsize of label.
#for label in plt.gca().xaxis.get_ticklabels():
#    label.set_fontsize(8)
#text portion
ind_array = np.arange(len(labels))
x, y = np.meshgrid(ind_array, ind_array)
 
for x_val, y_val in zip(x.flatten(), y.flatten()):
    c = cm_normalized[y_val][x_val]
    if (c > 0.01):
	plt.text(x_val, y_val, "%0.2f" %(c,), color='red', fontsize=7, va='center', ha='center')
#offset the tick
plt.gca().set_xticks(tick_marks, minor=True)
plt.gca().set_yticks(tick_marks, minor=True)
plt.gca().xaxis.set_ticks_position('none')
plt.gca().yaxis.set_ticks_position('none')
plt.grid(True, which='minor', linestyle='-')
plt.gcf().subplots_adjust(bottom=0.15)
 
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')
#show confusion matrix
plt.show()

下面的当做参考:

1、https://github.com/Tony607/Focal_Loss_Keras/blob/master/src/keras_focal_loss.ipynb image

jayboxyz avatar Nov 11 '19 13:11 jayboxyz

NameError: name 'Attack2Index' is not defined 请问这个问题如何解决

upupbo avatar Mar 05 '21 14:03 upupbo