Skip to content

Commit

Permalink
Updated train_job to address change to ZooKeeper
Browse files Browse the repository at this point in the history
  • Loading branch information
Pablo Rodríguez Flores committed Jun 4, 2024
1 parent 73108e3 commit 3386a9b
Showing 1 changed file with 51 additions and 77 deletions.
128 changes: 51 additions & 77 deletions resources/src/redborder/async_jobs/train_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,15 @@
# You should have received a copy of the GNU Affero General Public License along with this program.
# If not, see <https://www.gnu.org/licenses/>.


import os, json

from resources.src.rbntp.ntplib import NTPClient
import os, json, shutil
from resources.src.redborder.s3 import S3
from resources.src.ai.trainer import Trainer
from resources.src.server.rest import config
from resources.src.logger.logger import logger
from resources.src.rbntp.ntplib import NTPClient
from resources.src.druid.client import DruidClient
from resources.src.server.rest import config
from resources.src.druid.query_builder import QueryBuilder
from resources.src.redborder.s3 import S3
from resources.src.redborder.postgresql import RbOutliersPSQL
from resources.src.redborder.rb_ai_outliers_filters import RbAIOutliersFilters

class RbOutlierTrainJob:
def __init__(self) -> None:
Expand All @@ -52,54 +50,67 @@ def setup_s3(self):
config.get("AWS", "s3_hostname")
)

def download_latest_model_config_from_s3(self, model_name):
def download_file(self, s3_path: str, local_path, default_local_path):
"""
Download the latest model configuration file associated with a specific model from Amazon S3 and save it to the local AI directory.
Helper function to download a file from S3, falling back to copying a default file locally if necessary.
Args:
model_name (str): The identifier of the model for which the latest model configuration file needs to be downloaded.
s3_path (str): The S3 path of the file to download.
local_path (str): The local path where the file will be saved.
default_local_path (str): The local path of the default file to copy if the primary file does not exist in S3.
"""
self.s3_client.download_file(
f'rbaioutliers/latest/{model_name}.ini',
os.path.join(self.main_dir,"ai", f"{model_name}.ini")
)
try:
if self.s3_client.exists(s3_path):
self.s3_client.download_file(s3_path, local_path)
logger.logger.info(f"Downloaded {s3_path} to {local_path}")
else:
shutil.copyfile(default_local_path, local_path)
logger.logger.info(f"File {s3_path} not found in S3. Copied default file from {default_local_path} to {local_path}")
except Exception as e:
logger.logger.error(f"Error processing file from S3 or copying default file: {e}")

def download_latest_model_from_s3(self, model_name):
def download_model_from_s3(self, model_name):
"""
Download the latest model file associated with a specific model from Amazon S3 and save it to the local AI directory.
Download the latest files associated with a specific model from Amazon S3 and save them to the local AI directory.
Args:
model_name (str): The identifier of the model for which the latest model file needs to be downloaded.
model_name (str): The identifier of the model for which the latest file needs to be downloaded.
"""
self.s3_client.download_file(
f'rbaioutliers/latest/{model_name}.keras',
os.path.join(self.main_dir,"ai", f"{model_name}.keras")
)
extensions = ['keras', 'ini']
for ext in extensions:
filename = f"{model_name}.{ext}"
s3_path = f'rbaioutliers/latest/{filename}'
local_path = os.path.join(self.main_dir, "ai", filename)
default_local_path = os.path.join(self.main_dir, "ai", f'traffic.{ext}')
self.download_file(s3_path, local_path, default_local_path)

def download_latest_model_filter_from_s3(self, model_name):
def upload_file(self, local_path, s3_path):
"""
Download the latest model filter file associated with a specific model from Amazon S3 and save it to the local AI directory.
Helper function to upload a file to an S3 bucket.
Args:
model_name (str): The identifier of the model for which the latest model file needs to be downloaded.
local_path (str): The local path of the file to upload.
s3_path (str): The S3 path where the file will be uploaded.
"""
self.s3_client.download_file(
f'rbaioutliers/latest/{model_name}_filter.json',
os.path.join(self.main_dir,"ai", f"{model_name}_filter.json")
)
try:
self.s3_client.upload_file(local_path, s3_path)
logger.logger.info(f"Uploaded {local_path} to {s3_path}")
except Exception as e:
logger.logger.error(f"Error uploading file to S3: {e}")

