diff --git a/resources/src/redborder/async_jobs/train_job.py b/resources/src/redborder/async_jobs/train_job.py index 4f232e6..7e34d43 100644 --- a/resources/src/redborder/async_jobs/train_job.py +++ b/resources/src/redborder/async_jobs/train_job.py @@ -15,17 +15,15 @@ # You should have received a copy of the GNU Affero General Public License along with this program. # If not, see . - -import os, json - -from resources.src.rbntp.ntplib import NTPClient +import os, json, shutil +from resources.src.redborder.s3 import S3 from resources.src.ai.trainer import Trainer +from resources.src.server.rest import config from resources.src.logger.logger import logger +from resources.src.rbntp.ntplib import NTPClient from resources.src.druid.client import DruidClient -from resources.src.server.rest import config from resources.src.druid.query_builder import QueryBuilder -from resources.src.redborder.s3 import S3 -from resources.src.redborder.postgresql import RbOutliersPSQL +from resources.src.redborder.rb_ai_outliers_filters import RbAIOutliersFilters class RbOutlierTrainJob: def __init__(self) -> None: @@ -52,54 +50,67 @@ def setup_s3(self): config.get("AWS", "s3_hostname") ) - def download_latest_model_config_from_s3(self, model_name): + def download_file(self, s3_path: str, local_path, default_local_path): """ - Download the latest model configuration file associated with a specific model from Amazon S3 and save it to the local AI directory. + Helper function to download a file from S3, falling back to copying a default file locally if necessary. Args: - model_name (str): The identifier of the model for which the latest model configuration file needs to be downloaded. + s3_path (str): The S3 path of the file to download. + local_path (str): The local path where the file will be saved. + default_local_path (str): The local path of the default file to copy if the primary file does not exist in S3. """ - self.s3_client.download_file( - f'rbaioutliers/latest/{model_name}.ini', - os.path.join(self.main_dir,"ai", f"{model_name}.ini") - ) + try: + if self.s3_client.exists(s3_path): + self.s3_client.download_file(s3_path, local_path) + logger.logger.info(f"Downloaded {s3_path} to {local_path}") + else: + shutil.copyfile(default_local_path, local_path) + logger.logger.info(f"File {s3_path} not found in S3. Copied default file from {default_local_path} to {local_path}") + except Exception as e: + logger.logger.error(f"Error processing file from S3 or copying default file: {e}") - def download_latest_model_from_s3(self, model_name): + def download_model_from_s3(self, model_name): """ - Download the latest model file associated with a specific model from Amazon S3 and save it to the local AI directory. + Download the latest files associated with a specific model from Amazon S3 and save them to the local AI directory. Args: - model_name (str): The identifier of the model for which the latest model file needs to be downloaded. + model_name (str): The identifier of the model for which the latest file needs to be downloaded. """ - self.s3_client.download_file( - f'rbaioutliers/latest/{model_name}.keras', - os.path.join(self.main_dir,"ai", f"{model_name}.keras") - ) + extensions = ['keras', 'ini'] + for ext in extensions: + filename = f"{model_name}.{ext}" + s3_path = f'rbaioutliers/latest/{filename}' + local_path = os.path.join(self.main_dir, "ai", filename) + default_local_path = os.path.join(self.main_dir, "ai", f'traffic.{ext}') + self.download_file(s3_path, local_path, default_local_path) - def download_latest_model_filter_from_s3(self, model_name): + def upload_file(self, local_path, s3_path): """ - Download the latest model filter file associated with a specific model from Amazon S3 and save it to the local AI directory. + Helper function to upload a file to an S3 bucket. Args: - model_name (str): The identifier of the model for which the latest model file needs to be downloaded. + local_path (str): The local path of the file to upload. + s3_path (str): The S3 path where the file will be uploaded. """ - self.s3_client.download_file( - f'rbaioutliers/latest/{model_name}_filter.json', - os.path.join(self.main_dir,"ai", f"{model_name}_filter.json") - ) + try: + self.s3_client.upload_file(local_path, s3_path) + logger.logger.info(f"Uploaded {local_path} to {s3_path}") + except Exception as e: + logger.logger.error(f"Error uploading file to S3: {e}") - def get_model_filter(self, model_name): + def upload_model_to_s3(self, model_name): """ - Given a model name, returns its filter as a python dictionary. + Upload the latest files associated with a specific model to Amazon S3. Args: - model_name (str): The identifier of the model. - - Returns: - (dict): Dictionary with the filter of the model. + model_name (str): The identifier of the model for which the latest file needs to be uploaded. """ - with open(os.path.join(self.main_dir,"ai", f"{model_name}_filter.json"), 'r') as json_file: - return json.load(json_file) + extensions = ['keras', 'ini'] + for ext in extensions: + filename = f"{model_name}.{ext}" + local_path = os.path.join(self.main_dir, "ai", filename) + s3_path = f'rbaioutliers/latest/{filename}' + self.upload_file(local_path, s3_path) def train_job(self, model_name): """ @@ -108,20 +119,18 @@ def train_job(self, model_name): This function handles the Outliers training process, fetching data, and training the model. """ self.setup_s3() + logger.logger.info("Getting model files from S3") + self.download_model_from_s3(model_name) logger.info("Starting Outliers Train Job") redborder_ntp = self.initialize_ntp_client() druid_client = self.initialize_druid_client() - manager_time = redborder_ntp.get_ntp_time() - traffic_query = self.load_traffic_query() - self.query_builder = QueryBuilder( self.get_aggregation_config_path(), self.get_post_aggregations_config_path() ) query = self.query_builder.modify_aggregations(traffic_query) - self.trainer = Trainer( os.path.join(self.main_dir, "ai", f"{model_name}.keras"), os.path.join(self.main_dir, "ai", f"{model_name}.ini"), @@ -175,41 +184,6 @@ def get_post_aggregations_config_path(self): """ return os.path.join(self.main_dir, "druid", "data", "postAggregations.json") - def upload_model_results_back_to_s3(self, model_name): - """ - Upload a model file associated with a specific model to an Amazon S3 bucket. - - Args: - model_name (str): The identifier or name for which the model file needs to be uploaded to S3. - """ - self.s3_client.upload_file( - os.path.join(self.main_dir, "ai", f"{model_name}.keras"), - f'rbaioutliers/latest/{model_name}.keras', - f'rbaioutliers/latest/{model_name}_filter.json' - ) - - def upload_model_config_results_back_to_s3(self, model_name): - """ - Upload a model configuration file associated with a specific model to an Amazon S3 bucket. - - Args: - model_name (str): The name for which the model configuration file needs to be uploaded to S3. - """ - self.s3_client.upload_file( - os.path.join(self.main_dir, "ai", f"{model_name}.ini"), - f'rbaioutliers/latest/{model_name}.ini' - ) - - def upload_results_back_to_s3(self, model_name): - """ - Upload results for all models to an Amazon S3 bucket. - - This function iterates through a list of models and uploads both the model file and model configuration - file for each model to the 'rbaioutliers/latest' path in the S3 bucket. - """ - self.upload_model_results_back_to_s3(model_name) - self.upload_model_config_results_back_to_s3(model_name) - def process_model_data(self, model_name, query, redborder_ntp, manager_time, druid_client): """ Process data and train the model. @@ -226,7 +200,7 @@ def process_model_data(self, model_name, query, redborder_ntp, manager_time, dru rb_granularities=["pt1m", "pt2m", "pt5m", "pt15m", "pt30m", "pt1h", "pt2h", "pt8h"] start_time = redborder_ntp.time_to_iso8601_time(redborder_ntp.get_substracted_day_time(manager_time)) end_time = redborder_ntp.time_to_iso8601_time(manager_time) - model_filter = self.get_model_filter(model_name) + model_filter = RbAIOutliersFilters().get_filtered_data(model_name) query = self.query_builder.modify_filter(query, model_filter) query = self.query_builder.set_time_origin(query, start_time) query = self.query_builder.set_time_interval(query, start_time, end_time) @@ -240,4 +214,4 @@ def process_model_data(self, model_name, query, redborder_ntp, manager_time, dru int(config.get("Outliers", "batch_size")), config.get("Outliers", "backup_path") ) - self.upload_results_back_to_s3(model_name) \ No newline at end of file + self.upload_model_to_s3(model_name)