Skip to content

Commit

Permalink
pipeline testing
Browse files Browse the repository at this point in the history
+ pipeline test with function positives (test set in large part known to both models, just a biased simulation of symbol extraction plugin work)
  • Loading branch information
michal-kapala committed Aug 29, 2023
1 parent e5802bd commit 484af2b
Show file tree
Hide file tree
Showing 2 changed files with 360 additions and 0 deletions.
200 changes: 200 additions & 0 deletions models/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import os, sys, getopt
import sqlite3
from datetime import datetime
from joblib import load
from sklearn.preprocessing import StandardScaler
from utils import PipelineUtils as utils


HELP = 'Usage:\npython test_pipeline.py --dbpath="<dataset path>" --results="<results db path>" --names="<names classifier model file> --paths="<paths classifier model file>"\n'

def test_pipeline(conn: sqlite3.Connection, results_path: str, names_model_file: str, paths_model_file: str):
"""Simulates a plugin scenario and evaluates common predictions of both classifiers."""
cur = conn.cursor()
start = datetime.now()

print('Loading FastText model...')
try:
ft = utils.load_ft(utils.get_embedder_path())
except Exception as ex:
print(ex)
sys.exit()

print("Fetching data...")
data = utils.query_data(cur)
data['lit_vec'] = ''

print("Performing word embedding...")
data = utils.ft_embed(ft, data)
data.drop(['token_literal'], axis=1, inplace=True)

print('Loading classifier models...')
try:
names_clf = load(utils.get_model_path(names_model_file))
paths_clf = load(utils.get_model_path(paths_model_file))
except Exception as ex:
print(ex)
sys.exit()

names = utils.listify_names(data['lit_vec'].to_list())
print("Scaling names data...")
scaler = StandardScaler()
scaler.fit(names)
scaler.transform(names)

print("Predicting names...")
data['name_pred'] = names_clf.predict(X=names)

# rearrange columns
paths = data.drop(['binary', 'func_addr', 'names_func', 'name_pred'], axis=1)
paths = paths[['ref_depth',
'is_upward',
'nb_referrers',
'nb_strings',
'nb_referees',
'instructions',
'lit_vec']]
paths = utils.listify_paths(paths)

print(f"Scaling paths...")
scaler = StandardScaler()
scaler.fit(paths)
scaler.transform(paths)

print("Predicting paths...")
data['path_pred'] = paths_clf.predict(paths)
path_prob = paths_clf.predict_proba(paths)
data['path_pred_prob1'] = ''

for idx in data.index:
data.at[idx, 'path_pred_prob1'] = path_prob[idx][1]

funcs = utils.group_in_funcs(data)

# stats
tp = 0
tn = 0
fp = 0
fn = 0

for func in funcs:
for tpath in func:
truth = tpath[4]
names_pred = tpath[10]
paths_pred = tpath[11]
if truth == 0 and names_pred == 0:
tn += 1
if truth == 1 and names_pred == 0:
fn += 1
if truth == 1 and names_pred == 1 and paths_pred == 1:
tp += 1
if truth == 1 and names_pred == 1 and paths_pred == 0:
fn += 1
if truth == 0 and names_pred == 1 and paths_pred == 0:
tn += 1
if truth == 0 and names_pred == 1 and paths_pred == 1:
fp += 1


if tp + tn + fp + fn == 0:
print(f"Why are you testing with no data?")
sys.exit()

accuracy = (tp + tn) / (tp + tn + fp + fn)

if tp + fp == 0:
print("Precision could not be calculated (no positive predictions)")
precision = None
else:
precision = tp / (tp + fp)

if tp + fn == 0:
print("Recall could not be calculated - why are you testing with no positive samples in the set?")
recall = None
else:
recall = tp / (tp + fn)

if precision is None or recall is None:
f1 = None
else:
if precision == 0 or recall == 0:
f1 = 0
else:
f1 = 2 * precision * recall / (precision + recall)

