Skip to content

Commit

Permalink
naive bayes testing
Browse files Browse the repository at this point in the history
+ function name classifier testing scripts save results to separate database
  • Loading branch information
michal-kapala committed Aug 26, 2023
1 parent 376a2e6 commit a88289a
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 2 deletions.
5 changes: 3 additions & 2 deletions models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
Serialized classifier models with corresponding training and testing scripts.
## Directories

* `/embedder` - [`FastText`](https://fasttext.cc/) word embedder model for token text vectorization (self-trained, with source)
* `/names` - training and test scripts for function name classifiers
* `embedder` - [`FastText`](https://fasttext.cc/) word embedder model for token text vectorization (self-trained, with source)
* `names` - training and test scripts for function name classifiers

## Files
Naming convention for model files is `<classifier set>_<classifier name>.<serialization source>`.

Expand Down
96 changes: 96 additions & 0 deletions models/names/test_gnbayes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import sqlite3, sys, os, getopt, pandas as pd
from sklearn.naive_bayes import GaussianNB
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score
from joblib import load
from utils import NameClassifierUtils as utils


HELP = 'Usage:\npython test_gnbayes.py --dbpath="<dataset db path>" --results"<results db path>"\n'
MODEL_FILE = 'names_gnbayes.joblib'

def test_naive_bayes(conn: sqlite3.Connection, results_path: str):
"""Tests Gaussian Naive Bayes (scikit-learn) function name classifier and saves the results."""
cur = conn.cursor()
print("Fetching data...")
tokens = utils.query_tokens(cur)
pdb = utils.query_pdb(cur)
df = utils.balance_dataset(tokens, pdb)
print('Loading FastText model...')
try:
ft = utils.load_ft(utils.get_embedder_path())
except Exception as ex:
print(ex)
sys.exit()
literals = df['literal']
labels = df['is_name']

print("Splitting datasets...")
x_train, x_test, y_train, y_test = utils.split_dataset(literals, labels)

print("Performing word embedding...")
x_test = pd.DataFrame(data=x_test, columns = ['literal'])
x_test = utils.ft_embed(ft, x_test)
X_test = utils.listify(x_test['lit_vec'].to_list())
y_test = tuple(y_test.to_list())

# scaling
scaler = StandardScaler()
scaler.fit(X_test)
scaler.transform(X_test)

file_path = utils.get_model_path(MODEL_FILE)
print('Loading classifier model...')
try:
gnb = load(file_path)
except Exception as ex:
print(ex)
sys.exit()

print("Predicting...")
y_pred = gnb.predict(X=X_test)
print(y_pred)
print("Number of mislabeled points out of a total %d points : %d" % (x_test.shape[0], (y_test != y_pred).sum()))

# structure and save results
table = MODEL_FILE.replace('.joblib', '')

x_test['label'] = ''
x_test['prediction'] = ''
x_test = x_test.reset_index(drop=True)
for idx in x_test.index:
x_test.at[idx, 'label'] = y_test[idx]
x_test.at[idx, 'prediction'] = y_pred[idx]

print(x_test)
utils.save_results(x_test, table, results_path)

def main(argv):
db_path = ""
results_path = ""
opts, args = getopt.getopt(argv,"hdr:",["dbpath=", "results="])
for opt, arg in opts:
if opt == '-h':
print(HELP)
sys.exit()
elif opt in ("-d", "--dbpath"):
db_path = arg
elif opt in ("-r", "--results"):
results_path = arg

if db_path == "":
raise Exception(f"Dataset SQLite database path required\n{HELP}")
if results_path == "":
raise Exception(f"Results SQLite database path required\n{HELP}")
if not os.path.isfile(db_path):
raise Exception(f"Dataset database not found at {db_path}")
if not os.path.isfile(results_path):
raise Exception(f"Dataset database not found at {results_path}")

conn = sqlite3.connect(db_path)
test_naive_bayes(conn, results_path)
conn.close()


if __name__ == "__main__":
main(sys.argv[1:])
File renamed without changes.
41 changes: 41 additions & 0 deletions models/names/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,50 @@ def ft_embed(ft: FastText, tokens: pd.DataFrame):
tokens.at[idx, 'lit_vec'] = ft.wv[tokens.at[idx, 'literal']]
return tokens

@staticmethod
def listify(lst: list) -> list[list]:
"""Transforms `list[numpy.array]` into a `list[list[any]]`."""
result = []
for elem in lst:
result.append([elem.tolist()])
return result

@staticmethod
def save_results(results: pd.DataFrame, table: str, dbpath: str):
"""Saves (or overwrites) test results to results database."""
conn = sqlite3.connect(dbpath)
cur = conn.cursor()

try:
# sql injection yay (table names cant be passed as params)
cur.execute(f'DROP TABLE IF EXISTS {table}')
cur.execute(f'''CREATE TABLE {table} (
literal TEXT NOT NULL,
lit_vec REAL NOT NULL,
label INTEGER NOT NULL,
prediction INTEGER NOT NULL)''')
except Exception as ex:
print(ex)
sys.exit()

conn.commit()

literal = str('')
lit_vec = 0.0
label = int(-1)
prediction = int(-1)

for idx in results.index:
literal = results.at[idx, 'literal']
lit_vec = float(results.at[idx, 'lit_vec'])
label = results.at[idx, 'label']
prediction = int(results.at[idx, 'prediction'])
try:
# sql injection yay (table names cant be passed as params)
cur.execute(f'INSERT INTO {table} VALUES (?,?,?,?)', (literal, lit_vec, label, prediction))
except Exception as ex:
print(ex)
sys.exit()

conn.commit()
conn.close()

0 comments on commit a88289a

Please sign in to comment.