-
Notifications
You must be signed in to change notification settings - Fork 0
/
FP-growth.py
182 lines (152 loc) · 4.72 KB
/
FP-growth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# -*- coding: UTF-8 -*-
import time
from functools import wraps
def fn_timer(function):
@wraps(function)
def function_timer(*args, **kwargs):
t0 = time.time()
result = function(*args, **kwargs)
t1 = time.time()
print ("Total time running %s: %s seconds" %
(function.func_name, str(t1 - t0))
)
return result
return function_timer
class Node:
def __init__(self, name, num, parent):
self.name = name
self.num = num
self.parent = parent
self.children = {}
def add(self, num):
self.num += num
@fn_timer
def getFormatData():
formatData = []
Data = open("retail.dat", 'r')
for line in Data:
formatData.append([int(i) for i in line.split()])
return formatData
# @fn_timer
def getFirstData(formatData):
dic = {}
numDict = {}
headDict = {}
newData = {}
for line in formatData:
for item in line:
numDict[item] = numDict.get(item, 0) + 1
for line in formatData:
# line.sort(key=lambda item: numDict[item])
newLine = []
for item in line:
if numDict[item] >= minSupport:
newLine.append(item)
dic[item] = numDict[item]
headDict[item] = []
if len(newLine) > 0:
newLine = frozenset(newLine)
newData[newLine] = newData.get(newLine, 0) + 1
return newData, dic, headDict
def getNewData(formatData):
dic = {}
numDict = {}
headDict = {}
newData = {}
for line, num in formatData.items():
for item in line:
numDict[item] = numDict.get(item, 0) + num
for line, num in formatData.items():
newLine = []
for item in line:
if numDict[item] >= minSupport:
newLine.append(item)
dic[item] = numDict[item]
headDict[item] = []
if len(newLine) > 0:
newData[frozenset(newLine)] = newData.get(
frozenset(newLine), 0) + num
# newData[frozenset(newLine)] = num
return newData, dic, headDict
def findFreq(headDict, oldNumDict, lastRes): # res = []
a = sorted(oldNumDict.items(), key=lambda item: (
item[1], item[0]), reverse=True)
for key, value in a:
nodeList = headDict[key]
temp = lastRes[:]
temp.append(key)
result[frozenset(temp)] = oldNumDict[key]
formatData = creatFormatData(nodeList)
if len(formatData) != 0:
newData, numDict, newHeadDict = getNewData(formatData)
if len(newData) == 0:
continue
retNode, newHeadDict = creatTree(newData, newHeadDict, numDict)
findFreq(newHeadDict, numDict, temp)
return
def creatFormatData(nodeList):
newData = {}
for node in nodeList:
line = []
num = node.num
node = node.parent
while node.name != 'Null':
line.append(node.name)
node = node.parent
if len(line) > 0:
newData[frozenset(line)] = num
return newData
# @fn_timer
def creatTree(newData, headDict, numDict):
retNode = Node('Null', 1, None)
for line, num in newData.items():
line = list(line)
line.sort(key=lambda item: (numDict[item], item), reverse=True)
insertNode(line, retNode, headDict, num)
return retNode, headDict
# @fn_timer
def insertNode(line, parentNode, headDict, num):
item = line[0]
if item in parentNode.children:
parentNode.children[item].add(num)
else:
newNode = Node(item, num, parentNode)
parentNode.children[item] = newNode
headDict[item].append(newNode)
if len(line) > 1:
insertNode(line[1:], parentNode.children[item], headDict, num)
return
result = {}
# 总共88162条数据
minSupport = 881 # 检查作业要求最小支持度用数量来表示
# 4410
# 881
# 441
@fn_timer
def main():
formatData = getFormatData()
newData, numDict, headDict = getFirstData(formatData)
retNode, headDict = creatTree(newData, headDict, numDict)
findFreq(headDict, numDict, [])
formatRes = {}
for item, num in result.items():
length = len(item)
a = formatRes.get(length, 0)
if a != 0:
formatRes[length][item] = num
else:
formatRes[length] = {item: num}
# a = []
sum = 0
for num, item in formatRes.items():
print('%d项集个数:%d' % (num, len(item))),
print('--------------------')
sum += len(item)
# for k, v in item.items():
# print(list(k), v),
# # a.append((k))
# print('.')
# print(a)
print('总共%d个' % sum)
if __name__ == '__main__':
main()