基于make_moons数据集的Logistic回归分类器

  • 发布日期:2019-10-22
  • 难度:一般
  • 类别:分类与预测、Logistic回归
  • 标签:Python、scikit-learn、Logistic回归、make_moons

1. 问题描述

使用LogisticRegression函数对make_moons数据集进行分类预测。

2. 程序实现

In [1]:
#导入make_moons数据集
import sklearn.datasets 
X, y = sklearn.datasets.make_moons(100,noise=0.3)
#划分训练集和测试集
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
#建立Logistic回归模型
from sklearn.linear_model import LogisticRegression
clf=LogisticRegression(random_state=3)
clf.fit(X,y)
y_pred=clf.predict(X_test)
#效果评估
from sklearn.metrics import classification_report
train_score=clf.score(X_train,y_train)
test_score=clf.score(X_test,y_test)
print("train_score:%s" % (train_score))
print("test_score:%s" % (test_score))
print(classification_report(y_test,y_pred))
train_score:0.871428571429
test_score:0.766666666667
             precision    recall  f1-score   support

          0       0.81      0.76      0.79        17
          1       0.71      0.77      0.74        13

avg / total       0.77      0.77      0.77        30

In [2]:
#接下来对LogisticRegression模型绘制可视化效果图

#定义效果图绘制函数
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
def plot_hyperplane(clf, X, y, 
                    h=0.02, 
                    title='hyperplan'):
    # create a mesh to plot in
    x_min,x_max=X[:,0].min()-1,X[:,0].max()+1
    y_min,y_max=X[:,1].min()-1,X[:,1].max()+1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
    plt.title(title)
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.xticks(())
    plt.yticks(())
    Z=clf.predict(np.c_[xx.ravel(),yy.ravel()]).reshape(xx.shape)
    cmap_light=ListedColormap(['#FFAAAA','#AFEEEE'])
    cmap_dark=ListedColormap(['#FF0000','#000080'])
    plt.pcolormesh(xx,yy,Z,cmap=cmap_light)
    plt.scatter(X[:,0], X[:,1],c=y,cmap=cmap_dark)
#可视化效果图
plt.figure()
plot_hyperplane(clf,X,y,h=0.01,title='logistics regression')
plt.show()