Apriori Algorithm

戴高远
2023-12-01

该算法输出频繁项集,参考自:
https://github.com/asaini/Apriori
加入了剪枝步

1. 算法实现步骤

  • 1) 读取文件,生成事务数据库transactions
  • 2) 生成一项频繁集L1_Sup

    提取item:从事务数据库中提取item.
    计数:计算每个item的在数据库中的事务里出现次数.
    生成:判断是否大于设定的支持度,是则把item及count放入L1_Sup中.

  • 3) 生成LkPlus_Sup,即L(k+1)

    连接步:由Lk生成CkPlus.
    剪枝步:由CkPlus到CkPlusLast.
    计数:对CkPlusLast的各itemset计算在数据库中的事务里的出现次数.
    生成:判断是否大于设定的支持度,是则把itemset加入LkPlus_Sup.

  • 4) 把L1_Sup,…,Lk_Sup依次加入频繁项集的集合Lnlist

  • 5) 输出Lnlist

2. 数据结构

  • 事务.
id = frozenset{'1','2', ....}
  • 事务数据库.
    transactions = {}
    transactions = {frozenset{'1','2'}, frozenset{'1','2'}, id, .... }
  • k项集
    itemset = frozenset({'1', '5', '3', '2'})
  • k频繁项集:k项集是频繁的项的集合
    Lk_Sup = defaultdict(int)
    Lk_Sup = {itemset1:count1, itemset2: count2, ....}
  • 频繁项集的集合Lnlist
    Lnlist = []
    Lnlist = [L1_Sup, L2_Sup, L3_Sup]

3. 代码实现

# -*- coding: utf-8 -*-
# python3
# Name: Apriori Algorithm
# Author: Suc Liu
# Adress: USTC
# Date: 24/11/2017
'''
usage:
#AprioriData.txt数据集文件名,id各项以空格隔开,换行符结尾
Lnlist = AprioriGo('AprioriData.txt', 0.15) 
PrintItemFreq(Lnlist)

'''
from collections import defaultdict
def PrintSet(name,set):
    print(name, 'start')
    for key in set:
        print(key)
    print(name, 'finish')
def PrintDic(name, dic):
    print(name, 'start')
    for key, value in dic.items():
        print('key:%s value:%s' %(key, value))
    print(name, 'finish')

def PrintList(name, lis):
    for item in lis:
        print('item:%s             count:%s' %(set(item), lis[item]) )

def PrintItemFreq(Lnlist): #输出结果 按1-n项集 输出
    k = 1
    for item in Lnlist:
        print('%s-itemset' %k)
        PrintList('item', item)
        k += 1

def GetTransactions(filename):#获得事物数据库的集合,返回transactions:{frozenset{'1','2'}, frozenset{'1','2'} .... }
    transactions = set() #{frozenset({'2','3'}),frozenset({...}),  ....}
    with open(filename) as f: #with open() as f 使用完自动关闭file,官方推荐用
        for line in f: #读取行字符串,以'\n'结束,并读入'\n'
            line = line.strip(' \n')#去掉首尾的空格字符和换行符
            line = line.split()#按空格划分,返回list,多个空格算一个空格;若line.spit(' ')则会以一个空格划分,多出的空格将成为元素。
            id = frozenset(line) # Making id be a frozenset type so we can add it to a set: Transactions
            transactions.add(id) #remember that only harshable element can be set's ele.
    return transactions

def L1Gen(transactions, support):#生成一项的频繁集合 返回L1_Sup:{frozenset{'3'}, frozenset{'23'}, ....}
    C1 = set() #C1:{frozenset{'3'}, frozenset{'23'}, ....}
    C1_Sup = defaultdict(int)
    L1_Sup = {}
    for id in transactions: #id:{'1', '3', ......},找出事物涉及的项
        for item in id: #item is string type , item:'32'
            C1.add(frozenset([item]) )#use <set>'s property to eliminate  duplicate ele. 
                                      #frozenset([item]) make item be a frozenset

    for item in C1:#对项计算count
        for id in transactions:
            if item.issubset(id):
                C1_Sup[item] += 1

    count = support * len(transactions)
    for item in C1_Sup:#筛选出频繁项 ,并保存对应的count
        if C1_Sup[item] >= count:
            L1_Sup[item]= C1_Sup[item]
    return L1_Sup

def LkPlusGen(Lk_Sup, transactions, support):#生成L(k+1)
    Lk = Lk_Sup.keys()
    CkPlus = set()
    CkPlusLast = set()
    CkPlusLast_Sup = defaultdict(int)
    LkPlusLast_Sup = defaultdict(int)
    #连接步    
    for i in Lk: #smart method to join step
        for j in Lk:
            if len(i.union(j)) == len(i)+1:
                CkPlus.add(i.union(j))
    #剪枝步
    for itemset in CkPlus: #依次遍历连接步产生的项集
        for i in range(len(itemset)): # 依次遍历itemset的每个元素,并只删除它,判断它们是不是都是频繁的
                itemlist = list(itemset) #list类型才能按下标删除元素
                del itemlist[i] #删除会改变原list
                if not set([frozenset(itemlist)]).issubset(set(Lk)): #itemset任意一个子集只要不属于L[k-2],
                                                                     #set([frozenset(itemlist)])转变数据类型
                    break
                elif i == len(itemset) -1: #每个itemlist都是Lk的子集
                    CkPlusLast.update({itemset})

    #计算CkPlusLast各项的支持计数
    for itemset in CkPlusLast:
        for id in transactions:
            if itemset.issubset(id):
                CkPlusLast_Sup[itemset] += 1

    #选出 k+1的频繁项集
    count = support * len(transactions)
    for itemset in CkPlusLast_Sup:
        if CkPlusLast_Sup[itemset] >= count:
            LkPlusLast_Sup[itemset]= CkPlusLast_Sup[itemset]
    return LkPlusLast_Sup

def AprioriGo(datafile, support):#输入文件名, 支持度(%)
    Lnlist = [] #globle variable to save L1...Lk
    transactions = GetTransactions(datafile)
    L1_Sup = L1Gen(transactions, support)
    Lk_Sup = L1_Sup     
    while len(Lk_Sup) : #通过生成的Lk_Sup判断结尾
        Lnlist.insert(len(Lnlist),Lk_Sup) #insert 插入末尾
        LkPlus_Sup = LkPlusGen(Lk_Sup, transactions, support)
        Lk_Sup = LkPlus_Sup
    return  Lnlist  

Lnlist = AprioriGo('AprioriData.txt', 0.15)
PrintItemFreq(Lnlist)
 类似资料:

相关阅读

相关文章

相关问答