-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrainer.py
85 lines (63 loc) · 2.71 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import pandas
import numpy
import warnings
import matplotlib.pylab as plt
from sklearn import preprocessing
from sklearn.ensemble import RandomForestClassifier
from sklearn import svm
from sklearn import cross_validation as cv
warnings.filterwarnings('ignore', category = DeprecationWarning, module = 'pandas', lineno = 570)
def return_nonstring_col(data_cols):
"""docstring for return_nonstring_col"""
cols_to_keep = []
train_cols = []
for col in data_cols:
if col != 'URL' and col != 'host' and col != 'path':
cols_to_keep.append(col)
if col != 'malicious' and col != 'result':
train_cols.append(col)
return [cols_to_keep, train_cols]
def svm_classifier(train, query, train_cols):
clf = svm.SVC()
train[train_cols] = preprocessing.scale(train[train_cols])
query[train_cols] = preprocessing.scale(query[train_cols])
print clf.fit(train[train_cols], train['malicious'])
scores = cv.cross_val_score(clf, train[train_cols], train['malicious'], cv = 30)
print('Estimated score SVM: %0.5f (+/- %0.5f)' % (scores.mean(), scores.std() / 2))
query['result'] = clf.predict(query[train_cols])
print query[['URL', 'result']]
def forest_classifier(train, query, train_cols):
"""docstring for forest_classifier"""
rf = RandomForestClassifier(n_estimators = 150)
print rf.fit(train[train_cols], train['malicious'])
scores = cv.cross_val_score(rf, train[train_cols], train['malicious'], cv = 30)
print('Estimated score RandomForestClassifier: %0.5f (+/- %0.5f)' % (scores.mean(), scores.std() / 2))
query['result'] = rf.predict(query[train_cols])
print query['result']
# Called from gui
def forest_classifier_gui(train, query, train_cols):
"""docstring for forest_classifier_gui"""
rf = RandomForestClassifier(n_estimators = 150)
print rf.fit(train[train_cols], train['malicious'])
query['result']= rf.predict(query[train_cols])
print query[['URL', 'result']].head(2)
return query['result']
def train(db, test_db):
"""docstring for train"""
train_csv = pandas.read_csv(db)
cols_to_keep, train_cols = return_nonstring_col(train_csv.columns)
train = train_csv[cols_to_keep]
query_csv = pandas.read_csv(test_db)
cols_to_keep, train_cols = return_nonstring_col(query_csv.columns)
# query = query_csv[cols_to_keep]
svm_classifier(train_csv, query_csv, train_cols)
forest_classifier(train_csv, query_csv, train_cols)
def gui_caller(db, test_db):
"""docstring for gui_caller"""
train_csv = pandas.read_csv(db)
cols_to_keep, train_cols = return_nonstring_col(train_csv.columns)
train = train_csv[cols_to_keep]
query_csv = pandas.read_csv(test_db)
cols_to_keep, train_cols = return_nonstring_col(query_csv.columns)
# query = query_csv[cols_to_keep]
return forest_classifier_gui(train_csv, query_csv, train_cols)