Skip to content

Commit

Permalink
sim_4_cv before classif
Browse files Browse the repository at this point in the history
  • Loading branch information
nicdemon committed Jan 14, 2023
1 parent 2c74a1b commit 452f60c
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 136 deletions.
1 change: 0 additions & 1 deletion src/Caribou_kmers.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def kmers_dataset(opt):
parser.add_argument('-k','--k_length', required=True, type=int, help='Length of k-mers to extract')
parser.add_argument('-l','--kmers_list', default=None, type=pathlib.Path, help='PATH to a file containing a list of k-mers to be extracted if the dataset is not a training database')
parser.add_argument('-o','--outdir', required=True, type=pathlib.Path, help='PATH to a directory on file where outputs will be saved')
parser.add_argument('-o','--outdir', required=True, type=pathlib.Path, help='PATH to a directory on file where outputs will be saved')
parser.add_argument('-wd','--workdir', default='/tmp/spill', type=Path, help='Optional. Path to a working directory where tuning data will be spilled')
args = parser.parse_args()

Expand Down
118 changes: 87 additions & 31 deletions src/models/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
import numpy as np
import pandas as pd

from utils import zip_X_y
from glob import glob
from shutil import rmtree
from utils import load_Xy_data
from models.sklearn.ray_sklearn import SklearnModel
from models.kerasTF.ray_keras_tf import KerasTFModel

# Simulation class
from models.reads_simulation import readsSimulation


__author__ = 'Nicolas de Montigny'

Expand Down Expand Up @@ -73,12 +77,12 @@ def __init__(
# Empty initializations
self.models = {}
self._host = False
self._X_train = None
self._y_train = None
self._host_data = None
self._database_data = None
self._classified_ids = []
self._not_classified_ids = []
self._training_datasets = None
self._merged_training_datasets = None
self._merged_database_host = None
self.previous_taxa_unclassified = None
if isinstance(database_k_mers, tuple):
Expand All @@ -87,7 +91,7 @@ def __init__(
self._host_data = database_k_mers[1]
else:
self._database_data = database_k_mers
# Remove 'id' form kmers if present
# Remove 'id' from kmers if present
if 'id' in self._database_data['kmers']:
self._database_data['kmers'].remove('id')
if self._host and 'id' in self._host_data['kmers']:
Expand Down Expand Up @@ -155,7 +159,9 @@ def _binary_training(self, taxa):
self._database_data['kmers'],
self._verbose
)
self._load_training_data(taxa)
if self._training_datasets is None:
self._load_training_data()
self.models[taxa].train(self._training_datasets, self._database_data, self._cv)
else:
self._merge_database_host(self._database_data, self._host_data)
if self._classifier_binary == 'linearsvm':
Expand Down Expand Up @@ -183,11 +189,9 @@ def _binary_training(self, taxa):
self._merged_database_host['kmers'],
self._verbose
)
self._load_training_data(taxa, merged = True)
if self._merged_database_host is None:
self.models[taxa].train(self._X_train, self._y_train, self._database_data, self._cv)
else:
self.models[taxa].train(self._X_train, self._y_train, self._merged_database_host, self._cv)
self._load_training_data_merged(taxa)
self.models[taxa].train(self._merged_training_datasets, self._merged_database_host, self._cv)

self._save_model(self._model_file, taxa)

def _multiclass_training(self, taxa):
Expand Down Expand Up @@ -218,8 +222,9 @@ def _multiclass_training(self, taxa):
self._database_data['kmers'],
self._verbose
)
self._load_training_data(taxa)
self.models[taxa].train(self._X_train, self._y_train, self._database_data, self._cv)
if self._training_datasets is None:
self._load_training_data()
self.models[taxa].train(self._training_datasets, self._database_data, self._cv)
self._save_model(self._model_file, taxa)

# Execute classification using trained model(s)
Expand Down Expand Up @@ -439,25 +444,76 @@ def _verify_classifier_multiclass(self):
else:
raise ValueError('Invalid classifier option for bacteria classification!\n\tModels implemented at this moment are :\n\tClassic algorithm : Stochastic Gradient Descent (sgd) and Multinomial Naïve Bayes (mnb)\n\tNeural networks : Deep hybrid between LSTM and Attention (lstm_attention), CNN (cnn) and Wide CNN (widecnn)')

def _load_training_data(self, taxa, merged = False):
if merged:
# Binary merged
self._X_train = ray.data.read_parquet(self._merged_database_host['profile'])
self._y_train = pd.DataFrame({
taxa: pd.DataFrame(
self._merged_database_host['classes'],
columns=self._merged_database_host['taxas']
).loc[:, taxa].astype('string')
})
def _load_training_data_merged(self, taxa):
X_train = ray.data.read_parquet(self._merged_database_host['profile'])
y_train = pd.DataFrame({
taxa: pd.DataFrame(
self._merged_database_host['classes'],
columns=self._merged_database_host['taxas']
).loc[:, taxa].str.lower()
})

