Skip to content

Commit

Permalink
Changed s3 initialization to facilitate testing
Browse files Browse the repository at this point in the history
  • Loading branch information
Pablo Rodríguez Flores committed Jun 13, 2024
1 parent 27a41f3 commit 03d5091
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 61 deletions.
44 changes: 22 additions & 22 deletions resources/src/redborder/async_jobs/train_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,40 @@
from datetime import datetime, timezone, timedelta
from resources.src.redborder.s3 import S3
from resources.src.ai.trainer import Trainer
from resources.src.server.rest import config
from resources.src.config.configmanager import ConfigManager
from resources.src.logger.logger import logger
from resources.src.druid.client import DruidClient
from resources.src.druid.query_builder import QueryBuilder
from resources.src.redborder.rb_ai_outliers_filters import RbAIOutliersFilters

class RbOutlierTrainJob:
def __init__(self) -> None:
def __init__(self, config: ConfigManager) -> None:
"""
Initialize the Outliers application.
This class manages the training and running of the Outliers application.
Args:
config (ConfigManager): Configuration settings including the ones for the S3 client.
"""
self.query_builder = None
self.s3_client = None
self.setup_s3(config)
self.main_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..")
self.druid_client = DruidClient(config.get("Druid", "druid_endpoint"))
self.training_conf = {
"epochs": int(config.get("Outliers", "epochs")),
"batch_size": int(config.get("Outliers", "batch_size")),
"backup_path": config.get("Outliers", "backup_path")
}