print(f"Accuracy: {accuracy * 100:.3f}%")
results = {
"pos": tp + fn,
"neg": tn + fp,
"tp": tp,
"tn": tn,
"fp": fp,
"fn": fn,
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1": f1
}
print(results)

names_clf_name = names_model_file.replace('names_', '').replace('.joblib', '')
paths_clf_name = paths_model_file.replace('paths_', '').replace('.joblib', '')
table = f"pipe_{names_clf_name}_{paths_clf_name}"
utils.save_results(results, table, results_path)


print(f'Start time:\t{start}')
print(f'End time:\t{datetime.now()}')


def main(argv):
db_path = ""
results_path = ""
names_model = ""
paths_model = ""
opts, _ = getopt.getopt(argv,"hdrnp:",["dbpath=", "results=", "names=", "paths="])
for opt, arg in opts:
if opt == '-h':
print(HELP)
sys.exit()
elif opt in ("-d", "--dbpath"):
db_path = arg
elif opt in ("-r", "--results"):
results_path = arg
elif opt in ("-n", "--names"):
names_model = arg
elif opt in ("-p", "--paths"):
paths_model = arg

if db_path == "":
raise Exception(f"Dataset SQLite database path required\n{HELP}")
if results_path == "":
raise Exception(f"Results SQLite database path required\n{HELP}")
if names_model == "":
raise Exception(f"Function name model file name (with extension) required\n{HELP}")
if paths_model == "":
raise Exception(f"Xref paths model file name (with extension) required\n{HELP}")

if not os.path.isfile(db_path):
raise Exception(f"Dataset database not found at {db_path}")

if not os.path.isfile(results_path):
raise Exception(f"Results database not found at {results_path}")

names_model_path = utils.get_model_path(names_model)
if not os.path.isfile(names_model_path):
raise Exception(f"Function name model not found at {names_model_path}")

paths_model_path = utils.get_model_path(paths_model)
if not os.path.isfile(paths_model_path):
raise Exception(f"Xref path model not found at {paths_model_path}")

conn = sqlite3.connect(db_path)

test_pipeline(conn, results_path, names_model, paths_model)

conn.commit()
conn.close()

if __name__ == "__main__":
main(sys.argv[1:])
160 changes: 160 additions & 0 deletions models/pipeline/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import sqlite3, sys, os, pandas as pd
from gensim.models import FastText
from sklearn.model_selection import train_test_split

_TPATH_COLUMNS = [
'binary',
'func_addr',
'ref_depth',
'is_upward',
'token_literal',
'names_func',
'nb_referrers',
'nb_strings',
'nb_referees',
'instructions']

class PipelineUtils:
"""Utility functions for pipeline simulation."""
@staticmethod
def query_data(cur: sqlite3.Cursor) -> list[tuple]:
"""Returns paths/token paths/funcs join for the pipeline."""
try:
cur.execute('''SELECT p.binary, p.func_addr, ref_depth, is_upward, token_literal, names_func, nb_referrers, nb_strings, nb_referees, instructions FROM (SELECT binary, local_id, func_addr, ref_depth, is_upward FROM paths WHERE to_name = 1 GROUP BY func_addr) AS p
JOIN token_paths AS tp ON p.binary = tp.binary AND local_id = local_path_id
JOIN funcs ON funcs.binary = p.binary AND funcs.func_addr = p.func_addr
WHERE names_func IS NOT NULL''')
data = cur.fetchall()
# 'no such table: x'
except sqlite3.OperationalError as ex:
print(ex)
sys.exit()

return pd.DataFrame(data, columns=_TPATH_COLUMNS)

@staticmethod
def query_paths(cur: sqlite3.Cursor, func_addr: int, binary: str) -> pd.DataFrame:
"""Returns all labeled paths of a function."""
try:
cur.execute('SELECT local_id, ref_depth, is_upward, FROM paths WHERE func_addr = ? AND binary = ? AND to_name IS NOT NULL',
(func_addr, binary))
paths = cur.fetchall()
# 'no such table: x'
except sqlite3.OperationalError as ex:
print(ex)
sys.exit()