y_train[y_train['domain'] == 'archaea'] = 'bacteria'

df = zip_X_y(X_train, y_train)

if self._cv:
df_train, df_test = df.train_test_split(0.2, shuffle=True)
df_test = self._sim_4_cv(df_test, self._database_data, f'{self._database}_test')
self._merged_training_datasets = {'train': df_train, 'test': df_test}
else:
# Binary not merged or multiclass
self._X_train = ray.data.read_parquet(self._database_data['profile'])
self._y_train = pd.DataFrame({
taxa: pd.DataFrame(
self._database_data['classes'],
columns=self._database_data['taxas']
).loc[:, taxa].astype('string')
})
if taxa == 'domain':
self._y_train[self._y_train['domain'] == 'archaea'] = 'bacteria'
self._merged_training_datasets = {'train': df_train}

def _load_training_data(self):
X_train = ray.data.read_parquet(self._database_data['profile'])
y_train = pd.DataFrame(
self._database_data['classes'],
columns=self._database_data['taxas']
)

for col in y_train.columns:
y_train[col] = y_train[col].str.lower()

if 'domain' in y_train.columns:
y_train[y_train['domain'] == 'archaea'] = 'bacteria'

df = zip_X_y(X_train, y_train)

if self._cv:
df_train, df_test = df.train_test_split(0.2, shuffle=True)
df_test = self._sim_4_cv(df_test, self._database_data, f'{self._database}_test')
self._training_datasets = {'train': df_train, 'test': df_test}
else:
self._training_datasets = {'train': df_train}

def _sim_4_cv(self, df, kmers_ds, name):
sim_cls_dct = {
'id':[],
}
taxa_cols = []
for row in df.iter_rows():
if len(taxa_cols) == 0:
taxa_cols = list(row.keys())
taxa_cols.remove('id')
taxa_cols.remove('__value__')
for taxa in taxa_cols:
sim_cls_dct[taxa] = []
sim_cls_dct['id'].append(row['id'])
for taxa in taxa_cols:
sim_cls_dct[taxa].append(row[taxa])
cls = pd.DataFrame(sim_cls_dct)
sim_outdir = os.path.dirname(kmers_ds['profile'])
cv_sim = readsSimulation(kmers_ds['fasta'], cls, sim_cls_dct['id'], 'miseq', sim_outdir, name)
sim_data = cv_sim.simulation(self._k, self._database_data['kmers'])
sim_ids = sim_data['ids']
sim_ids = sim_data['ids']
sim_cls = pd.DataFrame({'sim_id':sim_ids}, dtype = object)
sim_cls['id'] = sim_cls['sim_id'].str.replace('_[0-9]+_[0-9]+_[0-9]+', '', regex=True)
sim_cls = sim_cls.set_index('id').join(cls.set_index('id'))
sim_cls = sim_cls.drop(['sim_id'], axis=1)
sim_cls = sim_cls.reset_index(drop = True)
df = ray.data.read_parquet(sim_data['profile'])
df = zip_X_y(df, sim_cls)
return df


81 changes: 55 additions & 26 deletions src/models/kerasTF/ray_keras_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from glob import glob
from shutil import rmtree
from utils import zip_X_y

# Preprocessing
from models.ray_tensor_min_max import TensorMinMaxScaler
from ray.data.preprocessors import BatchMapper, Concatenator, LabelEncoder, Chain, OneHotEncoder

# Parent class / models
Expand All @@ -29,6 +29,9 @@
from ray.train.tensorflow import TensorflowPredictor
from ray.train.batch_predictor import BatchPredictor

# Simulation class
from models.reads_simulation import readsSimulation

__author__ = 'Nicolas de Montigny'

__all__ = ['KerasTFModel']
Expand Down Expand Up @@ -124,18 +127,13 @@ def __init__(
elif self.classifier == 'widecnn':
print('Training multiclass classifier based on Wide CNN Network')

def _training_preprocess(self, X, y):
print('_training_preprocess')
self._preprocessor = TensorMinMaxScaler(self.kmers)
self._preprocessor.fit(X)
self._label_encode(y)
df = self._zip_X_y(X, y)
return df

def _label_encode(self, y):
self._nb_classes = len(np.unique(y[self.taxa]))
y = ray.data.from_pandas(y)
self._label_encode_define(y)
def _label_encode(self, df):
print('_label_encode')
labels = []
for row in df.iter_rows():
labels.append(row[self.taxa])
self._nb_classes = len(np.unique(labels))
self._label_encode_define(df)
encoded = []
encoded.append(-1)
labels = ['unknown']
Expand All @@ -151,7 +149,7 @@ def _label_encode_define(self, df):
LabelEncoder(self.taxa),
OneHotEncoder([self.taxa]),
Concatenator(
output_column_name='labels',
output_column_name=self.taxa,
include=['{}_{}'.format(self.taxa, i) for i in range(self._nb_classes)]
)
)
Expand All @@ -165,28 +163,51 @@ def _label_decode(self, predict):

