Skip to content

Commit

Permalink
Fixed tests and trainer class.
Browse files Browse the repository at this point in the history
  • Loading branch information
PRodriguezFlores committed Oct 6, 2023
1 parent 327bfed commit 4f5dc57
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 20 deletions.
5 changes: 2 additions & 3 deletions resources/src/ai/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import configparser
from datetime import datetime
from src.ai.outliers import Autoencoder

from tensorflow import keras
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

class Trainer(Autoencoder):
Expand Down Expand Up @@ -67,7 +67,6 @@ def save_model(self, save_model_file, save_config_file):
if os.path.exists(save_config_file):
if not os.access(save_config_file, os.W_OK):
raise PermissionError(f"Permission denied: Cannot overwrite '{save_config_file}'")

self.model.save(save_model_file)
new_model_config = configparser.ConfigParser()
new_model_config.add_section('Columns')
Expand Down Expand Up @@ -131,7 +130,7 @@ def train(self, raw_data, epochs=20, batch_size=32, backup_path=None):
"""
if backup_path is None:
backup_path = "./backups/"
date = datetime.now().strftime("%y-%m-%dT%H%M")
date = datetime.now().strftime("%y-%m-%dT%H:%M")
self.save_model(f"{backup_path}{date}.keras",f"{backup_path}{date}.ini")
data = self.input_json(raw_data)[0]
prep_data = self.prepare_data_for_training(data)
Expand Down
48 changes: 31 additions & 17 deletions resources/tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,46 +20,60 @@
import unittest
import os
import sys
import logging
import json
import numpy as np
import configparser

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.ai.trainer import Trainer

class TestTrainer(unittest.TestCase):

def setUp(self):
self.test_backup_path = "./test_backups/"
self.test_backup_path = "./resources/tests/test_backups/"
os.makedirs(self.test_backup_path, exist_ok=True)
self.trainer = Trainer("test_traffic.keras", "test_traffic.ini")
self.trainer.model_config_file = "dummy_config.ini"
self.trainer.model_file = "dummy.keras"
self.trainer = Trainer("./resources/tests/test_model.keras", "./resources/tests/test_model_config.ini")
self.trainer.model_config_file = "./resources/tests/dummy_config.ini"
self.trainer.model_file = "./resources/tests/dummy.keras"

def tearDown(self):
os.rmdir(self.test_backup_path)
os.remove("dummy_config.ini")
os.remove("dummy.keras")
if os.path.exists(self.test_backup_path):
for filename in os.listdir(self.test_backup_path):
if os.path.isfile(os.path.join(self.test_backup_path, filename)):
os.remove(os.path.join(self.test_backup_path, filename))
os.rmdir(self.test_backup_path)
if os.path.isfile(os.path.join(self.trainer.model_config_file)):
os.remove(self.trainer.model_config_file)
if os.path.isfile(self.trainer.model_file):
os.remove(self.trainer.model_file)

def test_save_model(self):
save_model_file = os.path.join(self.test_backup_path, "test_model.keras")
save_config_file = os.path.join(self.test_backup_path, "test_config.ini")
self.trainer.model = "dummy_model"
dummy_model = os.path.join(self.test_backup_path, "dummy.keras")
dummy_config = os.path.join(self.test_backup_path, "dummy_config.ini")
self.trainer.METRICS = ["metric1", "metric2"]
self.trainer.AVG_LOSS = 0.5
self.trainer.STD_LOSS = 0.2
self.trainer.save_model(save_model_file, save_config_file)
self.assertTrue(os.path.exists(save_model_file))
self.assertTrue(os.path.exists(save_config_file))
self.trainer.save_model(dummy_model, dummy_config)
self.assertTrue(os.path.exists(dummy_model))
self.assertTrue(os.path.exists(dummy_config))
model_config = configparser.ConfigParser()
model_config.read(dummy_config)
columns_section = model_config['Columns']
self.assertEqual(["metric1", "metric2"], columns_section.get('METRICS', '').split(', '))
general_section = model_config['General']
self.assertEqual('0.5', general_section.get('AVG_LOSS'))
self.assertEqual('0.2', general_section.get('STD_LOSS'))

def test_prepare_data_for_training(self):
data = np.array([1, 2, 3, 4, 5])
data = np.zeros((100,100))
prep_data = self.trainer.prepare_data_for_training(data)
self.assertEqual(prep_data.shape[1], self.trainer.NUM_WINDOWS*self.trainer.WINDOW_SIZE)
self.assertEqual(prep_data.shape[2], 100)

def test_train(self):
with open("outliers_test_data.json", "r") as file:
with open("./resources/tests/outliers_test_data.json", "r") as file:
raw_data = json.load(file)
self.trainer.train(raw_data, epochs=10, batch_size=32, backup_path=self.test_backup_path)

if __name__ == "__main__":
unittest.main()
unittest.main()

0 comments on commit 4f5dc57

Please sign in to comment.