return

@staticmethod
def get_model_path(filename: str) -> str:
"""Returns the target path for model file."""
models_path, _ = os.path.split(os.getcwd())
return os.path.join(models_path, filename)

@staticmethod
def get_embedder_path() -> str:
"""Returns the path to FastText model file (only supports Windows paths)."""
models_path, _ = os.path.split(os.getcwd())
return os.path.join(models_path, 'embedder\\embedder.ft')

@staticmethod
def load_ft(path: str) -> FastText:
"""Loads a pretrained FastText model from a file."""
return FastText.load(path)

@staticmethod
def ft_embed(ft: FastText, df: pd.DataFrame):
"""Performs vectorization on token text data."""
df['lit_vec'] = ''
for idx in df.index:
df.at[idx, 'lit_vec'] = ft.wv[df.at[idx, 'token_literal']]
return df

@staticmethod
def listify_paths(df: pd.DataFrame) -> list:
"""Transforms the vectorized token literal from `numpy.array` into a single-element list, then converts `pd.DataFrame` to list."""
for idx in df.index:
df.at[idx, 'lit_vec'] = df.at[idx, 'lit_vec'].tolist()

return df.values.tolist()

@staticmethod
def listify_names(lst: list) -> list[list]:
"""Transforms `list[numpy.array]` into a `list[list[any]]`."""
result = []
for elem in lst:
result.append([elem.tolist()])
return result

@staticmethod
def group_in_funcs(df: pd.DataFrame) -> dict:
"""Returns a dict with functions->token paths hierarchy."""
mapping = {}
for idx in df.index:
binary = df.at[idx, 'binary']
# func = df.at[idx, 'func_addr']
mapping[binary] = {}

for idx in df.index:
binary = df.at[idx, 'binary']
func = df.at[idx, 'func_addr']
mapping[binary][func] = []

for idx in df.index:
binary = df.at[idx, 'binary']
func = df.at[idx, 'func_addr']
mapping[binary][func].append(df.iloc[idx].to_list())

result = []
for bkey in mapping:
for fkey in mapping[bkey]:
result.append(mapping[bkey][fkey])

return result

@staticmethod
def save_results(results: dict, table: str, dbpath: str):
"""Saves test results to results database (or overwrites existing)."""
conn = sqlite3.connect(dbpath)
cur = conn.cursor()

try:
# sql injection yay (table names cant be passed as params)
cur.execute(f'DROP TABLE IF EXISTS {table}')
cur.execute(f'''CREATE TABLE {table} (
pos INTEGER NOT NULL,
neg INTEGER NOT NULL,
tp INTEGER NOT NULL,
tn INTEGER NOT NULL,
fp INTEGER NOT NULL,
fn INTEGER NOT NULL,
accuracy REAL NOT NULL,
precision REAL,
recall REAL,
f1 REAL)''')
except Exception as ex:
print(ex)
sys.exit()

conn.commit()
pos = int(results['pos'])
neg = int(results['neg'])
tp = int(results['tp'])
tn = int(results['tn'])
fp = int(results['fp'])
fn = int(results['fn'])
acc = float(results['accuracy'])
precision = float(results['precision']) if results['precision'] is not None else None
recall = float(results['recall']) if results['recall'] is not None else None
f1 = float(results['f1']) if results['f1'] is not None else None
try:
# sql injection yay (table names cant be passed as params)
cur.execute(f'INSERT INTO {table} VALUES (?,?,?,?,?,?,?,?,?,?)',
(pos, neg, tp, tn, fp, fn, acc, precision, recall, f1))
except Exception as ex:
print(ex)
sys.exit()

conn.commit()
conn.close()

0 comments on commit 484af2b

Please sign in to comment.