return decoded

def train(self, X, y, kmers_ds, cv = True):
def train(self, datasets, kmers_ds, cv = True):
print('train')
df = self._training_preprocess(X, y)

df = datasets['train']
self._training_preprocess(df)

if cv:
self._cross_validation(df, kmers_ds)
df_test = datasets['test']
self._cross_validation(df, df_test, kmers_ds)
else:
df_train, df_val = df.train_test_split(0.2, shuffle = True)
df_val = self._sim_4_cv(df_val, kmers_ds, 'validation')
df_val = self._sim_4_val(df_val, kmers_ds, 'validation')
df_train = df_train.drop_columns(['id'])
df_val = df_val.drop_columns(['id'])
datasets = {'train': df_train, 'validation': df_val}
self._fit_model(datasets)

def _cross_validation(self, df, kmers_ds):
def _sim_4_val(self, df, kmers_ds, name):
sim_genomes = []
sim_taxas = []
for row in df.iter_rows():
sim_genomes.append(row['id'])
sim_taxas.append(row[self.taxa])
cls = pd.DataFrame({'id':sim_genomes,self.taxa:sim_taxas})
sim_outdir = os.path.dirname(kmers_ds['profile'])
cv_sim = readsSimulation(kmers_ds['fasta'], cls, sim_genomes, 'miseq', sim_outdir, name)
sim_data = cv_sim.simulation(self.k, self.kmers)
sim_ids = sim_data['ids']
sim_ids = sim_data['ids']
sim_cls = pd.DataFrame({'sim_id':sim_ids}, dtype = object)
sim_cls['id'] = sim_cls['sim_id'].str.replace('_[0-9]+_[0-9]+_[0-9]+', '', regex=True)
sim_cls = sim_cls.set_index('id').join(cls.set_index('id'))
sim_cls = sim_cls.drop(['sim_id'], axis=1)
sim_cls = sim_cls.reset_index(drop = True)
df = ray.data.read_parquet(sim_data['profile'])
df = zip_X_y(df, sim_cls)
return df

def _cross_validation(self, df_train, df_test, kmers_ds):
print('_cross_validation')

df_train, df_test = df.train_test_split(0.2, shuffle = True)
df_train, df_val = df_train.train_test_split(0.2, shuffle = True)

df_val = self._sim_4_cv(df_val, kmers_ds, '{}_val'.format(self.dataset))
df_test = self._sim_4_cv(df_test, kmers_ds, '{}_test'.format(self.dataset))

df_val = self._sim_4_val(df_val, kmers_ds, '{}_val'.format(self.dataset))

df_train = df_train.drop_columns(['id'])
df_test = df_test.drop_columns(['id'])
df_val = df_val.drop_columns(['id'])
Expand All @@ -198,6 +219,11 @@ def _cross_validation(self, df, kmers_ds):
y_true = []
for row in df_test.iter_rows():
y_true.append(row[self.taxa])

y_true = np.array(y_true)
y_true[np.isnan(y_true)] = -1
y_true = list(y_true)

y_pred = self.predict(df_test.drop_columns([self.taxa]), cv = True)

for file in glob(os.path.join(os.path.dirname(kmers_ds['profile']), '*sim*')):
Expand All @@ -222,6 +248,7 @@ def _fit_model(self, datasets):
'size': self._nb_kmers,
'nb_cls':self._nb_classes,
'model': self.classifier,
'labels_col': self.taxa,
}

# Define trainer / tuner
Expand Down Expand Up @@ -271,7 +298,8 @@ def predict(self, df, threshold = 0.8, cv = False):
)
# Make predictions
predictions = self._predictor.predict(
df,
feature_columns = ['__value__'],
data = df,
batch_size = self.batch_size
)
predictions = self._prob_2_cls(predictions, threshold)
Expand Down Expand Up @@ -326,6 +354,7 @@ def train_func(config):
size = config.get('size')
nb_cls = config.get('nb_cls')
model = config.get('model')
labels_col = config.get('labels_col')

# Model setup
strategy = tf.distribute.MultiWorkerMirroredStrategy()
Expand All @@ -339,13 +368,13 @@ def train_func(config):
def to_tf_dataset(data):
ds = tf.data.Dataset.from_tensors((
tf.convert_to_tensor(list(data['__value__'])),
tf.convert_to_tensor(list(data['labels']))
tf.convert_to_tensor(list(data[labels_col]))
))
return ds

# Fit the model on streaming data
results = []
batch_val = pd.DataFrame(columns = ['__value__', 'labels'])
batch_val = pd.DataFrame(columns = ['__value__', labels_col])
for epoch in val_data.iter_epochs(1):
for batch in epoch.iter_batches():
batch_val = pd.concat([batch_val,batch])
Expand Down
Loading

0 comments on commit 452f60c

Please sign in to comment.