分类结果评估

  • 发布日期:2019-10-22
  • 难度:一般
  • 类别:分类与预测、分类结果评估
  • 标签:Python、scikit-learn、决策树、k折交叉验证、混淆矩阵、准确率、精确率、召回率、F值、乳腺癌数据集

1. 问题描述

使用决策树算法针对乳腺癌数据集建立分类模型,首先按照7:3的比例分为训练集测试集,在此基础上分别用混淆矩阵、准确率、精确率、召回率、F值、分类报告这六种形式对该模型的分类效果进行评估。

2. 程序实现

In [1]:
#导入数据集
from sklearn.datasets import load_breast_cancer
cancer=load_breast_cancer()
#划分为训练集和测试集
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(cancer.data, cancer.target, test_size=0.3, random_state=42)
#建立决策树模型
from sklearn.tree import DecisionTreeClassifier
clf_tree=DecisionTreeClassifier(max_depth=None, min_samples_split=2, random_state=1)
clf_tree = clf_tree.fit(X_train, y_train)
#使用已建立模型进行预测
y_pred=clf_tree.predict(X_test)
#混淆矩阵
from sklearn.metrics import confusion_matrix
print(confusion_matrix(y_test, y_pred))
[[ 59   4]
 [  7 101]]
In [4]:
#准确率
from sklearn.metrics import accuracy_score
print("accuracy of malignant and benign:%s" % (accuracy_score(y_test,y_pred)))
#精确率
from sklearn.metrics import precision_score
print("precision of malignant:%s" % (precision_score(y_test,y_pred, pos_label=0)))
print("precision of benign:%s" % (precision_score(y_test,y_pred)))
#召回率
from sklearn.metrics import recall_score
print("recall of malignant:%s" % (recall_score(y_test,y_pred,pos_label=0)))
print("recall of benign:%s" % (recall_score(y_test,y_pred)))
accuracy of malignant and benign:0.93567251462
precision of malignant:0.893939393939
precision of benign:0.961904761905
recall of malignant:0.936507936508
recall of benign:0.935185185185
In [5]:
#对该数据集用交叉验证的方法进行评估,设定cv=10,即10折交叉验证,通过cross_val_score函数可以得到评估指标的数组,由于选定的scoring为f1,因此得到的数组表示每次交叉验证得到的F值。还可以通过score.mean()函数来求10次F值结果的平均值
#交叉验证评估
from sklearn import cross_validation
score=cross_validation.cross_val_score(clf_tree, cancer.data, cancer.target, cv=10,scoring='f1')
print(score)
print(score.mean())
[ 0.92957746  0.87671233  0.93150685  0.86956522  0.97142857  0.93150685
  0.91891892  0.95652174  0.93939394  0.92307692]
0.924820880153
C:\ProgramData\Anaconda3\lib\site-packages\sklearn\cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.
  "This module will be removed in 0.20.", DeprecationWarning)