-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
+ training script utils for function name classifiers + gaussian naive bayes model trained with amazing 62% cross-validation accuracy score
- Loading branch information
1 parent
d143c63
commit 376a2e6
Showing
7 changed files
with
155 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Models | ||
|
||
Serialized classifier models with corresponding training and testing scripts. | ||
## Directories | ||
|
||
* `/embedder` - [`FastText`](https://fasttext.cc/) word embedder model for token text vectorization (self-trained, with source) | ||
* `/names` - training and test scripts for function name classifiers | ||
## Files | ||
Naming convention for model files is `<classifier set>_<classifier name>.<serialization source>`. | ||
|
||
* `*.joblib` - classifier models serialized with [`joblib`](https://pypi.org/project/joblib/) | ||
* `*.ft` - FastText model serialized with [`gensim.models.FastText`](https://radimrehurek.com/gensim/models/fasttext.html) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import sqlite3, sys, os, pandas as pd | ||
from gensim.models import FastText | ||
from sklearn.model_selection import train_test_split | ||
|
||
_COLUMNS = ['literal', 'is_name'] | ||
_TEST_SIZE_RATIO = 0.2 | ||
"""Desired percentage of test samples in the dataset.""" | ||
|
||
class NameClassifierUtils: | ||
"""Utility functions for function name classifiers.""" | ||
@staticmethod | ||
def query_tokens(cur: sqlite3.Cursor) -> pd.DataFrame: | ||
"""Returns all labelled tokens from the dataset.""" | ||
try: | ||
cur.execute('SELECT literal, is_name FROM tokens WHERE is_name IS NOT NULL') | ||
tokens = cur.fetchall() | ||
# 'no such table: x' | ||
except sqlite3.OperationalError as ex: | ||
print(ex) | ||
sys.exit() | ||
|
||
return pd.DataFrame(data=tokens, columns=_COLUMNS) | ||
|
||
@staticmethod | ||
def query_pdb(cur: sqlite3.Cursor) -> pd.DataFrame: | ||
"""Returns all PDB function names from the dataset.""" | ||
try: | ||
cur.execute('SELECT literal FROM pdb') | ||
pdb = cur.fetchall() | ||
# 'no such table: x' | ||
except sqlite3.OperationalError as ex: | ||
print(ex) | ||
sys.exit() | ||
df = pd.DataFrame(data=pdb, columns=['literal'], index=range(len(pdb))) | ||
|
||
df['is_name'] = '' | ||
for idx in df.index: | ||
df.at[idx, 'is_name'] = 1 | ||
|
||
return df | ||
|
||
@staticmethod | ||
def get_embedder_path() -> str: | ||
"""Returns the path to FastText model file (only supports Windows paths).""" | ||
models_path, _ = os.path.split(os.getcwd()) | ||
return os.path.join(models_path, 'embedder\\embedder.ft') | ||
|
||
@staticmethod | ||
def get_model_path(filename: str) -> str: | ||
"""Returns the target path for model file.""" | ||
models_path, _ = os.path.split(os.getcwd()) | ||
return os.path.join(models_path, filename) | ||
|
||
@staticmethod | ||
def load_ft(path: str) -> FastText: | ||
"""Loads a pretrained FastText model from a file.""" | ||
return FastText.load(path) | ||
|
||
@staticmethod | ||
def balance_dataset(tokens_df: pd.DataFrame, pdb_df: pd.DataFrame) -> pd.DataFrame: | ||
"""Returns a complete dataset balanced with PDB positives.""" | ||
# calculate the nb of missing positives | ||
nb_neg = tokens_df[tokens_df['is_name'] == 0].shape[0] | ||
nb_pos = tokens_df.shape[0] - nb_neg | ||
nb_missing_pos = nb_neg - nb_pos | ||
|
||
# deterministic shuffle | ||
balancing_pos, _ = train_test_split(pdb_df, train_size=nb_missing_pos, random_state=0) | ||
return pd.concat([tokens_df, balancing_pos], ignore_index=True) | ||
|
||
@staticmethod | ||
def split_dataset(features: pd.DataFrame, labels: pd.DataFrame) -> tuple: | ||
"""Parameterized wrapper for `sklearn.model_selection.train_test_split`.""" | ||
# Deterministic shuffle | ||
x_train, x_test, y_train, y_test = train_test_split(features, labels, test_size=_TEST_SIZE_RATIO, random_state=0) | ||
return x_train, x_test, y_train, y_test | ||
|
||
@staticmethod | ||
def ft_embed(ft: FastText, tokens: pd.DataFrame): | ||
"""Performs vectorization on token text data.""" | ||
tokens['lit_vec'] = '' | ||
for idx in tokens.index: | ||
tokens.at[idx, 'lit_vec'] = ft.wv[tokens.at[idx, 'literal']] | ||
return tokens | ||
|
||
def listify(lst: list) -> list[list]: | ||
"""Transforms `list[numpy.array]` into a `list[list[any]]`.""" | ||
result = [] | ||
for elem in lst: | ||
result.append([elem.tolist()]) | ||
return result |
Binary file not shown.