-
Notifications
You must be signed in to change notification settings - Fork 0
/
Decision_Tree.py
138 lines (108 loc) · 4.4 KB
/
Decision_Tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import numpy as np
from collections import Counter
def entropy(y):
hist = np.bincount(y)
ps = hist/len(y)
return -np.sum([p * np.log2(p) for p in ps if p > 0])
class Node:
def __init__(self, feature=None, threshold=None, left=None, right=None, *, value=None):
self.feature = feature
self.threshold = threshold
self.left = left
self.right = right
self.value = value
def is_leaf_node(self):
return self.value is not None
class DecisionTree:
def __init__(self, min_samples_split=2, max_depth=100, n_feats=None):
super(DecisionTree, self).__init__()
self.min_samples_split = min_samples_split
self.max_depth = max_depth
self.n_feats = n_feats
self.root = None
def fit(self, X, y):
# grow tree
self.n_feats = X.shape[1] if not self.n_feats else min(self.n_feats, X.shape[1])
self.root = self._grow_tree(X, y)
def predict(self, X):
# traverse tree
return np.array([self._traverse_tree(x, self.root) for x in X])
def _grow_tree(self, X, y, depth=0):
n_samples, n_features = X.shape
n_labels = len(np.unique(y))
# stopping criteria
if (depth >= self.max_depth
or n_labels == 1
or n_samples < self.min_samples_split):
leaf_value = self._most_common_label(y)
return Node(value=leaf_value)
feat_idxs = np.random.choice(n_features, self.n_feats, replace=False)
# greedy search
best_feat, best_thresh = self._best_criteria(X, y, feat_idxs)
left_idxs, right_idxs = self._split(X[:, best_feat], best_thresh)
left = self._grow_tree(X[left_idxs, :], y[left_idxs], depth + 1)
right = self._grow_tree(X[right_idxs, :], y[right_idxs], depth + 1)
return Node(best_feat, best_thresh, left, right)
def _best_criteria(self, X, y, feat_idxs):
best_gain = -1
split_idx, split_thresh = None, None
for feat_idx in feat_idxs:
X_column = X[:, feat_idx]
thresholds = np.unique(X_column)
for threshold in thresholds:
gain = self._information_gain(y, X_column, threshold)
if gain > best_gain:
best_gain = gain
split_idx = feat_idx
split_thresh = threshold
return split_idx, split_thresh
def _information_gain(self, y, X_column, split_thresh):
# parent entropy
parent_entropy = entropy(y)
# generate split
left_idx, right_idx = self._split(X_column, split_thresh)
if len(left_idx) == 0 or len(right_idx) == 0:
return 0
# weighted average of children
n = len(y)
n_l, n_r = len(left_idx), len(right_idx)
e_l, e_r = entropy(y[left_idx]), entropy(y[right_idx])
child_entropy = (n_l/n)*e_l + (n_r/n)*e_r
# return ig
ig = parent_entropy - child_entropy
return ig
def _split(self, X_column, split_thresh):
left_idxs = np.argwhere(X_column <= split_thresh).flatten()
right_idxs = np.argwhere(X_column > split_thresh).flatten()
return left_idxs, right_idxs
def _traverse_tree(self, x, node):
if node.is_leaf_node():
return node.value
if x[node.feature] <= node.threshold:
return self._traverse_tree(x, node.left)
return self._traverse_tree(x, node.right)
def _most_common_label(self, y):
counter = Counter(y)
most_common = counter.most_common(1)[0][0]
return most_common
if __name__ == "__main":
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import tree
def accuracy(y_true, y_pred):
accuracy = np.sum(y_true == y_pred) / len(y_true)
return accuracy
data = datasets.load_breast_cancer()
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1234)
clf1 = DecisionTree(max_depth=10)
clf1.fit(X_train, y_train)
y_pred1 = clf1.predict(X_test)
MYacc = accuracy(y_test, y_pred1)
# sklearn tree implementation
clf2 = tree.DecisionTreeClassifier()
clf2.fit(X_train, y_train)
y_pred2 = clf2.predict(X_test)
SKacc = accuracy(y_true=y_test, y_pred=y_pred2)
print('Our model Accuracy', MYacc, 'SK model accuracy', SKacc)