def get_model_filter(self, model_name):
def upload_model_to_s3(self, model_name):
"""
Given a model name, returns its filter as a python dictionary.
Upload the latest files associated with a specific model to Amazon S3.
Args:
model_name (str): The identifier of the model.
Returns:
(dict): Dictionary with the filter of the model.
model_name (str): The identifier of the model for which the latest file needs to be uploaded.
"""
with open(os.path.join(self.main_dir,"ai", f"{model_name}_filter.json"), 'r') as json_file:
return json.load(json_file)
extensions = ['keras', 'ini']
for ext in extensions:
filename = f"{model_name}.{ext}"
local_path = os.path.join(self.main_dir, "ai", filename)
s3_path = f'rbaioutliers/latest/{filename}'
self.upload_file(local_path, s3_path)

def train_job(self, model_name):
"""
Expand All @@ -108,20 +119,18 @@ def train_job(self, model_name):
This function handles the Outliers training process, fetching data, and training the model.
"""
self.setup_s3()
logger.logger.info("Getting model files from S3")
self.download_model_from_s3(model_name)
logger.info("Starting Outliers Train Job")
redborder_ntp = self.initialize_ntp_client()
druid_client = self.initialize_druid_client()

manager_time = redborder_ntp.get_ntp_time()

traffic_query = self.load_traffic_query()

self.query_builder = QueryBuilder(
self.get_aggregation_config_path(),
self.get_post_aggregations_config_path()
)
query = self.query_builder.modify_aggregations(traffic_query)

self.trainer = Trainer(
os.path.join(self.main_dir, "ai", f"{model_name}.keras"),
os.path.join(self.main_dir, "ai", f"{model_name}.ini"),
Expand Down Expand Up @@ -175,41 +184,6 @@ def get_post_aggregations_config_path(self):
"""
return os.path.join(self.main_dir, "druid", "data", "postAggregations.json")

def upload_model_results_back_to_s3(self, model_name):
"""
Upload a model file associated with a specific model to an Amazon S3 bucket.
Args:
model_name (str): The identifier or name for which the model file needs to be uploaded to S3.
"""
self.s3_client.upload_file(
os.path.join(self.main_dir, "ai", f"{model_name}.keras"),
f'rbaioutliers/latest/{model_name}.keras',
f'rbaioutliers/latest/{model_name}_filter.json'
)

def upload_model_config_results_back_to_s3(self, model_name):
"""
Upload a model configuration file associated with a specific model to an Amazon S3 bucket.
Args:
model_name (str): The name for which the model configuration file needs to be uploaded to S3.
"""
self.s3_client.upload_file(
os.path.join(self.main_dir, "ai", f"{model_name}.ini"),
f'rbaioutliers/latest/{model_name}.ini'
)

def upload_results_back_to_s3(self, model_name):
"""
Upload results for all models to an Amazon S3 bucket.
This function iterates through a list of models and uploads both the model file and model configuration
file for each model to the 'rbaioutliers/latest' path in the S3 bucket.
"""
self.upload_model_results_back_to_s3(model_name)
self.upload_model_config_results_back_to_s3(model_name)

def process_model_data(self, model_name, query, redborder_ntp, manager_time, druid_client):
"""
Process data and train the model.
Expand All @@ -226,7 +200,7 @@ def process_model_data(self, model_name, query, redborder_ntp, manager_time, dru
rb_granularities=["pt1m", "pt2m", "pt5m", "pt15m", "pt30m", "pt1h", "pt2h", "pt8h"]
start_time = redborder_ntp.time_to_iso8601_time(redborder_ntp.get_substracted_day_time(manager_time))
end_time = redborder_ntp.time_to_iso8601_time(manager_time)
model_filter = self.get_model_filter(model_name)
model_filter = RbAIOutliersFilters().get_filtered_data(model_name)
query = self.query_builder.modify_filter(query, model_filter)
query = self.query_builder.set_time_origin(query, start_time)
query = self.query_builder.set_time_interval(query, start_time, end_time)
Expand All @@ -240,4 +214,4 @@ def process_model_data(self, model_name, query, redborder_ntp, manager_time, dru
int(config.get("Outliers", "batch_size")),
config.get("Outliers", "backup_path")
)
self.upload_results_back_to_s3(model_name)
self.upload_model_to_s3(model_name)

0 comments on commit 3386a9b

Please sign in to comment.