通过自编Python程序实现Apriori算法,实现频繁项集挖掘。购物篮数据来自IBM SPSS Modeler软件自带的BASKETS1n数据集。
def create_C1(data_set):
C1 = set()
for t in data_set:
for item in t:
item_set = frozenset([item])
#将频繁1-项集中的每一项都转换为不可变集合
C1.add(item_set)
return C1
def create_Ck(Lksub1, k): #Lksub1即为L(k-1)
Ck = set()
len_Lksub1 = len(Lksub1)
list_Lksub1 = list(Lksub1)
#利用Fk-1*Fk-1算法生成候选项集
for i in range(len_Lksub1):
for j in range(1, len_Lksub1):
if(i<j):
l1 = list(list_Lksub1[i])
l2 = list(list_Lksub1[j])
l1.sort()
l2.sort()
if l1[0:k-2] == l2[0:k-2]:
Ck_item = list_Lksub1[i] | list_Lksub1[j]
# 子集测试
sub_test=1
for item in Ck_item:
sub_Ck = Ck_item - frozenset([item])
if sub_Ck not in Lksub1:
sub_test=0
break
if sub_test==1:
Ck.add(Ck_item)
return Ck
def generate_Lk_by_Ck(data_set, Ck, min_support,support_data):
Lk = set()
item_count = {}
tran_num = float(len(data_set))
#对候选k-项集出现的次数进行计数
for t in data_set:
for item in Ck:
if item.issubset(t):
if item not in item_count:
item_count[item] = 1
else:
item_count[item] += 1
#通过最小支持度筛选得到频繁k-项集
for item in item_count:
if (item_count[item] / tran_num) >= min_support:
Lk.add(item)
support_data[item] = item_count[item] / tran_num
return Lk
def generate_L(data_set, min_support):
support_data = {}
L=[]
C1 = create_C1(data_set)
L1 = generate_Lk_by_Ck(data_set, C1, min_support, support_data)
L.append(L1)
Lksub1 = L1.copy()
i=2
#从k=2开始设定一个循环,直到候选k-项集或频繁k-项集为空时结束
while 1:
Ci = create_Ck(Lksub1, i)
Li = generate_Lk_by_Ck(data_set, Ci, min_support, support_data)
if len(Li)>0:
L.append(Li)
if len(Ci)==0 or len(Li)==0:
break
Lksub1 = Li.copy()
i+=1
return L,support_data
import xlrd #excel文件读取函数
#定义数据集读取函数
def load_data_set():
data_set=[]
data = xlrd.open_workbook('Basket1n.xls')
table = data.sheets()[0] #索引到第一个工作表
card_id = table.col_values(0)[1:]
for i in range(1,len(card_id)+1):
one_tran=[]
for j in range(7,18): #商品对应的列数
if table.cell(i,j).value=='T':
one_tran.append(table.cell(0,j).value)
data_set.append(one_tran)
return data_set
#读取数据集
data_set = load_data_set()
#通过事务数据集和最小支持度挖掘频繁项集
L,support_data = generate_L(data_set, min_support=0.1)
#输出频繁项集挖掘结果
print('频繁项集'+'\t'+'支持度')
for i in support_data:
print(str(i)+':'+str(support_data[i]))
#共得到17个频繁项集