-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
+ 8 models trained for both tasks + fixed metric-related edge cases + model files are now ignored due to RandomForest producing 400MB of data
- Loading branch information
1 parent
9ccd967
commit e5802bd
Showing
23 changed files
with
1,248 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
*.pyc | ||
*.env | ||
mergedb.json | ||
*.joblib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import sqlite3, sys, os, getopt, pandas as pd | ||
from sklearn.ensemble import AdaBoostClassifier | ||
from sklearn.preprocessing import StandardScaler | ||
from sklearn.model_selection import cross_val_score | ||
from joblib import dump | ||
from utils import NameClassifierUtils as utils | ||
|
||
|
||
HELP = 'Usage:\npython train_adaboost.py --dbpath="<database path>"\n' | ||
MODEL_FILE = 'names_adaboost.joblib' | ||
|
||
def train_adaboost(conn: sqlite3.Connection): | ||
"""Trains function name classifier using AdaBoost model (scikit-learn) and saves it to a file.""" | ||
cur = conn.cursor() | ||
|
||
print('Loading FastText model...') | ||
try: | ||
ft = utils.load_ft(utils.get_embedder_path()) | ||
except Exception as ex: | ||
print(ex) | ||
sys.exit() | ||
|
||
print("Fetching data...") | ||
tokens = utils.query_tokens(cur) | ||
pdb = utils.query_pdb(cur) | ||
df = utils.balance_dataset(tokens, pdb) | ||
|
||
literals = df['literal'] | ||
labels = df['is_name'] | ||
|
||
print("Splitting datasets...") | ||
x_train, _, y_train, _ = utils.split_dataset(literals, labels) | ||
|
||
print("Performing word embedding...") | ||
x_train = pd.DataFrame(data=x_train, columns = ['literal']) | ||
x_train = utils.ft_embed(ft, x_train) | ||
x_train = utils.listify(x_train['lit_vec'].to_list()) | ||
y_train = tuple(y_train.to_list()) | ||
|
||
print("Scaling data...") | ||
scaler = StandardScaler() | ||
scaler.fit(x_train) | ||
scaler.transform(x_train) | ||
|
||
print('Initializing classifier model...') | ||
# defaults to 50 estimators | ||
ab = AdaBoostClassifier(n_estimators=50, random_state=0) | ||
|
||
print("Cross-validation (5-fold)...") | ||
scores = cross_val_score(ab, X=x_train, y=y_train) | ||
print("Accuracy: %0.3f" % (scores.mean())) | ||
print("Std_dev: %0.3f" % (scores.std())) | ||
|
||
print("Training classifier...") | ||
ab.fit(X=x_train, y=y_train) | ||
file_path = utils.get_model_path(MODEL_FILE) | ||
dump(ab, file_path) | ||
print(f'Model saved to {file_path}') | ||
|
||
def main(argv): | ||
db_path = "" | ||
opts, _ = getopt.getopt(argv,"hd:",["dbpath="]) | ||
for opt, arg in opts: | ||
if opt == '-h': | ||
print(HELP) | ||
sys.exit() | ||
elif opt in ("-d", "--dbpath"): | ||
db_path = arg | ||
|
||
if db_path == "": | ||
raise Exception(f"SQLite database path required\n{HELP}") | ||
if not os.path.isfile(db_path): | ||
raise Exception(f"Database not found at {db_path}") | ||
|
||
conn = sqlite3.connect(db_path) | ||
train_adaboost(conn) | ||
conn.close() | ||
|
||
if __name__ == "__main__": | ||
main(sys.argv[1:]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import sqlite3, sys, os, getopt, pandas as pd | ||
from sklearn.tree import DecisionTreeClassifier | ||
from sklearn.preprocessing import StandardScaler | ||
from sklearn.model_selection import cross_val_score | ||
from joblib import dump | ||
from utils import NameClassifierUtils as utils | ||
|
||
|
||
HELP = 'Usage:\npython train_dtree.py --dbpath="<database path>"\n' | ||
MODEL_FILE = 'names_dtree.joblib' | ||
|
||
def train_decision_tree(conn: sqlite3.Connection): | ||
"""Trains function name classifier using Decision Tree model (scikit-learn) and saves it to a file.""" | ||
cur = conn.cursor() | ||
|
||
print('Loading FastText model...') | ||
try: | ||
ft = utils.load_ft(utils.get_embedder_path()) | ||
except Exception as ex: | ||
print(ex) | ||
sys.exit() | ||
|
||
print("Fetching data...") | ||
tokens = utils.query_tokens(cur) | ||
pdb = utils.query_pdb(cur) | ||
df = utils.balance_dataset(tokens, pdb) | ||
|
||
literals = df['literal'] | ||
labels = df['is_name'] | ||
|
||
print("Splitting datasets...") | ||
x_train, _, y_train, _ = utils.split_dataset(literals, labels) | ||
|
||
print("Performing word embedding...") | ||
x_train = pd.DataFrame(data=x_train, columns = ['literal']) | ||
x_train = utils.ft_embed(ft, x_train) | ||
x_train = utils.listify(x_train['lit_vec'].to_list()) | ||
y_train = tuple(y_train.to_list()) | ||
|
||
print("Scaling data...") | ||
scaler = StandardScaler() | ||
scaler.fit(x_train) | ||
scaler.transform(x_train) | ||
|
||
print('Initializing classifier model...') | ||
tree = DecisionTreeClassifier(random_state=0) | ||
|
||
print("Cross-validation (5-fold)...") | ||
scores = cross_val_score(tree, X=x_train, y=y_train) | ||
print("Accuracy: %0.3f" % (scores.mean())) | ||
print("Std_dev: %0.3f" % (scores.std())) | ||
|
||
print("Training classifier...") | ||
tree.fit(X=x_train, y=y_train) | ||
file_path = utils.get_model_path(MODEL_FILE) | ||
dump(tree, file_path) | ||
print(f'Model saved to {file_path}') | ||
|
||
def main(argv): | ||
db_path = "" | ||
opts, _ = getopt.getopt(argv,"hd:",["dbpath="]) | ||
for opt, arg in opts: | ||
if opt == '-h': | ||
print(HELP) | ||
sys.exit() | ||
elif opt in ("-d", "--dbpath"): | ||
db_path = arg | ||
|
||
if db_path == "": | ||
raise Exception(f"SQLite database path required\n{HELP}") | ||
if not os.path.isfile(db_path): | ||
raise Exception(f"Database not found at {db_path}") | ||
|
||
conn = sqlite3.connect(db_path) | ||
train_decision_tree(conn) | ||
conn.close() | ||
|
||
if __name__ == "__main__": | ||
main(sys.argv[1:]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import sqlite3, sys, os, getopt, pandas as pd | ||
from sklearn.neighbors import KNeighborsClassifier | ||
from sklearn.preprocessing import StandardScaler | ||
from sklearn.model_selection import cross_val_score | ||
from joblib import dump | ||
from utils import NameClassifierUtils as utils | ||
|
||
|
||
HELP = 'Usage:\npython train_knn.py --dbpath="<database path>"\n' | ||
MODEL_FILE = 'names_knn.joblib' | ||
|
||
def train_nearest_neighbours(conn: sqlite3.Connection): | ||
"""Trains function name classifier using k-Nearest Neighbors model (scikit-learn) and saves it to a file.""" | ||
cur = conn.cursor() | ||
|
||
print('Loading FastText model...') | ||
try: | ||
ft = utils.load_ft(utils.get_embedder_path()) | ||
except Exception as ex: | ||
print(ex) | ||
sys.exit() | ||
|
||
print("Fetching data...") | ||
tokens = utils.query_tokens(cur) | ||
pdb = utils.query_pdb(cur) | ||
df = utils.balance_dataset(tokens, pdb) | ||
|
||
literals = df['literal'] | ||
labels = df['is_name'] | ||
|
||
print("Splitting datasets...") | ||
x_train, _, y_train, _ = utils.split_dataset(literals, labels) | ||
|
||
print("Performing word embedding...") | ||
x_train = pd.DataFrame(data=x_train, columns = ['literal']) | ||
x_train = utils.ft_embed(ft, x_train) | ||
x_train = utils.listify(x_train['lit_vec'].to_list()) | ||
y_train = tuple(y_train.to_list()) | ||
|
||
print("Scaling data...") | ||
scaler = StandardScaler() | ||
scaler.fit(x_train) | ||
scaler.transform(x_train) | ||
|
||
print('Initializing classifier model...') | ||
# 5 neighbors is the default | ||
knn = KNeighborsClassifier(n_neighbors=5) | ||
|
||
print("Cross-validation (5-fold)...") | ||
scores = cross_val_score(knn, X=x_train, y=y_train) | ||
print("Accuracy: %0.3f" % (scores.mean())) | ||
print("Std_dev: %0.3f" % (scores.std())) | ||
|
||
print("Training classifier...") | ||
knn.fit(X=x_train, y=y_train) | ||
file_path = utils.get_model_path(MODEL_FILE) | ||
dump(knn, file_path) | ||
print(f'Model saved to {file_path}') | ||
|
||
def main(argv): | ||
db_path = "" | ||
opts, _ = getopt.getopt(argv,"hd:",["dbpath="]) | ||
for opt, arg in opts: | ||
if opt == '-h': | ||
print(HELP) | ||
sys.exit() | ||
elif opt in ("-d", "--dbpath"): | ||
db_path = arg | ||
|
||
if db_path == "": | ||
raise Exception(f"SQLite database path required\n{HELP}") | ||
if not os.path.isfile(db_path): | ||
raise Exception(f"Database not found at {db_path}") | ||
|
||
conn = sqlite3.connect(db_path) | ||
train_nearest_neighbours(conn) | ||
conn.close() | ||
|
||
if __name__ == "__main__": | ||
main(sys.argv[1:]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import sqlite3, sys, os, getopt, pandas as pd | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.preprocessing import StandardScaler | ||
from sklearn.model_selection import cross_val_score | ||
from joblib import dump | ||
from utils import NameClassifierUtils as utils | ||
|
||
|
||
HELP = 'Usage:\npython train_logreg.py --dbpath="<database path>"\n' | ||
MODEL_FILE = 'names_logreg.joblib' | ||
|
||
def train_logistic_regression(conn: sqlite3.Connection): | ||
"""Trains function name classifier using Logistic Regression model (scikit-learn) and saves it to a file.""" | ||
cur = conn.cursor() | ||
|
||
print('Loading FastText model...') | ||
try: | ||
ft = utils.load_ft(utils.get_embedder_path()) | ||
except Exception as ex: | ||
print(ex) | ||
sys.exit() | ||
|
||
print("Fetching data...") | ||
tokens = utils.query_tokens(cur) | ||
pdb = utils.query_pdb(cur) | ||
df = utils.balance_dataset(tokens, pdb) | ||
|
||
literals = df['literal'] | ||
labels = df['is_name'] | ||
|
||
print("Splitting datasets...") | ||
x_train, _, y_train, _ = utils.split_dataset(literals, labels) | ||
|
||
print("Performing word embedding...") | ||
x_train = pd.DataFrame(data=x_train, columns = ['literal']) | ||
x_train = utils.ft_embed(ft, x_train) | ||
x_train = utils.listify(x_train['lit_vec'].to_list()) | ||
y_train = tuple(y_train.to_list()) | ||
|
||
print("Scaling data...") | ||
scaler = StandardScaler() | ||
scaler.fit(x_train) | ||
scaler.transform(x_train) | ||
|
||
print('Initializing classifier model...') | ||
lr = LogisticRegression(random_state=0) | ||
|
||
print("Cross-validation (5-fold)...") | ||
scores = cross_val_score(lr, X=x_train, y=y_train) | ||
print("Accuracy: %0.3f" % (scores.mean())) | ||
print("Std_dev: %0.3f" % (scores.std())) | ||
|
||
print("Training classifier...") | ||
lr.fit(X=x_train, y=y_train) | ||
file_path = utils.get_model_path(MODEL_FILE) | ||
dump(lr, file_path) | ||
print(f'Model saved to {file_path}') | ||
|
||
def main(argv): | ||
db_path = "" | ||
opts, _ = getopt.getopt(argv,"hd:",["dbpath="]) | ||
for opt, arg in opts: | ||
if opt == '-h': | ||
print(HELP) | ||
sys.exit() | ||
elif opt in ("-d", "--dbpath"): | ||
db_path = arg | ||
|
||
if db_path == "": | ||
raise Exception(f"SQLite database path required\n{HELP}") | ||
if not os.path.isfile(db_path): | ||
raise Exception(f"Database not found at {db_path}") | ||
|
||
conn = sqlite3.connect(db_path) | ||
train_logistic_regression(conn) | ||
conn.close() | ||
|
||
if __name__ == "__main__": | ||
main(sys.argv[1:]) |
Oops, something went wrong.