diff --git a/.github/workflows/test_cornflow_server.yml b/.github/workflows/test_cornflow_server.yml index 6769fd2e..68ef5af8 100644 --- a/.github/workflows/test_cornflow_server.yml +++ b/.github/workflows/test_cornflow_server.yml @@ -52,6 +52,13 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + - name: Set up Quarto + uses: quarto-dev/quarto-actions/setup@v2 + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + # To install LaTeX to build PDF book + tinytex: true - name: Copy DAG files run: | cd .. diff --git a/cornflow-dags/DAG/tsp/core/experiment.py b/cornflow-dags/DAG/tsp/core/experiment.py index 45a4e618..277d6300 100644 --- a/cornflow-dags/DAG/tsp/core/experiment.py +++ b/cornflow-dags/DAG/tsp/core/experiment.py @@ -7,7 +7,7 @@ from .solution import Solution import json, tempfile -from quarto import render +import quarto class Experiment(ExperimentCore): @@ -51,8 +51,6 @@ def get_objective(self) -> float: # if solution is empty, we return 0 if len(self.solution.data["route"]) == 0: return 0 - # we get a sorted list of nodes by position - arcs = self.solution.get_used_arcs() # we sum all arc weights in the solution return sum(self.get_used_arc_weights().values()) @@ -93,13 +91,17 @@ def generate_report(self, report_path: str, report_name="report") -> None: if not os.path.exists(path_to_qmd): raise FileNotFoundError(f"Report with path {path_to_qmd} does not exist.") path_to_output = path_without_ext + ".html" + try: + quarto.quarto.find_quarto() + except FileNotFoundError: + raise ModuleNotFoundError("Quarto is not installed.") with tempfile.TemporaryDirectory() as tmp: path = os.path.join(tmp, "experiment.json") # write a json with instance and solution to temp file self.to_json(path) # pass the path to the report to render # it generates a report with path = path_to_output - render(input=path_to_qmd, execute_params=dict(file_name=path)) + quarto.render(input=path_to_qmd, execute_params=dict(file_name=path)) # quarto always writes the report in the .qmd directory. # thus, we need to move it where we want to: os.replace(path_to_output, report_path) diff --git a/cornflow-dags/DAG/tsp/core/instance.py b/cornflow-dags/DAG/tsp/core/instance.py index 2f82110e..ab86ad88 100644 --- a/cornflow-dags/DAG/tsp/core/instance.py +++ b/cornflow-dags/DAG/tsp/core/instance.py @@ -64,7 +64,7 @@ def to_tsplib95(self): def get_arcs(self) -> TupList: return self.data["arcs"] - def get_indexed_arcs(self) -> TupList: + def get_indexed_arcs(self) -> SuperDict: return self.data["arcs"].to_dict( result_col=None, indices=["n1", "n2"], is_list=False ) diff --git a/cornflow-dags/tests/test_dags.py b/cornflow-dags/tests/test_dags.py index 24d82f66..c9528021 100644 --- a/cornflow-dags/tests/test_dags.py +++ b/cornflow-dags/tests/test_dags.py @@ -17,7 +17,11 @@ sys.modules["airflow.secrets.environment_variables"] = mymodule from cornflow_client import SchemaManager, ApplicationCore -from cornflow_client.airflow.dag_utilities import cf_solve +from cornflow_client.airflow.dag_utilities import ( + cf_solve, + cf_report, + AirflowDagException, +) from jsonschema import Draft7Validator from pytups import SuperDict @@ -193,6 +197,16 @@ def test_report(self): pass self.assertRaises(StopIteration, parser.feed, content) + def test_report_error(self): + tests = self.app.test_cases + my_experim = self.app.solvers["cpsat"](self.app.instance(tests[0]["instance"])) + my_experim.solve(dict()) + report_path = "./my_report.html" + my_fun = my_experim.generate_report( + report_path=report_path, report_name="wrong_name" + ) + self.assertRaises(FileNotFoundError, my_fun) + def test_export(self): tests = self.app.test_cases my_file_path = "export.json" @@ -203,6 +217,86 @@ def test_export(self): except FileNotFoundError: pass + @patch("cornflow_client.airflow.dag_utilities.connect_to_cornflow") + def test_complete_report(self, connectCornflow, config=None): + config = config or self.config + config = dict(**config, report=dict(name="report")) + tests = self.app.test_cases + for test_case in tests: + instance_data = test_case.get("instance") + solution_data = test_case.get("solution", None) + if solution_data is None: + solution_data = dict(route=[]) + + mock = Mock() + mock.get_data.return_value = dict( + data=instance_data, solution_data=solution_data + ) + mock.get_results.return_value = dict(config=config, state=1) + mock.create_report.return_value = dict(id=1) + connectCornflow.return_value = mock + dag_run = Mock() + dag_run.conf = dict(exec_id="exec_id") + cf_report(app=self.app, secrets="", dag_run=dag_run) + mock.create_report.assert_called_once() + mock.put_one_report.assert_called_once() + + @patch("cornflow_client.airflow.dag_utilities.connect_to_cornflow") + def test_complete_report_wrong_data(self, connectCornflow, config=None): + config = config or self.config + config = dict(**config, report=dict(name="report")) + tests = self.app.test_cases + for test_case in tests: + instance_data = test_case.get("instance") + solution_data = None + + mock = Mock() + mock.get_data.return_value = dict( + data=instance_data, solution_data=solution_data + ) + mock.get_results.return_value = dict(config=config, state=1) + mock.create_report.return_value = dict(id=1) + connectCornflow.return_value = mock + dag_run = Mock() + dag_run.conf = dict(exec_id="exec_id") + my_report = lambda: cf_report(app=self.app, secrets="", dag_run=dag_run) + self.assertRaises(AirflowDagException, my_report) + mock.create_report.assert_called_once() + mock.raw.put_api_for_id.assert_called_once() + args, kwargs = mock.raw.put_api_for_id.call_args + self.assertEqual(kwargs["data"], {"state": -1}) + + @patch("quarto.render") + @patch("cornflow_client.airflow.dag_utilities.connect_to_cornflow") + def test_complete_report_no_quarto(self, connectCornflow, render, config=None): + config = config or self.config + config = dict(**config, report=dict(name="report")) + tests = self.app.test_cases + render.side_effect = ModuleNotFoundError() + render.return_value = dict(a=1) + for test_case in tests: + instance_data = test_case.get("instance") + solution_data = test_case.get("solution", None) + if solution_data is None: + solution_data = dict(route=[]) + + mock = Mock() + mock.get_data.return_value = dict( + data=instance_data, + solution_data=solution_data, + ) + mock.get_results.return_value = dict(config=config, state=1) + mock.create_report.return_value = dict(id=1) + connectCornflow.return_value = mock + dag_run = Mock() + dag_run.conf = dict(exec_id="exec_id") + my_report = lambda: cf_report(app=self.app, secrets="", dag_run=dag_run) + self.assertRaises(AirflowDagException, my_report) + mock.create_report.assert_called_once() + mock.raw.put_api_for_id.assert_called_once() + args, kwargs = mock.raw.put_api_for_id.call_args + self.assertEqual(kwargs["data"], {"state": -10}) + class Vrp(BaseDAGTests.SolvingTests): def setUp(self): diff --git a/cornflow-server/cornflow/config.py b/cornflow-server/cornflow/config.py index 627b19c3..102e005f 100644 --- a/cornflow-server/cornflow/config.py +++ b/cornflow-server/cornflow/config.py @@ -95,7 +95,6 @@ class Development(DefaultConfig): """ """ ENV = "development" - UPLOAD_FOLDER = os.getenv("UPLOAD_FOLDER", "/usr/src/app/static") class Testing(DefaultConfig): @@ -115,10 +114,6 @@ class Testing(DefaultConfig): AIRFLOW_PWD = os.getenv("AIRFLOW_PWD", "admin") OPEN_DEPLOYMENT = 1 LOG_LEVEL = int(os.getenv("LOG_LEVEL", 10)) - UPLOAD_FOLDER = os.getenv( - "UPLOAD_FOLDER", - os.path.abspath(os.path.join(os.path.dirname(__file__), "./static")), - ) class Production(DefaultConfig): diff --git a/cornflow-server/cornflow/endpoints/__init__.py b/cornflow-server/cornflow/endpoints/__init__.py index 6028d7d8..341ed81a 100644 --- a/cornflow-server/cornflow/endpoints/__init__.py +++ b/cornflow-server/cornflow/endpoints/__init__.py @@ -38,7 +38,7 @@ ExecutionRelaunchEndpoint, ) -from .reports import ReportEndpoint, ReportDetailsEndpoint +from .reports import ReportEndpoint, ReportDetailsEndpoint, ReportDetailsEditEndpoint from .health import HealthEndpoint from .instance import ( InstanceEndpoint, @@ -224,6 +224,11 @@ urls="/report//", endpoint="report-detail", ), + dict( + resource=ReportDetailsEditEndpoint, + urls="/report//edit/", + endpoint="report-detail-edit", + ), dict(resource=ReportEndpoint, urls="/report/", endpoint="report"), ] diff --git a/cornflow-server/cornflow/endpoints/reports.py b/cornflow-server/cornflow/endpoints/reports.py index b2d46421..66b8f7f4 100644 --- a/cornflow-server/cornflow/endpoints/reports.py +++ b/cornflow-server/cornflow/endpoints/reports.py @@ -7,6 +7,7 @@ from flask import current_app, request, send_from_directory from flask_apispec import marshal_with, use_kwargs, doc from werkzeug.utils import secure_filename +import uuid from cornflow.endpoints.meta_resource import BaseMetaResource from cornflow.models import ExecutionModel, ReportModel @@ -75,49 +76,24 @@ def post(self, **kwargs): if execution is None: raise ObjectDoesNotExist("The execution does not exist") - if "file" not in request.files: - return {"message": "No file part"}, 400 - - file = request.files["file"] - filename = secure_filename(file.filename) - filename_extension = filename.split(".")[-1] - - if filename_extension not in current_app.config["ALLOWED_EXTENSIONS"]: - return { - "message": f"Invalid file extension. " - f"Valid extensions are: {current_app.config['ALLOWED_EXTENSIONS']}" - }, 400 - - my_directory = f"{current_app.config['UPLOAD_FOLDER']}/{execution.id}" - - # we create a directory for the execution - if not os.path.exists(my_directory): - current_app.logger.info(f"Creating directory {my_directory}") - os.mkdir(my_directory) + # we're creating an empty report. + # which is possible + report = ReportModel(get_report_info(kwargs, execution, None)) - report_name = f"{secure_filename(kwargs['name'])}.{filename_extension}" - - save_path = os.path.normpath(os.path.join(my_directory, report_name)) + report.save() + return report, 201 - if "static" not in save_path or ".." in save_path: - raise NoPermission("Invalid file name") + file = request.files["file"] + report_name = new_file_name(file) - report = ReportModel( - { - "name": kwargs["name"], - "file_url": save_path, - "execution_id": kwargs["execution_id"], - "user_id": execution.user_id, - "description": kwargs.get("description", ""), - } - ) + report = ReportModel(get_report_info(kwargs, execution, report_name)) report.save() + # We try to save the file, if an error is raised then we delete the record on the database try: - # We try to save the file, if an error is raised then we delete the record on the database - file.save(save_path) + write_file(file, execution.id, report_name) return report, 201 except Exception as error: @@ -137,39 +113,13 @@ def __init__(self): self.foreign_data = {"execution_id": ExecutionModel} -class ReportDetailsEndpoint(ReportDetailsEndpointBase): - @doc(description="Get details of a report", tags=["Reports"], inherit=False) - @authenticate(auth_class=Auth()) - @marshal_with(ReportSchema) - @BaseMetaResource.get_data_or_404 - def get(self, idx): - """ - API method to get a report created by the user and its related info. - It requires authentication to be passed in the form of a token that has to be linked to - an existing session (login) made by a user. - - :param str idx: ID of the report. - :return: A dictionary with a message (error if authentication failed, or the report does not exist or - the data of the report) and an integer with the HTTP status code. - :rtype: Tuple(dict, integer) - """ - current_app.logger.info(f"User {self.get_user()} gets details of report {idx}") - report = self.get_detail(user=self.get_user(), idx=idx) - if report is None: - raise ObjectDoesNotExist - - directory, file = report.file_url.split(report.name) - file = f"{report.name}{file}" - directory = directory[:-1] +class ReportDetailsEditEndpoint(ReportDetailsEndpointBase): - response = send_from_directory(directory, file) - response.headers["File-Description"] = report.description - response.headers["File-Name"] = report.name - return response + ROLES_WITH_ACCESS = [SERVICE_ROLE] @doc(description="Edit a report", tags=["Reports"], inherit=False) @authenticate(auth_class=Auth()) - @use_kwargs(ReportEditRequest, location="json") + @use_kwargs(ReportEditRequest, location="form") def put(self, idx, **data): """ Edit an existing report @@ -179,34 +129,82 @@ def put(self, idx, **data): a message) and an integer with the HTTP status code. :rtype: Tuple(dict, integer) """ + # TODO: forbid non-service users from running put current_app.logger.info(f"User {self.get_user()} edits report {idx}") - report = self.get_detail(user=self.get_user(), idx=idx) + report = self.get_detail(idx=idx) + + if "file" not in request.files: + # we're creating an empty report. + # which is possible + report.update(data) + report.save() + return {"message": "Updated correctly"}, 200 + + # there's two cases, + # (1) the report already has a file + # (2) the report doesn't yet have a file + file = request.files["file"] + report_name = new_file_name(file) + old_name = report.file_url + # we update the report with the new content, including the new name + report.update(dict(**data, file_url=report_name)) + # We try to save the file, if an error is raised then we delete the record on the database try: - if report.name != data["name"]: - directory, file = report.file_url.split(report.name) + write_file(file, report.execution_id, report_name) + report.save() - new_location = ( - f"{os.path.join(directory, secure_filename(data['name']))}{file}" - ) - old_location = report.file_url + except Exception as error: + # we do not save the report + current_app.logger.error(error) + raise FileError(error=str(error)) - current_app.logger.debug(f"Old location: {old_location}") - current_app.logger.debug(f"New location: {new_location}") + # if it saves correctly, we delete the old file, if exists + # if unsuccessful, we still return 201 but log the error + if old_name is not None: + try: + os.remove(get_report_path(report)) + except OSError as error: + current_app.logger.error(error) + return {"message": "Updated correctly"}, 200 - os.rename(old_location, new_location) - data["file_url"] = new_location - except Exception as error: - current_app.logger.error(error) - return {"error": "Error moving file"}, 400 +class ReportDetailsEndpoint(ReportDetailsEndpointBase): + @doc(description="Get details of a report", tags=["Reports"], inherit=False) + @authenticate(auth_class=Auth()) + @marshal_with(ReportSchema) + @BaseMetaResource.get_data_or_404 + def get(self, idx): + """ + API method to get a report created by the user and its related info. + It requires authentication to be passed in the form of a token that has to be linked to + an existing session (login) made by a user. - report.update(data) + :param str idx: ID of the report. + :return: A dictionary with a message (error if authentication failed, or the report does not exist or + the data of the report) and an integer with the HTTP status code. + :rtype: Tuple(dict, integer) + """ + # TODO: are we able to download the name in the database and not as part of the file? + current_app.logger.info(f"User {self.get_user()} gets details of report {idx}") + report = self.get_detail(user=self.get_user(), idx=idx) - report.save() + if report is None: + print("error") + raise ObjectDoesNotExist - return {"message": "Updated correctly"}, 200 + # if there's no file, we do not return it: + if report.file_url is None: + return report, 200 + + my_dir = get_report_dir(report.execution_id) + print(my_dir) + print(report.file_url) + response = send_from_directory(my_dir, report.file_url) + response.headers["File-Description"] = report.description + response.headers["File-Name"] = report.file_url + return response @doc(description="Delete a report", tags=["Reports"], inherit=False) @authenticate(auth_class=Auth()) @@ -229,6 +227,60 @@ def delete(self, idx): raise ObjectDoesNotExist # delete file - os.remove(os.path.join(report.file_url)) + os.remove(get_report_path(report)) return self.delete_detail(user_id=self.get_user_id(), idx=idx) + + +def get_report_dir(execution_id): + return f"{current_app.config['UPLOAD_FOLDER']}/{execution_id}" + + +def get_report_path(report): + try: + return f"{get_report_dir(report['execution_id'])}/{report['file_url']}" + except: + return f"{get_report_dir(report.execution_id)}/{report.file_url}" + + +def new_file_name(file): + + filename = secure_filename(file.filename) + filename_extension = filename.split(".")[-1] + + if filename_extension not in current_app.config["ALLOWED_EXTENSIONS"]: + return { + "message": f"Invalid file extension. " + f"Valid extensions are: {current_app.config['ALLOWED_EXTENSIONS']}" + }, 400 + + report_name = f"{uuid.uuid4().hex}.{filename_extension}" + + return report_name + + +def write_file(file, execution_id, file_name): + my_directory = get_report_dir(execution_id) + + # we create a directory for the execution + if not os.path.exists(my_directory): + current_app.logger.info(f"Creating directory {my_directory}") + os.mkdir(my_directory) + + save_path = os.path.normpath(os.path.join(my_directory, file_name)) + + if "static" not in save_path or ".." in save_path: + raise NoPermission("Invalid file name") + file.save(save_path) + + +def get_report_info(data, execution, file_url=None): + return { + "name": data["name"], + "file_url": file_url, + "execution_id": execution.id, + "user_id": execution.user_id, + "description": data.get("description", ""), + "state": data.get("state"), + "state_message": data.get("state_message"), + }