-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhw1.cs4641_decisionTrees_Restaurants.py
105 lines (84 loc) · 3.82 KB
/
hw1.cs4641_decisionTrees_Restaurants.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#Sheena Ganju, CS 4641 HW 1
#Decision Trees
#code uses scikit-learn decision trees, code from cstrelioff on github and
#http://scikit-learn.org/stable/modules/tree.html, skeleton code from
#http://dataaspirant.com/2017/02/01/decision-tree-algorithm-python-with-scikit-learn/
#import statements
from sklearn import tree
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier
from sklearn import preprocessing
from sklearn.model_selection import cross_val_predict
from sklearn.model_selection import learning_curve
from sklearn.model_selection import train_test_split
import scikitplot as skplt
import matplotlib.pyplot as plt
import csv
import numpy as np
import pandas as pd
import time
from sklearn.model_selection import validation_curve
from datetime import date
#Read data in using pandas
trainDataSet = pd.read_csv("geoplaces2.csv", sep = ',', header = None, low_memory = False)
#encode text data to integers using getDummies
traindata = pd.get_dummies(trainDataSet)
# Create decision Tree using major_category, month, year, to predict violent or not
# train split uses default gini node, split using train_test_split
X = traindata.values[1:, 1:]
Y = traindata.values[1:,0]
#start timer
t0= time.clock()
cv = train_test_split(X, Y, test_size=.33, random_state= 20)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.33, random_state= 20)
clf = tree.DecisionTreeClassifier(criterion = "gini", splitter='random', min_samples_leaf = 10)
clf = clf.fit(X_train, Y_train)
train_prediction = clf.predict(X_train)
trainaccuracy = accuracy_score(train_prediction, Y_train)*100
print("The training accuracy for this is " +str(trainaccuracy))
#output
Y_prediction = clf.predict(X_test)
accuracy = accuracy_score(Y_test, Y_prediction)*100
print("The test classification works with " + str(accuracy) + "% accuracy without pruning")
#precision outcomes
from sklearn.metrics import precision_score
from sklearn.metrics import log_loss
precision = precision_score(Y_test, Y_prediction, average = "weighted")*100
loss = log_loss(Y_test, Y_prediction)*100
print("Precision: " + str(precision))
print("Loss: " + str(loss))
#Graph Viz
with open("decisionTree1.txt", 'w') as f:
f = tree.export_graphviz(clf, out_file = f)
##SciKit Pruning Code (from Piazza)
#info at http://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html#sphx-glr-auto-examples-tree-plot-unveil-tree-structure-py
def prune(tree, min_samples_leaf = 3):
if tree.min_samples_leaf >= min_samples_leaf:
raise Exception('Tree already more pruned')
else:
tree.min_samples_leaf = min_samples_leaf
tree = tree.tree_
for i in range(tree.node_count):
n_samples = tree.n_node_samples[i]
if n_samples <= min_samples_leaf:
tree.children_left[i]=-1
tree.children_right[i]=-1
print(clf.min_samples_leaf)
#Call pruning, plotting functions
#prune(clf)
print(clf.min_samples_leaf)
clf = tree.DecisionTreeClassifier(criterion = "gini", splitter='random', min_samples_leaf = 10, max_depth = 3)
clf.fit(X_train, Y_train)
#Final Score, time taken to run model
new_Y_prediction = clf.predict(X_test)
accuracy2 = accuracy_score(Y_test, new_Y_prediction)*100
print("This classification works with " + str(accuracy2) + "% accuracy with pruning, as this algorithm implements pruning")
#Graph Viz o Pruned Graph
with open("decisionTree2.txt", 'w') as f:
f = tree.export_graphviz(clf, out_file = f)
#time program took to run
print(str(time.time() - t0) + " seconds wall time.")
#Visualizations for model accuracy
#Learning Curve Estimator, Cross Validation
skplt.estimators.plot_learning_curve(clf, X, Y, title = "Learning Curve: Decision Trees")
plt.show()