什么是混淆矩阵
混淆矩阵是机器学习中总结分类模型预测结果的情形分析表,以矩阵形式将数据集中的记录按照真实的类别与分类模型作出的分类判断两个标准进行汇总。这个名字来源于它可以非常容易的表明多个类别是否有混淆(也就是一个class被预测成另一个class)
如下图:
其中绿色部分是预测正确的,红色是预测错误的。
对于二分类(正误)问题来说:
参考:http://www.omegaxyz.com/2017/08/27/rocandauc/
Python混淆矩阵的使用
confusion_matrix函数的使用
官方文档中给出的用法是
sklearn.metrics.confusion_matrix(y_true, y_pred, labels=None, sample_weight=None)
y_true: 是样本真实分类结果,y_pred: 是样本预测分类结果
labels:是所给出的类别,通过这个可对类别进行选择
sample_weight : 样本权重
实现代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 |
from sklearn.metrics import confusion_matrix y_true = [2, 1, 0, 1, 2, 0] y_pred = [2, 0, 0, 1, 2, 1] C=confusion_matrix(y_true, y_pred) print(C, end='\n\n') y_true = ["cat", "ant", "cat", "cat", "ant", "bird"] y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"] C2 = confusion_matrix(y_true, y_pred, labels=["ant", "bird", "cat"]) print(C2) |
结果: