scikit-learn中的一个k_means聚类方法参数说明

  • 发布日期:2019-10-25
  • 难度:中等
  • 类别:聚类分析、k-means聚类方法应用案例
  • 标签:Python、sklearn.cluster.KMeans

1. 问题描述

如下程序是k-means算法基于Python第三方库sklearn的一个实例。应用的数据集为sklearn中的make_blobs的3000个数据点,最佳聚类数目设为3类。程序分别实现了将数据点聚为两类的情况、聚为三类的情况、类内数据的标准差较大和较小的情况,以及类团内部数据规模差异较大的情况。

2. 程序实现

In [4]:
#coding:utf-8
# k-means实验
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs
#以下三行为图中标题中文显示问题的解决方案
from pylab import mpl
mpl.rcParams['font.sans-serif'] = ['FangSong'] 
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS']
plt.figure(figsize=(24, 24))
# 选取样本数量
n_samples = 3000
# 选取随机因子
random_state = 100
# 获取数据集
X, y = make_blobs(n_samples=n_samples, random_state=random_state,centers=3)
# 聚类数量不正确时的效果
y_pred = KMeans(n_clusters=2, random_state=random_state).fit_predict(X)
plt.subplot(221)
plt.scatter(X[y_pred==0][:, 0], X[y_pred==0][:, 1], marker='x',color='b')
plt.scatter(X[y_pred==1][:, 0], X[y_pred==1][:, 1], marker='+',color='r')
plt.title("错误的聚为两类的聚类图")
# 聚类数量正确时的效果
y_pred = KMeans(n_clusters=3, random_state=random_state).fit_predict(X)
plt.subplot(222)
plt.scatter(X[y_pred==0][:, 0], X[y_pred==0][:, 1], marker='x',color='b')
plt.scatter(X[y_pred==1][:, 0], X[y_pred==1][:, 1], marker='+',color='r')
plt.scatter(X[y_pred==2][:, 0], X[y_pred==2][:, 1], marker='1',color='m')
plt.title("正确的聚为三类的聚类图")
# 类间的方差存在差异的效果
X_varied, y_varied = make_blobs(n_samples=n_samples,cluster_std=[1.0, 2.5, 0.5],
random_state=random_state)
y_pred = KMeans(n_clusters=3, random_state=random_state).fit_predict(X_varied)
plt.subplot(223)
plt.scatter(X_varied[y_pred==0][:,0],X_varied[y_pred==0][:, 1], marker='x',color='b')
plt.scatter(X_varied[y_pred==1][:,0],X_varied[y_pred==1][:, 1], marker='+',color='r')
plt.scatter(X_varied[y_pred==2][:,0],X_varied[y_pred==2][:,1], marker='1',color='m')
plt.title("各个类团标准差不均等的聚类图")
# 类的规模差异较大的效果
X_filtered = np.vstack((X[y == 0][:500], X[y == 1][:100], X[y == 2][:10]))
y_pred= KMeans(n_clusters=3, random_state=random_state).fit_predict(X_filtered)
plt.subplot(224)
plt.scatter(X_filtered[y_pred==0][:,0],X_filtered[y_pred==0][:,1], marker='x',color='b')
plt.scatter(X_filtered[y_pred==1][:,0],X_filtered[y_pred==1][:,1], marker='+',color='r')
plt.scatter(X_filtered[y_pred==2][:,0],X_filtered[y_pred==2][:,1], marker='1',color='m')
plt.title("各个类团规模不等的聚类图")
plt.show()