From f98b25e3921c592a5d665825aede8d8b416af7a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pablo=20Rodr=C3=ADguez=20Flores?= Date: Thu, 4 Jan 2024 12:17:22 +0000 Subject: [PATCH] Improved logic and error handling when selecting a model in the api and improved its tests, solved an issue where the outliers model failed to load some data when reading a json and renamed some variables not in camel_case --- resources/src/ai/outliers.py | 102 ++++++++++++++------------- resources/src/ai/shallow_outliers.py | 6 +- resources/src/ai/traffic.ini | 4 +- resources/src/ai/trainer.py | 25 +++---- resources/src/server/rest.py | 70 ++++++------------ resources/tests/test_rest.py | 70 +++++++++++++----- resources/tests/test_trainer.py | 8 +-- 7 files changed, 151 insertions(+), 134 deletions(-) diff --git a/resources/src/ai/outliers.py b/resources/src/ai/outliers.py index ab73436..7e06305 100644 --- a/resources/src/ai/outliers.py +++ b/resources/src/ai/outliers.py @@ -26,12 +26,16 @@ ''' End of important OS Variables ''' +import sys import shutil import numpy as np import configparser import pandas as pd import tensorflow as tf +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from logger import logger + class Autoencoder: """ Autoencoder class for anomaly detection. @@ -47,14 +51,14 @@ def __init__(self, model_file, model_config_file): Args: model_file (str): Path to model .keras file. model_config_file (str): Path to the model config, including: - METRICS (list): Names of the metrics used by the module. - TIMESTAMP (list): Names of the timestamp columns used by the module. - AVG_LOSS (float): Average loss of the model. - STD_LOSS (float): Standard deviation of the loss of the model. - WINDOW_SIZE (int): Number of entries the model will put together in a 'window'. - NUM_WINDOWS (int): Number of windows the model will put together in each slice. - LOSS_MULT_1 (float): Extra penalty in the loss function for guessing wrong metrics. - LOSS_MULT_2 (float): Extra penalty in the loss function for guessing wrong + metrics (list): Names of the metrics used by the module. + timestamp (list): Names of the timestamp columns used by the module. + avg_loss (float): Average loss of the model. + std_loss (float): Standard deviation of the loss of the model. + window_size (int): Number of entries the model will put together in a 'window'. + num_window (int): Number of windows the model will put together in each slice. + loss_mult_metric (float): Extra penalty in the loss function for guessing wrong metrics. + loss_mult_minute (float): Extra penalty in the loss function for guessing wrong 'minute' field. """ self.check_existence(model_file, model_config_file) @@ -63,36 +67,36 @@ def __init__(self, model_file, model_config_file): model_config = configparser.ConfigParser() model_config.read(model_config_file) columns_section = model_config['Columns'] - self.METRICS = columns_section.get('METRICS', '').split(', ') - self.TIMESTAMP = columns_section.get('TIMESTAMP', '').split(', ') - self.COLUMNS = self.METRICS + self.TIMESTAMP + self.metrics = columns_section.get('METRICS', '').split(', ') + self.timestamp = columns_section.get('TIMESTAMP', '').split(', ') + self.columns = self.metrics + self.timestamp general_section = model_config['General'] - self.AVG_LOSS = float(general_section.get('AVG_LOSS', 0.0)) - self.STD_LOSS = float(general_section.get('STD_LOSS', 0.0)) - self.WINDOW_SIZE = int(general_section.get('WINDOW_SIZE', 0)) - self.NUM_WINDOWS = int(general_section.get('NUM_WINDOWS', 0)) - self.LOSS_MULT_1 = float(general_section.get('LOSS_MULT_1', 0)) - self.LOSS_MULT_2 = float(general_section.get('LOSS_MULT_2', 0)) + self.avg_loss = float(general_section.get('AVG_LOSS', 0.0)) + self.std_loss = float(general_section.get('STD_LOSS', 0.0)) + self.window_size = int(general_section.get('WINDOW_SIZE', 0)) + self.num_window = int(general_section.get('NUM_WINDOWS', 0)) + self.loss_mult_metric = float(general_section.get('LOSS_MULT_METRIC', 0)) + self.loss_mult_minute = float(general_section.get('LOSS_MULT_MINUTE', 0)) except FileNotFoundError: - print(f"Error: Model file '{model_config_file}' not found.") + logger.logger.error(f"Error: Model file '{model_config_file}' not found.") except (OSError, ValueError) as e: - print(f"Error loading model conif: {e}") + logger.logger.error(f"Error loading model conif: {e}") try: self.model = tf.keras.models.load_model( model_file, compile=False ) except FileNotFoundError: - print(f"Error: Model file '{model_file}' not found.") + logger.logger.error(f"Error: Model file '{model_file}' not found.") except (OSError, ValueError) as e: - print(f"Error loading the model: {e}") + logger.logger.error(f"Error loading the model: {e}") def check_existence(self, model_file, model_config_file): """ Check existence of model files and copy them if missing. - This function checks if the provided `model_file` and `model_config_file` exist in their - respective paths. If they don't exist, it renames and copies the corresponding default + This function checks if the provided `model_file` and `model_config_file` exist in their + respective paths. If they don't exist, it renames and copies the corresponding default files from the 'traffic.keras' and 'traffic.ini' files, which are expected to be located in the same directory as the target files. @@ -122,7 +126,7 @@ def rescale(self, data): Returns: (numpy.ndarray): Rescaled data as a numpy array. """ - num_metrics = len(self.METRICS) + num_metrics = len(self.metrics) rescaled=data.copy() rescaled[..., 0:num_metrics]=np.tanh(np.log1p(rescaled[..., 0:num_metrics])/32) rescaled[..., num_metrics]=rescaled[..., num_metrics]/1440 @@ -138,7 +142,7 @@ def descale(self, data): Returns: (numpy.ndarray): Descaled data as a numpy array. """ - num_metrics = len(self.METRICS) + num_metrics = len(self.metrics) descaled = data.copy() descaled = np.where(descaled > 1.0, 1.0, np.where(descaled < -1.0, -1.0, descaled)) descaled[..., 0:num_metrics] = np.expm1(32*np.arctanh(descaled[..., 0:num_metrics])) @@ -154,7 +158,7 @@ def model_loss(self, y_true, y_pred, single_value=True): otherwise, the value is left unchanged. Then, the difference between both tensors is evaluated and a log_cosh loss is applied. - + Args: y_true (tf.Tensor): True target values. y_pred (tf.Tensor): Predicted values. @@ -165,17 +169,17 @@ def model_loss(self, y_true, y_pred, single_value=True): """ y_true = tf.cast(y_true, tf.float16) y_pred = tf.cast(y_pred, tf.float16) - num_metrics = len(self.METRICS) - num_features = len(self.COLUMNS) + num_metrics = len(self.metrics) + num_features = len(self.columns) is_metric = (tf.range(num_features) < num_metrics) is_minute = (tf.range(num_features) == num_metrics) mult_true = tf.where( - is_metric, self.LOSS_MULT_1 * y_true, - tf.where(is_minute, self.LOSS_MULT_2 * y_true, y_true) + is_metric, self.loss_mult_metric * y_true, + tf.where(is_minute, self.loss_mult_minute * y_true, y_true) ) mult_pred = tf.where( - is_metric, self.LOSS_MULT_1 * y_pred, - tf.where(is_minute, self.LOSS_MULT_2 * y_pred, y_pred) + is_metric, self.loss_mult_metric * y_pred, + tf.where(is_minute, self.loss_mult_minute * y_pred, y_pred) ) standard_loss = tf.math.log(tf.cosh((mult_true - mult_pred))) if single_value: @@ -183,7 +187,6 @@ def model_loss(self, y_true, y_pred, single_value=True): return standard_loss def slice(self, data, index = []): - #TODO add a graph to doc to explain this """ Transform a 2D numpy array into a 3D array readable by the model. @@ -195,13 +198,13 @@ def slice(self, data, index = []): (numpy.ndarray): 3D numpy array that can be processed by the model. """ _l = len(data) - Xs = [] - slice_length = self.WINDOW_SIZE * self.NUM_WINDOWS + sliced_data = [] + slice_length = self.window_size * self.num_window if len(index) == 0: - index = np.arange(0, _l-slice_length+1 , self.WINDOW_SIZE) + index = np.arange(0, _l-slice_length+1 , self.window_size) for i in index: - Xs.append(data[i:i+slice_length]) - return np.array(Xs) + sliced_data.append(data[i:i+slice_length]) + return np.array(sliced_data) def flatten(self, data): """ @@ -213,11 +216,11 @@ def flatten(self, data): """ tsr = data.copy() num_slices, slice_len, features = tsr.shape - flattened_len = (num_slices-1)*self.WINDOW_SIZE + slice_len + flattened_len = (num_slices-1)*self.window_size + slice_len flattened_tensor = np.zeros([flattened_len, features]) scaling = np.zeros(flattened_len) for i in range(num_slices): - left_pad = i*self.WINDOW_SIZE + left_pad = i*self.window_size right_pad = left_pad+slice_len flattened_tensor[left_pad:right_pad] += tsr[i] scaling[left_pad:right_pad] +=1 @@ -256,10 +259,10 @@ def compute_json(self, metric, raw_json): (Json): Json with the anomalies and predictions for the data with RedBorder prediction Json format. """ - threshold = self.AVG_LOSS+5*self.STD_LOSS + threshold = self.std_loss+5*self.std_loss data, timestamps = self.input_json(raw_json) predicted, loss = self.calculate_predictions(data) - predicted = pd.DataFrame(predicted, columns=self.COLUMNS) + predicted = pd.DataFrame(predicted, columns=self.columns) predicted['timestamp'] = timestamps anomalies = predicted[loss>threshold] return self.output_json(metric, anomalies, predicted) @@ -271,7 +274,7 @@ def granularity_from_dataframe(self, dataframe): Args: dataframe (pandas.DataFrame): Dataframe with timestamp column - + Returns: time_diffs (pandas.Series): Series with the estimated Granularity of the dataframe. """ @@ -287,22 +290,23 @@ def input_json(self, raw_json): Args: raw_json (Json): druid Json response with the data. - + Returns: data (numpy.ndarray): transformed data. - timestamps (pandas.Series): pandas series with the timestamp of each entry. + timestamps (pandas.Series): pandas series with the timestamp of each entry. """ data = pd.json_normalize(raw_json) data["granularity"] = self.granularity_from_dataframe(data) - metrics_dict = {f"result.{metric}": metric for metric in self.METRICS} + metrics_dict = {f"result.{metric}": metric for metric in self.metrics} data.rename(columns=metrics_dict, inplace=True) timestamps = data['timestamp'].copy() data['timestamp'] = pd.to_datetime(data['timestamp']) data['minute'] = data['timestamp'].dt.minute + 60 * data['timestamp'].dt.hour - data = pd.get_dummies(data, columns=['timestamp'], prefix=['weekday'], drop_first=True) - missing_columns = set(self.COLUMNS) - set(data.columns) + data['weekday']= data['timestamp'].dt.weekday + data = pd.get_dummies(data, columns=['weekday'], prefix=['weekday'], drop_first=True) + missing_columns = set(self.columns) - set(data.columns) data[list(missing_columns)] = 0 - data = data[self.COLUMNS].dropna() + data = data[self.columns].dropna().astype('float') data_array = data.values return data_array, timestamps diff --git a/resources/src/ai/shallow_outliers.py b/resources/src/ai/shallow_outliers.py index 5f6c265..e230f07 100644 --- a/resources/src/ai/shallow_outliers.py +++ b/resources/src/ai/shallow_outliers.py @@ -38,7 +38,7 @@ def predict(self, arr): Args: arr (numpy.ndarray): 1D numpy array with the datapoints to be smoothed. - + Returns: smooth_arr (numpy.ndarray): 1D numpy array with the smoothed data. Same shape as arr. """ @@ -66,7 +66,7 @@ def get_outliers(self, arr, smoothed_arr): an outlier and False otherwise. The method used for outlier detection is an isolation forest, which will look for - the 0.3% most isolated points when taking into account the original value, the + the 0.3% most isolated points when taking into account the original value, the smoothed valued, the diference between them (error) and the squared diference between them. @@ -74,7 +74,7 @@ def get_outliers(self, arr, smoothed_arr): arr (numpy.ndarray): 1D numpy array where the outliers shall be detected. smoothed_arr (numpy.ndarray): 1D numpy array that tries to approximate arr. -Must have the same shape as arr. - + Returns: numpy.ndarray: 1D numpy array with the smoothed data. """ diff --git a/resources/src/ai/traffic.ini b/resources/src/ai/traffic.ini index b65aa73..c198749 100644 --- a/resources/src/ai/traffic.ini +++ b/resources/src/ai/traffic.ini @@ -7,6 +7,6 @@ avg_loss = 0.09741491896919627 std_loss = 0.11098885675977664 window_size = 16 num_windows = 2 -loss_mult_1 = 20.0 -loss_mult_2 = 10.0 +loss_mult_metric = 20.0 +loss_mult_minute = 10.0 diff --git a/resources/src/ai/trainer.py b/resources/src/ai/trainer.py index 45cd02c..643780f 100644 --- a/resources/src/ai/trainer.py +++ b/resources/src/ai/trainer.py @@ -81,16 +81,16 @@ def save_model(self, save_model_file, save_config_file): new_model_config = configparser.ConfigParser() new_model_config.add_section('Columns') columns_section = new_model_config['Columns'] - columns_section['METRICS'] = ', '.join(self.METRICS) - columns_section['TIMESTAMP'] = ', '.join(self.TIMESTAMP) + columns_section['METRICS'] = ', '.join(self.metrics) + columns_section['TIMESTAMP'] = ', '.join(self.timestamp) new_model_config.add_section('General') general_section = new_model_config['General'] - general_section['AVG_LOSS'] = str(self.AVG_LOSS) - general_section['STD_LOSS'] = str(self.STD_LOSS) - general_section['WINDOW_SIZE'] = str(self.WINDOW_SIZE) - general_section['NUM_WINDOWS'] = str(self.NUM_WINDOWS) - general_section['LOSS_MULT_1'] = str(self.LOSS_MULT_1) - general_section['LOSS_MULT_2'] = str(self.LOSS_MULT_2) + general_section['AVG_LOSS'] = str(self.avg_loss) + general_section['STD_LOSS'] = str(self.std_loss) + general_section['WINDOW_SIZE'] = str(self.window_size) + general_section['NUM_WINDOWS'] = str(self.num_window) + general_section['LOSS_MULT_METRIC'] = str(self.loss_mult_metric) + general_section['LOSS_MULT_MINUTE'] = str(self.loss_mult_minute) with open(save_config_file, 'w') as configfile: new_model_config.write(configfile) @@ -101,7 +101,7 @@ def data_augmentation(self, data): Args: data (numpy ndarray): original data to be fed to the model. - + Returns: augmented (numpy ndarray): augmented data. """ @@ -113,8 +113,9 @@ def prepare_data_for_training(self, data, augment = False): Args: data (numpy ndarray): data to be used for training. + augment (boolean): set to True to generate more data for training. - + Returns: prep_data (numpy ndarray): transformed data for its use in the model. """ @@ -145,6 +146,6 @@ def train(self, raw_data, epochs=20, batch_size=32, backup_path=None): prep_data = self.prepare_data_for_training(data) self.model.fit(x=prep_data, y=prep_data, epochs = epochs, batch_size = batch_size, verbose = 0) loss = self.model_loss(prep_data, self.model.predict(prep_data), single_value=False).numpy() - self.AVG_LOSS = 0.9*self.AVG_LOSS + 0.1*loss.mean() - self.STD_LOSS = 0.9*self.AVG_LOSS + 0.1*loss.std() + self.avg_loss = 0.9*self.avg_loss + 0.1*loss.mean() + self.std_loss = 0.9*self.std_loss + 0.1*loss.std() self.save_model(self.model_file ,self.model_config_file) diff --git a/resources/src/server/rest.py b/resources/src/server/rest.py index 0eb8f59..aedda1f 100644 --- a/resources/src/server/rest.py +++ b/resources/src/server/rest.py @@ -104,44 +104,19 @@ def calculate(): model = 'default' else: try: - model = base64.b64decode(model).decode('utf-8') + decoded_model = base64.b64decode(model).decode('utf-8') + model_path = os.path.join(self.ai_path, f"{decoded_model}.keras") + if not os.path.isfile(model_path): + logger.logger.error(f"Model {decoded_model} does not exist") + model = 'default' + else: + model = decoded_model except Exception as e: - logger.logger.error(f"Error decoding model -> {e}") + logger.logger.error(f"Error decoding or checking model -> {e}") model = 'default' - if not os.path.isfile(os.path.join(self.ai_path, f"{model}.keras")): - logger.logger.error(f"Model {model} does not exist") - model = 'default' - - if model != 'default': - logger.logger.info(f"Calculating predictions with keras model {model}.keras") - return self.execute_keras_model(druid_query, config.get("Outliers","metric"), model) - logger.logger.info("Calculating predictions with default model") - return self.execute_default_model(druid_query) - - def execute_default_model(self, druid_query): - """ - Execute a keras deep learning model to detect outliers. - - Args: - druid_query (dict): druid query for the data that we want to analyze. - - Returns: - (JSON): json containing the model's predictions and the outliers detected. - """ - try: - data = druid_client.execute_query(druid_query) - except Exception as e: - error_message = "Error while executing druid query" - logger.logger.error(error_message + " -> " + str(e)) - return self.return_error(error=error_message) - try: - return jsonify(shallow_outliers.ShallowOutliers.execute_prediction_model(data)) - except Exception as e: - error_message = "Error while calculating prediction model" - logger.logger.error(error_message + " -> " + str(e)) - return self.return_error(error=error_message) + return self.execute_model(druid_query, config.get("Outliers","metric"), model) - def execute_keras_model(self, druid_query, metric, model): + def execute_model(self, druid_query, metric, model='default'): """ Execute a keras deep learning model to detect outliers. @@ -154,20 +129,21 @@ def execute_keras_model(self, druid_query, metric, model): (JSON): json containing the model's predictions and the outliers detected. """ - try: + if model != 'default': + logger.logger.info(f"Calculating predictions with keras model {model}.keras") druid_query = query_modifier.modify_aggregations(druid_query) - data = druid_client.execute_query(druid_query) - except Exception as e: - error_message = "Error while executing druid query" - logger.logger.error(error_message + " -> " + str(e)) - return self.return_error(error=error_message) + else: + logger.logger.info("Calculating predictions with default model") + data = druid_client.execute_query(druid_query) try: - return jsonify(outliers.Autoencoder.execute_prediction_model( - data, - metric, - os.path.join(self.ai_path, f"{model}.keras"), - os.path.join(self.ai_path, f"{model}.ini") - )) + if model != 'default': + return jsonify(outliers.Autoencoder.execute_prediction_model( + data, + metric, + os.path.join(self.ai_path, f"{model}.keras"), + os.path.join(self.ai_path, f"{model}.ini") + )) + return jsonify(shallow_outliers.ShallowOutliers.execute_prediction_model(data)) except Exception as e: error_message = "Error while calculating prediction model" logger.logger.error(error_message + " -> " + str(e)) diff --git a/resources/tests/test_rest.py b/resources/tests/test_rest.py index c0a358c..b49ba54 100644 --- a/resources/tests/test_rest.py +++ b/resources/tests/test_rest.py @@ -26,6 +26,12 @@ from src.server.rest import APIServer class TestAPIServer(unittest.TestCase): + output_data = { + "anomalies": [{'expected': 1, 'timestamp': '2023-09-21T09:00:00.000Z'}], + "predicted": [{'forecast': 1, 'timestamp': '2023-09-21T09:00:00.000Z'}], + "status": "success" + } + def setUp(self): self.api_server = APIServer() @@ -36,50 +42,80 @@ def test_calculate_endpoint_missing_query(self): data = {'model':'YXNkZg=='} with self.api_server.app.test_client().post('/api/v1/outliers', data=data) as response: self.assertEqual(response.status_code, 200) - self.assertEqual(response.get_json()['status'], 'error') + self.assertEqual( + response.get_json(), + {'msg': 'Error decoding query', 'status': 'error'} + ) def test_calculate_endpoint_invalid_query(self): data = {'model':'YXNkZg==', 'query':'YXNkZg=='} with self.api_server.app.test_client().post('/api/v1/outliers', data=data) as response: self.assertEqual(response.status_code, 200) self.assertEqual( - response.get_json()['status'], - 'error' + response.get_json(), + {'msg': 'Error decoding query', 'status': 'error'} ) @patch('druid.client.DruidClient.execute_query') @patch('ai.shallow_outliers.ShallowOutliers.execute_prediction_model') @patch('os.path.isfile') def test_calculate_endpoint_invalid_model(self, mock_isfile, mock_execute_model, mock_query): - output_data = { - "anomalies": [{'expected': 1, 'timestamp': '2023-09-21T09:00:00.000Z'}], - "predicted": [{'expected': 1, 'timestamp': '2023-09-21T09:00:00.000Z'}], - "status": "success" - } - mock_execute_model.return_value = output_data + mock_execute_model.return_value = self.output_data mock_query.return_value = {} mock_isfile.return_value = False data = {'model':'YXNkZg==', 'query':'eyJhc2RmIjoiYXNkZiJ9'} with self.api_server.app.test_client().post('/api/v1/outliers', data=data) as response: self.assertEqual(response.status_code, 200) - self.assertEqual(response.get_json(), output_data) + self.assertEqual(response.get_json(), self.output_data) + + @patch('druid.client.DruidClient.execute_query') + @patch('ai.shallow_outliers.ShallowOutliers.execute_prediction_model') + @patch('os.path.isfile') + def test_calculate_endpoint_none_model(self, mock_isfile, mock_execute_model, mock_query): + mock_execute_model.return_value = self.output_data + mock_query.return_value = {} + mock_isfile.return_value = False + data = {'query':'eyJhc2RmIjoiYXNkZiJ9'} + with self.api_server.app.test_client().post('/api/v1/outliers', data=data) as response: + self.assertEqual(response.status_code, 200) + self.assertEqual(response.get_json(), self.output_data) + + @patch('druid.client.DruidClient.execute_query') + @patch('ai.shallow_outliers.ShallowOutliers.execute_prediction_model') + @patch('os.path.isfile') + def test_calculate_endpoint_invalid_b64_model(self, mock_isfile, mock_execute_model, mock_query): + mock_execute_model.return_value = self.output_data + mock_query.return_value = {} + mock_isfile.return_value = False + data = {'model':'model', 'query':'eyJhc2RmIjoiYXNkZiJ9'} + with self.api_server.app.test_client().post('/api/v1/outliers', data=data) as response: + self.assertEqual(response.status_code, 200) + self.assertEqual(response.get_json(), self.output_data) @patch('druid.client.DruidClient.execute_query') @patch('ai.outliers.Autoencoder.execute_prediction_model') @patch('os.path.isfile') def test_calculate_endpoint_valid_model(self, mock_isfile, mock_execute_model, mock_query): - output_data = { - "anomalies": [{'expected': 1, 'timestamp': '2023-09-21T09:00:00.000Z'}], - "predicted": [{'expected': 1, 'timestamp': '2023-09-21T09:00:00.000Z'}], - "status": "success" - } - mock_execute_model.return_value = output_data + mock_execute_model.return_value = self.output_data mock_query.return_value = {} mock_isfile.return_value = True data = {'model':'YXNkZg==', 'query':'eyJhc2RmIjoiYXNkZiJ9'} with self.api_server.app.test_client().post('/api/v1/outliers', data=data) as response: self.assertEqual(response.status_code, 200) - self.assertEqual(response.get_json(), output_data) + self.assertEqual(response.get_json(), self.output_data) + + @patch('druid.client.DruidClient.execute_query') + @patch('os.path.isfile') + def test_execute_default_model_invalid_query(self, mock_isfile, mock_query): + mock_query.return_value = {"test":"test"} + mock_isfile.return_value = False + data = {'model':'YXNkZg==', 'query':'eyJhc2RmIjoiYXNkZiJ9'} + with self.api_server.app.test_client().post('/api/v1/outliers', data=data) as response: + self.assertEqual(response.status_code, 200) + self.assertEqual( + response.get_json(), + {'msg': 'Error while calculating prediction model', 'status': 'error'} + ) if __name__ == '__main__': unittest.main() diff --git a/resources/tests/test_trainer.py b/resources/tests/test_trainer.py index 228fb53..feeddfb 100644 --- a/resources/tests/test_trainer.py +++ b/resources/tests/test_trainer.py @@ -51,9 +51,9 @@ def tearDown(self): def test_save_model(self): dummy_model = os.path.join(self.test_backup_path, "dummy.keras") dummy_config = os.path.join(self.test_backup_path, "dummy_config.ini") - self.trainer.METRICS = ["metric1", "metric2"] - self.trainer.AVG_LOSS = 0.5 - self.trainer.STD_LOSS = 0.2 + self.trainer.metrics = ["metric1", "metric2"] + self.trainer.avg_loss = 0.5 + self.trainer.std_loss = 0.2 self.trainer.save_model(dummy_model, dummy_config) self.assertTrue(os.path.exists(dummy_model)) self.assertTrue(os.path.exists(dummy_config)) @@ -68,7 +68,7 @@ def test_save_model(self): def test_prepare_data_for_training(self): data = np.zeros((100,100)) prep_data = self.trainer.prepare_data_for_training(data) - self.assertEqual(prep_data.shape[1], self.trainer.NUM_WINDOWS*self.trainer.WINDOW_SIZE) + self.assertEqual(prep_data.shape[1], self.trainer.num_window*self.trainer.window_size) self.assertEqual(prep_data.shape[2], 100) def test_train(self):