Skip to content

Commit

Permalink
classifier evaluation framework
Browse files Browse the repository at this point in the history
+ scripts for training, evaluation, testing and persistent results for classifiers of both kinds
+ clean db (token path cleansing)
+ removed primary key on `pdb` table of merged db
  • Loading branch information
michal-kapala committed Aug 28, 2023
1 parent a88289a commit 9ccd967
Show file tree
Hide file tree
Showing 15 changed files with 698 additions and 136 deletions.
28 changes: 26 additions & 2 deletions models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,33 @@ Serialized classifier models with corresponding training and testing scripts.

* `embedder` - [`FastText`](https://fasttext.cc/) word embedder model for token text vectorization (self-trained, with source)
* `names` - training and test scripts for function name classifiers
* `paths` - training and test scripts for function name classifiers

## Training
Training scripts create a dataset on the fly, perform 5-fold cross-validation and serialize the model.

Usage:
```
cd <names/paths>
python train_<classifier>.py --dbpath="<dataset db path>"
```

## Testing
Testing scripts reconstruct training/test datasets, load a serialized model, make predictions and save the results to a separate results database (model file name is the name of the SQLite table).

Usage:
```
cd <names/paths>
python test.py --dbpath="<dataset db path>" --results="<results db path>" --model="<model file name>"
```

## Files
Naming convention for model files is `<classifier set>_<classifier name>.<serialization source>`.

### Naming
Naming convention for classifier model files:
* function names - `names_<classifier name>.joblib`
* cross-reference paths - `paths_<classifier name>.joblib`

Extensions:
* `*.joblib` - classifier models serialized with [`joblib`](https://pypi.org/project/joblib/)
* `*.ft` - FastText model serialized with [`gensim.models.FastText`](https://radimrehurek.com/gensim/models/fasttext.html)
* `*.ft` - FastText word2vec model serialized with [`gensim.models.FastText`](https://radimrehurek.com/gensim/models/fasttext.html)
Binary file modified models/embedder/embedder.ft
Binary file not shown.
2 changes: 1 addition & 1 deletion models/embedder/train_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def train_token_embedder(conn: sqlite3.Connection):

def main(argv):
db_path = ""
opts, args = getopt.getopt(argv,"hd:",["dbpath="])
opts, _ = getopt.getopt(argv,"hd:",["dbpath="])
for opt, arg in opts:
if opt == '-h':
print(HELP)
Expand Down
123 changes: 123 additions & 0 deletions models/names/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import sqlite3, sys, os, getopt, pandas as pd
from datetime import datetime
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import StandardScaler
from joblib import load
from utils import NameClassifierUtils as utils


HELP = 'Usage:\npython test.py --dbpath="<dataset db path>" --results"<results db path>" --model="<model filename>"\n'
ITERS = 10
"""Number of testing iterations (predictions made on different sample sets). Data shuffling used is reproducible."""

def test_model(conn: sqlite3.Connection, results_path: str, model: str):
"""Tests function name classifier model of choice and saves the results."""
cur = conn.cursor()
start = datetime.now()

print("Fetching data...")
tokens = utils.query_tokens(cur)
pdb = utils.query_pdb(cur)
df = utils.balance_dataset(tokens, pdb)

print('Loading FastText model...')
try:
ft = utils.load_ft(utils.get_embedder_path())
except Exception as ex:
print(ex)
sys.exit()
literals = df['literal']
labels = df['is_name']

print("Splitting datasets...")
_, x_test, _, y_test = utils.split_dataset(literals, labels)

print("Performing word embedding...")
x_test = pd.DataFrame(data=x_test, columns = ['literal'])
x_test = utils.ft_embed(ft, x_test)
X_test = utils.listify(x_test['lit_vec'].to_list())
y_test = tuple(y_test.to_list())

# scaling
scaler = StandardScaler()
scaler.fit(X_test)
scaler.transform(X_test)

print('Loading classifier model...')
file_path = utils.get_model_path(model)
try:
clf = load(file_path)
except Exception as ex:
print(ex)
sys.exit()

print("Predicting...")
y_pred = clf.predict(X=X_test)

# stats
tn, fp, fn, tp = confusion_matrix(y_test, y_pred, labels=[0, 1]).ravel()
pos = tp + fn
neg = tn + fp
accuracy = (tp + tn) / (pos + neg)
precision = tp / (tp + fp)
recall = tp / pos
f1 = 2 * precision * recall / (precision + recall)
print(f"Accuracy: {accuracy * 100:.3f}%")
results = {
"pos": pos,
"neg": neg,
"tp": tp,
"tn": tn,
"fp": fp,
"fn": fn,
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1": f1
}

table = model.replace('.joblib', '')
utils.save_results(results, table, results_path)

print(f'Start time:\t{start}')
print(f'End time:\t{datetime.now()}')

def main(argv):
db_path = ""
results_path = ""
opts, _ = getopt.getopt(argv,"hdr:",["dbpath=", "results=", "model="])
for opt, arg in opts:
if opt == '-h':
print(HELP)
sys.exit()
elif opt in ("-d", "--dbpath"):
db_path = arg
elif opt in ("-r", "--results"):
results_path = arg
elif opt in ("-m", "--model"):
model = arg

if db_path == "":
raise Exception(f"Dataset SQLite database path required\n{HELP}")
if results_path == "":
raise Exception(f"Results SQLite database path required\n{HELP}")
if model == "":
raise Exception(f"Model file name (with extension) required\n{HELP}")

if not os.path.isfile(db_path):
raise Exception(f"Dataset database not found at {db_path}")

if not os.path.isfile(results_path):
raise Exception(f"Results database not found at {results_path}")

model_path = utils.get_model_path(model)
if not os.path.isfile(model_path):
raise Exception(f"Model not found at {model_path}")

conn = sqlite3.connect(db_path)
test_model(conn, results_path, model)
conn.close()


if __name__ == "__main__":
main(sys.argv[1:])
96 changes: 0 additions & 96 deletions models/names/test_gnbayes.py

This file was deleted.

25 changes: 14 additions & 11 deletions models/names/train_gnbayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,24 @@
def train_naive_bayes(conn: sqlite3.Connection):
"""Trains function name classifier using Gaussian Naive Bayes (scikit-learn) model and saves it to a file."""
cur = conn.cursor()
print("Fetching data...")
tokens = utils.query_tokens(cur)
pdb = utils.query_pdb(cur)
df = utils.balance_dataset(tokens, pdb)

print('Loading FastText model...')
try:
ft = utils.load_ft(utils.get_embedder_path())
except Exception as ex:
print(ex)
sys.exit()

print("Fetching data...")
tokens = utils.query_tokens(cur)
pdb = utils.query_pdb(cur)
df = utils.balance_dataset(tokens, pdb)

literals = df['literal']
labels = df['is_name']

print("Splitting datasets...")
x_train, x_test, y_train, y_test = utils.split_dataset(literals, labels)
x_train, _, y_train, _ = utils.split_dataset(literals, labels)

print("Performing word embedding...")
x_train = pd.DataFrame(data=x_train, columns = ['literal'])
Expand All @@ -41,9 +44,10 @@ def train_naive_bayes(conn: sqlite3.Connection):

gnb = GaussianNB()

# cross-validation
scores = cross_val_score(gnb, X=x_train, y=y_train, cv=10)
print("Accuracy: %0.3f, std_dev: %0.3f" % (scores.mean(), scores.std()))
print("Cross-validation (5-fold)...")
scores = cross_val_score(gnb, X=x_train, y=y_train)
print("Accuracy: %0.3f" % (scores.mean()))
print("Std_dev: %0.3f" % (scores.std()))

print("Training classifier...")
gnb.fit(X=x_train, y=y_train)
Expand All @@ -53,7 +57,7 @@ def train_naive_bayes(conn: sqlite3.Connection):

def main(argv):
db_path = ""
opts, args = getopt.getopt(argv,"hd:",["dbpath="])
opts, _ = getopt.getopt(argv,"hd:",["dbpath="])
for opt, arg in opts:
if opt == '-h':
print(HELP)
Expand All @@ -70,6 +74,5 @@ def main(argv):
train_naive_bayes(conn)
conn.close()


if __name__ == "__main__":
main(sys.argv[1:])
main(sys.argv[1:])
Loading

0 comments on commit 9ccd967

Please sign in to comment.