-
Notifications
You must be signed in to change notification settings - Fork 0
/
partition.py
97 lines (85 loc) · 3.57 KB
/
partition.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from best_split import *
from node import Node
"""
Partitions the data set until the stopping criteria are met. Assumes user passed cp = 0.0 (rp.alpha = 0 in rpart)
TODO: add more docs
"""
def partition(node, nodeNum, params):
response = params.response
y = getResponseColumn(node.data, response)
tempcp = AnovaSS(y)
node.dev = tempcp
if nodeNum == 1:
node.cp = node.dev
if nodeNum > 1 and tempcp > node.cp:
tempcp = node.cp
if nodeNum > params.maxNodes or node.numObs < params.minObs or tempcp <= 0:
node.leftNode = None
node.rightNode = None
node.yval = y.iloc[:, 0].mean()
node.dev = tempcp
node.cp = 0
return 0, tempcp
else:
# find best split for the node
left, right = bsplit(node, response, params)
if left is None or right is None:
# couldn't find a split worth doing...
# For the spam dataset (at least), we get to a point where all explanatories are the same but responses
# are different, this wasn't being handled properly.
node.leftNode = None
node.rightNode = None
node.yval = y.iloc[:, 0].mean()
node.dev = tempcp
node.cp = 0
return 0, tempcp
# update 'where' list
leftIndices = list(left.index)
rightIndices = list(right.index)
for index in leftIndices:
params.where[index] = nodeNum * 2
for index in rightIndices:
params.where[index] = (nodeNum * 2) + 1
# partition left side of data set
node.leftNode = Node()
node.leftNode.numObs = len(left)
node.leftNode.data = left
node.leftNode.response = response
node.leftNode.nodeId = nodeNum * 2
node.leftNode.cp = tempcp
leftSplit, leftRisk = partition(node.leftNode, (2 * nodeNum), params)
# update estimate of cp
tempcp = (node.dev - leftRisk) / (leftSplit + 1)
tempcp2 = (node.dev - node.leftNode.dev)
if tempcp < tempcp2:
tempcp = tempcp2
if tempcp > node.cp:
tempcp = node.cp
# partition right side of data set
node.rightNode = Node()
node.rightNode.numObs = len(right)
node.rightNode.data = right
node.rightNode.response = response
node.rightNode.nodeId = (nodeNum * 2) + 1
node.rightNode.cp = tempcp
rightSplit, rightRisk = partition(node.rightNode, 1 + (2 * nodeNum), params)
# calculate actual cp as if this node is the top node of the tree, to be fixed later
# TESTED: produces same cp values as rpart (on cars data set at least)
tempcp = (node.dev - (leftRisk + rightRisk)) / (leftSplit + rightSplit + 1)
if node.rightNode.cp > node.leftNode.cp:
if tempcp > node.leftNode.cp:
leftRisk = node.leftNode.dev
leftSplit = 0
tempcp = (node.dev - (leftRisk + rightRisk)) / (leftSplit + rightSplit + 1)
if tempcp > node.rightNode.cp:
rightRisk = node.rightNode.dev
rightSplit = 0
elif tempcp > node.rightNode.cp:
rightSplit = 0
rightRisk = node.rightNode.dev
tempcp = (node.dev - (leftRisk + rightRisk)) / (leftSplit + rightSplit + 1)
if tempcp > node.leftNode.cp:
leftRisk = node.leftNode.dev
leftSplit = 0
node.cp = (node.dev - (leftRisk + rightRisk)) / (leftSplit + rightSplit + 1)
return leftSplit + rightSplit + 1, leftRisk + rightRisk