原始的朴素贝叶斯只能处理离散数据,当输入数据是连续变量时,可以使用高斯朴素贝叶斯;多项式朴素贝叶斯经常用于处理多分类问题,比起原始的朴素贝叶斯分类效果有较大提升。本节使用高斯朴素贝叶斯方法,对make_moons数据进行分类预测。
#导入数据集,并将其划分为训练集和测试集
import sklearn.datasets
from sklearn.model_selection import train_test_split
X, y = sklearn.datasets.make_moons(100,noise=0.3)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
#建立朴素贝叶斯模型
from sklearn.naive_bayes import GaussianNB
clf=GaussianNB()
clf.fit(X_train, y_train)
#准确率评估
train_score=clf.score(X_train, y_train)
test_score=clf.score(X_test,y_test)
print(train_score)
print(test_score)
#二维可视化效果图
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
import numpy as np
cmap_light=ListedColormap(['#FFAAAA','#AFEEEE'])
cmap_dark=ListedColormap(['#FF0000','#000080'])
x_min,x_max=X[:,0].min()-1,X[:,0].max()+1
y_min,y_max=X[:,1].min()-1,X[:,1].max()+1
xx,yy=np.meshgrid(np.arange(x_min,x_max,0.1),np.arange (y_min,y_max,0.1))
Z=clf.predict(np.c_[xx.ravel(),yy.ravel()]).reshape(xx.shape)
plt.figure()
plt.pcolormesh(xx,yy,Z,cmap=cmap_light)
plt.scatter(X[:,0], X[:,1],c=y,cmap=cmap_dark)
plt.show()
#查看具体类别的预测效果
y_pred=clf.predict(X_test)
from sklearn.metrics import classification_report
print(classification_report(y_test,y_pred))