def setup_s3(self):
def setup_s3(self, config: ConfigManager):
"""
Set up the S3 client for handling interactions with Amazon S3.
This function initializes the S3 client with the AWS public key, private key, region, bucket, and hostname as specified in the configuration.
Args:
config (ConfigManager): Configuration settings including the ones for the S3 client.
"""
self.s3_client = S3(
config.get("AWS", "s3_public_key"),
Expand Down Expand Up @@ -131,11 +143,9 @@ def train_job(self, model_name):
This function handles the Outliers training process, fetching data, and training the model.
"""
self.setup_s3()
logger.logger.info("Getting model files from S3")
self.download_model_from_s3(model_name)
logger.info("Starting Outliers Train Job")
druid_client = self.initialize_druid_client()
traffic_query = self.load_traffic_query()
self.query_builder = QueryBuilder(
self.get_aggregation_config_path(),
Expand All @@ -146,16 +156,7 @@ def train_job(self, model_name):
os.path.join(self.main_dir, "ai", f"{model_name}.keras"),
os.path.join(self.main_dir, "ai", f"{model_name}.ini"),
)
self.process_model_data(model_name, query,druid_client)

def initialize_druid_client(self):
"""
Initialize the Druid client.
Returns:
DruidClient: The initialized Druid client.
"""
return DruidClient(config.get("Druid", "druid_endpoint"))
self.process_model_data(model_name, query)

def load_traffic_query(self):
"""
Expand Down Expand Up @@ -186,14 +187,13 @@ def get_post_aggregations_config_path(self):
"""
return os.path.join(self.main_dir, "druid", "data", "postAggregations.json")

def process_model_data(self, model_name, query, druid_client):
def process_model_data(self, model_name, query):
"""
Process data and train the model.
Args:
model_name (str): Model identifier.
query (dict): The query to be modified.
druid_client (DruidClient): The Druid client.
This function processes data, modifies the query, and trains the model.
"""
Expand All @@ -207,11 +207,11 @@ def process_model_data(self, model_name, query, druid_client):
traffic_data=[]
for gran in rb_granularities:
temp_query = self.query_builder.modify_granularity(query,gran)
traffic_data.append(druid_client.execute_query(temp_query))
traffic_data.append(self.druid_client.execute_query(temp_query))
self.trainer.train(
traffic_data,
int(config.get("Outliers", "epochs")),
int(config.get("Outliers", "batch_size")),
config.get("Outliers", "backup_path")
self.training_conf["epochs"],
self.training_conf["batch_size"],
self.training_conf["backup_path"]
)
self.upload_model_to_s3(model_name)
42 changes: 22 additions & 20 deletions resources/src/redborder/zookeeper/rb_outliers_zoo_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,21 @@ def __init__(self, config: ConfigManager) -> None:
configurations, including the ZooKeeper client and S3 client.
Args:
config (ConfigManager): Configuration settings including the ones for ZooKeeper client.
config (ConfigManager): Configuration settings including the ones for the ZooKeeper
and S3 clients.
"""
self.config = config
self.is_leader = False
self.is_running = False
self.s3_client = None
self.queue = None
self.election = None
self.leader_watcher = None
self.paths = {}
super().__init__(config)
self.sleep_time = float(config.get("ZooKeeper", "zk_sleep_time"))
self.tick_time = float(config.get("ZooKeeper", "zk_tick_time"))
self.zk_sync_path = config.get("ZooKeeper", "zk_sync_path")
self.setup_s3(config)
self.outlier_job = RbOutlierTrainJob(config)

def _tick(self) -> None:
time.sleep(self.tick_time)
Expand All @@ -65,35 +67,34 @@ def _ensure_paths(self) -> None:
"""
Ensures the required ZooKeeper paths are created.
"""
zk_sync_path = self.config.get("ZooKeeper", "zk_sync_path")

self.paths = {
"leader": os.path.join(zk_sync_path, "leader"),
"queue": os.path.join(zk_sync_path, "models", "queue"),
"taken": os.path.join(zk_sync_path, "models", "taken"),
"train": os.path.join(zk_sync_path, "models", "train"),
"election": os.path.join(zk_sync_path, "election")
"leader": os.path.join(self.zk_sync_path, "leader"),
"queue": os.path.join(self.zk_sync_path, "models", "queue"),
"taken": os.path.join(self.zk_sync_path, "models", "taken"),
"train": os.path.join(self.zk_sync_path, "models", "train"),
"election": os.path.join(self.zk_sync_path, "election")
}
for path in self.paths.values():
self.zookeeper.ensure_path(path)

def setup_s3(self) -> None:
def setup_s3(self, config) -> None:
"""
Sets up the S3 client with the necessary configurations.
"""
self.s3_client = S3(
self.config.get("AWS", "s3_public_key"),
self.config.get("AWS", "s3_private_key"),
self.config.get("AWS", "s3_region"),
self.config.get("AWS", "s3_bucket"),
self.config.get("AWS", "s3_hostname")
config.get("AWS", "s3_public_key"),
config.get("AWS", "s3_private_key"),
config.get("AWS", "s3_region"),
config.get("AWS", "s3_bucket"),
config.get("AWS", "s3_hostname")
)

def sync_nodes(self) -> None:
"""
Synchronizes the nodes and starts the election and task processes.
"""
logger.info("Synchronizing nodes")
self.setup_s3()
self._ensure_paths()
self.is_running = True
self.queue = LockingQueue(self.zookeeper, self.paths["queue"])
Expand All @@ -116,11 +117,13 @@ def cleanup(self, signum: int, frame) -> None:
logger.info(f"Cleanup called with signal {signum}")
self.is_running = False
self.election.cancel()
self.leader_watcher._stopped = True
if self.leader_watcher is not None:
self.leader_watcher._stopped = True
if self.is_leader:
self.is_leader = False
self.zookeeper.set(self.paths["leader"], b"")
self._tick()
self._tick()
super().cleanup(signum, frame)

def _run_tasks(self) -> None:
Expand All @@ -144,7 +147,7 @@ def _leader_tasks(self) -> None:
models = self._get_models()
self._queue_models_on_zoo(models)
next_task_time = time.time() + self.sleep_time
while time.time() < next_task_time:
while time.time() < next_task_time and self.is_running:
for model in models:
is_taken = self._check_node(self.paths["taken"], model)
is_training = self._check_node(self.paths["train"], model)
Expand Down Expand Up @@ -277,8 +280,7 @@ def _process_model_as_follower(self, model: str) -> None:
model (str): The model to process.
"""
try:
outlier_job = RbOutlierTrainJob()
outlier_job.train_job(model)
self.outlier_job.train_job(model)
self._delete_node(self.paths["taken"], model)
self._delete_node(self.paths["train"], model)
logger.info(f"Finished training of model {model}")
Expand Down
2 changes: 1 addition & 1 deletion resources/src/server/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def sync_with_s3_periodically(self):
"""
while True:
logger.logger.info("Sync with S3 Started")
self.sync_models_with_s3(config)
self.sync_models_with_s3()
logger.logger.info("Sync with S3 Finished")
time.sleep(self.s3_sync_interval)

Expand Down
12 changes: 5 additions & 7 deletions resources/tests/test_rb_outliers_zoo_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class TestRbOutliersZooSync(unittest.TestCase):

def setUp(self):
self.zk_sync = RbOutliersZooSync(config)
self.zk_sync.setup_s3()
try:
self.zk_sync.s3_client.s3_client.create_bucket(Bucket=self.zk_sync.s3_client.bucket_name)
except:
Expand Down Expand Up @@ -122,9 +121,8 @@ def test_leader_tasks(self, mock_train_job):
self.assertTrue(len(queue)==3) #Several loops of queueing + one requeued node
self.assertFalse(self.zk_sync._check_node(self.zk_sync.paths["taken"], "model3"))

@patch('resources.src.redborder.zookeeper.rb_outliers_zoo_sync.RbOutlierTrainJob')
def test_follower_tasks(self, mock_train_job):
mock_train_job = MagicMock()
def test_follower_tasks(self):
self.zk_sync.outlier_job = MagicMock()
with self.sync_thread:
self.zk_sync.is_leader = False
self.zk_sync.queue.put(b"model1")
Expand Down Expand Up @@ -163,9 +161,9 @@ def test_get_model_from_queue(self):
self.assertTrue(self.zk_sync._check_node(self.zk_sync.paths["train"], "model1"))
self.assertTrue(self.zk_sync._check_node(self.zk_sync.paths["taken"], "model1"))

@patch('resources.src.redborder.zookeeper.rb_outliers_zoo_sync.RbOutlierTrainJob')
def test_process_model_as_follower(self, mock_train_job):
mock_train_job = MagicMock()

def test_process_model_as_follower(self):
self.zk_sync.outlier_job = MagicMock()
with self.sync_thread:
self.zk_sync._create_node(self.zk_sync.paths["train"], "model1")
self.zk_sync._create_node(self.zk_sync.paths["taken"], "model1")
Expand Down
74 changes: 63 additions & 11 deletions resources/tests/test_train_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,80 @@
# If not, see <https://www.gnu.org/licenses/>.


import os, sys
import os
import sys
import unittest
from unittest.mock import Mock, patch
from unittest.mock import Mock, patch, call
from datetime import datetime, timezone, timedelta

from resources.src.redborder.s3 import S3
from resources.src.config.configmanager import ConfigManager
from resources.src.redborder.async_jobs.train_job import RbOutlierTrainJob
from resources.src.druid.client import DruidClient
from resources.src.ai.trainer import Trainer
from resources.src.logger.logger import logger

config_path = "resources/tests/config_test.ini"
config = ConfigManager(config_path)

class TestRbOutlierTrainJob(unittest.TestCase):

def setUp(self):
self.mock_S3 = patch('resources.src.redborder.async_jobs.train_job.S3').start()
self.mock_config = patch('resources.src.redborder.async_jobs.train_job.config').start()

self.mock_config.get.return_value = 0
self.train_job = RbOutlierTrainJob(config)

def tearDown(self):
patch.stopall()
pass

def test_setup_s3(self):
job = RbOutlierTrainJob()
job.setup_s3()
self.mock_S3.assert_called_with(0,0,0,0,0)
self.assertIsInstance(self.train_job.s3_client, S3)

@patch('shutil.copyfile')
@patch.object(S3, 'download_file')
@patch.object(S3, 'exists', return_value=True)
def test_download_file_exists(self, mock_exists, mock_download_file, mock_copyfile):
s3_path = 's3_path'
local_path = 'local_path'
default_local_path = 'default_local_path'

self.train_job.download_file(s3_path, local_path, default_local_path)

mock_exists.assert_called_once_with(s3_path)
mock_download_file.assert_called_once_with(s3_path, local_path)
mock_copyfile.assert_not_called()

@patch('shutil.copyfile')
@patch.object(S3, 'download_file')
@patch.object(S3, 'exists', return_value=False)
def test_download_file_not_exists(self, mock_exists, mock_download_file, mock_copyfile):
s3_path = 's3_path'
local_path = 'local_path'
default_local_path = 'default_local_path'

self.train_job.download_file(s3_path, local_path, default_local_path)

mock_exists.assert_called_once_with(s3_path)
mock_download_file.assert_not_called()
mock_copyfile.assert_called_once_with(default_local_path, local_path)

@patch.object(S3, 'upload_file')
def test_upload_file(self, mock_upload_file):
local_path = 'local_path'
s3_path = 's3_path'

self.train_job.upload_file(local_path, s3_path)

mock_upload_file.assert_called_once_with(local_path, s3_path)

def test_get_iso_time(self):
iso_time = self.train_job.get_iso_time()
expected_time = datetime.now(timezone.utc).replace(microsecond=0).isoformat()
self.assertEqual(iso_time(), expected_time)

def test_subtract_one_day(self):
iso_time_str = '2023-01-01T00:00:00+00:00'
expected_time_str = '2022-12-31T00:00:00+00:00'
result_time_str = self.train_job.subtract_one_day(iso_time_str)
self.assertEqual(result_time_str, expected_time_str)

if __name__ == '__main__':
unittest.main()
unittest.main()

0 comments on commit 03d5091

Please sign in to comment.