diff --git a/.github/workflows/test_cornflow_server.yml b/.github/workflows/test_cornflow_server.yml index 4e8fc4576..6769fd2e2 100644 --- a/.github/workflows/test_cornflow_server.yml +++ b/.github/workflows/test_cornflow_server.yml @@ -49,7 +49,7 @@ jobs: steps: - uses: actions/checkout@v1 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Copy DAG files diff --git a/cornflow-server/Dockerfile b/cornflow-server/Dockerfile index 851f6e307..bbf9bc501 100644 --- a/cornflow-server/Dockerfile +++ b/cornflow-server/Dockerfile @@ -36,6 +36,9 @@ RUN pip install "cornflow==${CORNFLOW_VERSION}" # create folder for logs RUN mkdir -p /usr/src/app/log +# create folder for object storage +RUN mkdir -p /usr/src/app/static + # create folder for custom ssh keys RUN mkdir /usr/src/app/.ssh diff --git a/cornflow-server/MANIFEST.in b/cornflow-server/MANIFEST.in index 4f0473188..3816e709f 100644 --- a/cornflow-server/MANIFEST.in +++ b/cornflow-server/MANIFEST.in @@ -3,4 +3,5 @@ include MANIFEST.in include README.rst include setup.py include cornflow/migrations/* -include cornflow/migrations/versions/* \ No newline at end of file +include cornflow/migrations/versions/* +include cornflow/static/* \ No newline at end of file diff --git a/cornflow-server/cornflow/app.py b/cornflow-server/cornflow/app.py index d8125a931..2301af0c8 100644 --- a/cornflow-server/cornflow/app.py +++ b/cornflow-server/cornflow/app.py @@ -46,6 +46,9 @@ def create_app(env_name="development", dataconn=None): :return: the application that is going to be running :class:`Flask` :rtype: :class:`Flask` """ + if os.getenv("FLASK_ENV", None) is not None: + env_name = os.getenv("FLASK_ENV") + dictConfig(log_config(app_config[env_name].LOG_LEVEL)) app = Flask(__name__) diff --git a/cornflow-server/cornflow/config.py b/cornflow-server/cornflow/config.py index 4dba2613e..627b19c39 100644 --- a/cornflow-server/cornflow/config.py +++ b/cornflow-server/cornflow/config.py @@ -26,7 +26,7 @@ class DefaultConfig(object): FILE_BACKEND = os.getenv("FILE_BACKEND", "local") UPLOAD_FOLDER = os.getenv( "UPLOAD_FOLDER", - os.path.abspath(os.path.join(os.path.dirname(__file__), "../static")), + os.path.abspath(os.path.join(os.path.dirname(__file__), "./static")), ) ALLOWED_EXTENSIONS = os.getenv("ALLOWED_EXTENSIONS", ["pdf", "html"]) @@ -95,6 +95,7 @@ class Development(DefaultConfig): """ """ ENV = "development" + UPLOAD_FOLDER = os.getenv("UPLOAD_FOLDER", "/usr/src/app/static") class Testing(DefaultConfig): @@ -114,6 +115,10 @@ class Testing(DefaultConfig): AIRFLOW_PWD = os.getenv("AIRFLOW_PWD", "admin") OPEN_DEPLOYMENT = 1 LOG_LEVEL = int(os.getenv("LOG_LEVEL", 10)) + UPLOAD_FOLDER = os.getenv( + "UPLOAD_FOLDER", + os.path.abspath(os.path.join(os.path.dirname(__file__), "./static")), + ) class Production(DefaultConfig): @@ -126,6 +131,7 @@ class Production(DefaultConfig): # needs to be on to avoid getting only 500 codes: # and https://medium.com/@johanesriandy/flask-error-handler-not-working-on-production-mode-3adca4c7385c PROPAGATE_EXCEPTIONS = True + UPLOAD_FOLDER = os.getenv("UPLOAD_FOLDER", "/usr/src/app/static") app_config = {"development": Development, "testing": Testing, "production": Production} diff --git a/cornflow-server/cornflow/endpoints/reports.py b/cornflow-server/cornflow/endpoints/reports.py index e6c034ccf..b2d46421d 100644 --- a/cornflow-server/cornflow/endpoints/reports.py +++ b/cornflow-server/cornflow/endpoints/reports.py @@ -70,7 +70,8 @@ def post(self, **kwargs): the reference_id for the newly created report if successful) and a integer with the HTTP status code :rtype: Tuple(dict, integer) """ - execution = ExecutionModel.get_one_object(id=kwargs["execution_id"]) + + execution = ExecutionModel.get_one_object(idx=kwargs["execution_id"]) if execution is None: raise ObjectDoesNotExist("The execution does not exist") @@ -122,7 +123,7 @@ def post(self, **kwargs): except Exception as error: report.delete() current_app.logger.error(error) - raise FileError + raise FileError(error=str(error)) class ReportDetailsEndpointBase(BaseMetaResource): @@ -153,7 +154,7 @@ def get(self, idx): :rtype: Tuple(dict, integer) """ current_app.logger.info(f"User {self.get_user()} gets details of report {idx}") - report = self.get_detail(user_id=self.get_user_id(), idx=idx) + report = self.get_detail(user=self.get_user(), idx=idx) if report is None: raise ObjectDoesNotExist @@ -161,7 +162,10 @@ def get(self, idx): file = f"{report.name}{file}" directory = directory[:-1] - return send_from_directory(directory, file) + response = send_from_directory(directory, file) + response.headers["File-Description"] = report.description + response.headers["File-Name"] = report.name + return response @doc(description="Edit a report", tags=["Reports"], inherit=False) @authenticate(auth_class=Auth()) @@ -176,7 +180,33 @@ def put(self, idx, **data): :rtype: Tuple(dict, integer) """ current_app.logger.info(f"User {self.get_user()} edits report {idx}") - return self.put_detail(data, user=self.get_user(), idx=idx) + + report = self.get_detail(user=self.get_user(), idx=idx) + + try: + if report.name != data["name"]: + directory, file = report.file_url.split(report.name) + + new_location = ( + f"{os.path.join(directory, secure_filename(data['name']))}{file}" + ) + old_location = report.file_url + + current_app.logger.debug(f"Old location: {old_location}") + current_app.logger.debug(f"New location: {new_location}") + + os.rename(old_location, new_location) + data["file_url"] = new_location + + except Exception as error: + current_app.logger.error(error) + return {"error": "Error moving file"}, 400 + + report.update(data) + + report.save() + + return {"message": "Updated correctly"}, 200 @doc(description="Delete a report", tags=["Reports"], inherit=False) @authenticate(auth_class=Auth()) diff --git a/cornflow-server/cornflow/models/reports.py b/cornflow-server/cornflow/models/reports.py index 73ef15c8f..d1eaa2188 100644 --- a/cornflow-server/cornflow/models/reports.py +++ b/cornflow-server/cornflow/models/reports.py @@ -53,6 +53,10 @@ def user_id(self): """ return db.Column(db.Integer, db.ForeignKey("users.id"), nullable=False) + @declared_attr + def user(self): + return db.relationship("UserModel") + def __init__(self, data: dict): super().__init__() self.user_id = data.get("user_id") diff --git a/cornflow-server/cornflow/shared/exceptions.py b/cornflow-server/cornflow/shared/exceptions.py index c3f4da63a..605a8e753 100644 --- a/cornflow-server/cornflow/shared/exceptions.py +++ b/cornflow-server/cornflow/shared/exceptions.py @@ -21,7 +21,10 @@ class InvalidUsage(Exception): def __init__(self, error=None, status_code=None, payload=None, log_txt=None): Exception.__init__(self, error) if error is not None: - self.error = error + if isinstance(error, Exception): + self.error = str(error) + else: + self.error = error if status_code is not None: self.status_code = status_code self.payload = payload diff --git a/cornflow-server/cornflow/tests/unit/test_reports.py b/cornflow-server/cornflow/tests/unit/test_reports.py index a142f93ce..abab00834 100644 --- a/cornflow-server/cornflow/tests/unit/test_reports.py +++ b/cornflow-server/cornflow/tests/unit/test_reports.py @@ -158,7 +158,7 @@ def test_new_report_no_execution(self): ), ) - self.assertEqual(400, response.status_code) + self.assertEqual(404, response.status_code) self.assertTrue("error" in response.json) def test_get_no_reports(self): @@ -187,6 +187,27 @@ def test_get_one_report(self): self.assertEqual(200, response.status_code) self.assertEqual(content, file) + def test_modify_report(self): + item = self.test_new_report_html() + + payload = {"name": "new_name", "description": "some_description"} + + response = self.client.put( + f"{self.url}{item['id']}/", + headers=self.get_header_with_auth(self.token), + json=payload, + ) + + self.assertEqual(response.status_code, 200) + + response = self.client.get( + f"{self.url}{item['id']}/", headers=self.get_header_with_auth(self.token) + ) + + self.assertEqual(200, response.status_code) + self.assertEqual("new_name", dict(response.headers)["File-Name"]) + self.assertEqual("some_description", dict(response.headers)["File-Description"]) + def test_delete_report(self): item = self.test_new_report_html() response = self.client.delete( diff --git a/cornflow-server/static/__init__.py b/cornflow-server/static/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/libs/client/cornflow_client/cornflow_client.py b/libs/client/cornflow_client/cornflow_client.py index 178943f68..3f2df4890 100644 --- a/libs/client/cornflow_client/cornflow_client.py +++ b/libs/client/cornflow_client/cornflow_client.py @@ -1,7 +1,6 @@ from .raw_cornflow_client import RawCornFlow, CornFlowApiError # TODO: review the standard calls for the reports. -# TODO: modify the headers on the calls that require a file. # TODO: have the download report method to receive the path to save it on the local machine. @@ -22,7 +21,11 @@ def __init__(self, url, token=None): ) self.create_report = self.expect_status(self.raw.create_report, 201) self.get_reports = self.expect_status(self.raw.get_reports, 200) - self.get_one_report = self.expect_status(self.raw.get_one_report, 200) + self.get_one_report = self.expect_status( + self.raw.get_one_report, 200, json=False + ) + self.put_one_report = self.expect_status(self.raw.put_one_report, 200) + self.delete_one_report = self.expect_status(self.raw.delete_one_report, 200) self.create_instance_data_check = self.expect_status( self.raw.create_instance_data_check, 201 @@ -92,10 +95,13 @@ def token(self, token): self.raw.token = token @staticmethod - def expect_status(func, expected_status=None): + def expect_status(func, expected_status=None, json=True): """ Gets the response of the call and raise an exception if the status of the response is not the expected + + The response of the call is the json in the body for those calls that are application/json + For the calls that are form/data the response of the call is the content and the headers """ def decorator(*args, **kwargs): @@ -104,7 +110,11 @@ def decorator(*args, **kwargs): raise CornFlowApiError( f"Expected a code {expected_status}, got a {response.status_code} error instead: {response.text}" ) - return response.json() + + if json: + return response.json() + else: + return response.content, response.headers return decorator diff --git a/libs/client/cornflow_client/raw_cornflow_client.py b/libs/client/cornflow_client/raw_cornflow_client.py index 9428fe3c6..86fa6a729 100644 --- a/libs/client/cornflow_client/raw_cornflow_client.py +++ b/libs/client/cornflow_client/raw_cornflow_client.py @@ -1,14 +1,55 @@ """ - +Code for the main class to interact to cornflow from python code. """ import logging as log import os import re from functools import wraps +from typing import Union, Dict from urllib.parse import urljoin import requests +from requests import Response + + +def ask_token(func: callable): + @wraps(func) + def wrapper(self, *args, **kwargs): + if not self.token: + raise CornFlowApiError("Need to login first!") + return func(self, *args, **kwargs) + + return wrapper + + +def log_call(func: callable): + @wraps(func) + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + log.debug(result.json()) + return result + + return wrapper + + +def prepare_encoding(func: callable): + @wraps(func) + def wrapper(*args, **kwargs): + encoding = kwargs.get("encoding", "br") + if encoding not in [ + "gzip", + "compress", + "deflate", + "br", + "identity", + ]: + encoding = "br" + kwargs["encoding"] = encoding + result = func(*args, **kwargs) + return result + + return wrapper class RawCornFlow(object): @@ -16,46 +57,10 @@ class RawCornFlow(object): Base class to access cornflow-server """ - def __init__(self, url, token=None): + def __init__(self, url: str, token=None): self.url = url self.token = token - def ask_token(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - if not self.token: - raise CornFlowApiError("Need to login first!") - return func(self, *args, **kwargs) - - return wrapper - - def log_call(func): - @wraps(func) - def wrapper(*args, **kwargs): - result = func(*args, **kwargs) - log.debug(result.json()) - return result - - return wrapper - - def prepare_encoding(func): - @wraps(func) - def wrapper(*args, **kwargs): - encoding = kwargs.get("encoding", "br") - if encoding not in [ - "gzip", - "compress", - "deflate", - "br", - "identity", - ]: - encoding = "br" - kwargs["encoding"] = encoding - result = func(*args, **kwargs) - return result - - return wrapper - # def expect_201(func): # return partial(expect_status, status=201) # @@ -64,25 +69,25 @@ def wrapper(*args, **kwargs): def api_for_id( self, - api, - id=None, - method="GET", - post_url=None, - query_args=None, - encoding=None, + api: str, + id: Union[str, int] = None, + method: str = "GET", + post_url: str = None, + query_args: Dict = None, + encoding: str = None, **kwargs, - ): + ) -> Response: """ - :param api: the resource in the server + :param str api: the resource in the server :param id: the id of the particular object - :param method: HTTP method to apply - :param post_url: optional action to apply - :param query_args: query arguments for the request - :param encoding: optional string with the type of encoding, if it is not specified it uses br encoding, + :param str method: HTTP method to apply + :param str post_url: optional action to apply + :param Dict query_args: query arguments for the request + :param str encoding: optional string with the type of encoding, if it is not specified it uses br encoding, options are: gzip, compress, deflate, br or identity :param kwargs: other arguments to requests.request - :return: requests.request + :return: :class:`requests.Response` """ if api[0] == "/" and self.url[-1] == "/": api = api[1:] @@ -122,7 +127,9 @@ def api_for_id( **kwargs, ) - def get_api(self, api, method="GET", encoding=None, **kwargs): + def get_api( + self, api: str, method: str = "GET", encoding: str = None, **kwargs + ) -> Response: return requests.request( method=method, url=urljoin(self.url, api) + "/", @@ -135,7 +142,14 @@ def get_api(self, api, method="GET", encoding=None, **kwargs): @ask_token @prepare_encoding - def get_api_for_id(self, api, id=None, post_url=None, encoding=None, **kwargs): + def get_api_for_id( + self, + api: str, + id: Union[str, id] = None, + post_url: str = None, + encoding: str = None, + **kwargs, + ) -> Response: """ api_for_id with a GET request """ @@ -352,7 +366,7 @@ def create_execution( :param str instance_id: id for the instance :param str name: name for the execution :param str description: description of the execution - :param dict config: execution configuration + :param dict config: configuration for the execution :param str schema: name of the problem to solve :param str encoding: the type of encoding used in the call. Defaults to 'br' :param bool run: if the execution should be run or not @@ -503,8 +517,8 @@ def write_solution(self, execution_id, encoding=None, **kwargs): Edits an execution :param str execution_id: id for the execution - :param kwargs: optional data to edit :param str encoding: the type of encoding used in the call. Defaults to 'br' + :param kwargs: optional data to edit """ return self.put_api_for_id( "dag/", id=execution_id, encoding=encoding, payload=kwargs @@ -514,7 +528,14 @@ def write_solution(self, execution_id, encoding=None, **kwargs): @log_call @prepare_encoding def get_reports(self, params=None, encoding=None): - """ """ + """ + Gets all reports for a given user + + :param dict params: optional filters + :param str encoding: the type of encoding used in the call. Defaults to 'br' + :return: the response object + :rtype: :class:`Response` + """ return self.get_api("report", params=params, encoding=encoding) @ask_token @@ -526,14 +547,13 @@ def create_report(self, name, filename, execution_id, encoding=None, **kwargs): :param str execution_id: id for the execution :param str name: the name of the report :param file filename: the file object with the report (e.g., open(REPORT_FILE_PATH, "rb")) - :param kwargs: optional data to write (description) :param str encoding: the type of encoding used in the call. Defaults to 'br' + :param kwargs: optional data to write (description) """ with open(filename, "rb") as _file: - payload = dict(name=name, execution_id=execution_id, **kwargs) result = self.create_api( "report/", - data=payload, + data=dict(name=name, execution_id=execution_id, **kwargs), files=dict(file=_file), encoding=encoding, headers={"content_type": "multipart/form-data"}, @@ -543,8 +563,18 @@ def create_report(self, name, filename, execution_id, encoding=None, **kwargs): @ask_token @prepare_encoding def get_one_report( - self, reference_id, folder_destination, file_name, encoding=None - ): + self, reference_id, folder_destination, file_name=None, encoding=None + ) -> Response: + """ + Gets one specific report and downloads it to disk + + :param int reference_id: id of the report to download + :param str folder_destination: Path on the local system where to save the downloaded report + :param str file_name: optional name for the report file to write + :param str encoding: the type of encoding used in the call. Defaults to 'br' + :return: the response object + :rtype: :class:`Response` + """ result = self.get_api_for_id(api="report", id=reference_id, encoding=encoding) content = result.content @@ -558,6 +588,36 @@ def get_one_report( return result + @ask_token + @log_call + @prepare_encoding + def delete_one_report(self, reference_id, encoding=None): + """ + Deletes a report + + :param int reference_id: id of the report to download + :param str encoding: the type of encoding used in the call. Defaults to 'br' + :return: the response object + :rtype: :class:`Response` + """ + return self.delete_api_for_id(api="report", id=reference_id, encoding=encoding) + + @ask_token + @log_call + @prepare_encoding + def put_one_report(self, reference_id, payload, encoding=None) -> Response: + """ + Edits one specific report and downloads it to disk + + :param int reference_id: id of the report to download + :param str encoding: the type of encoding used in the call. Defaults to 'br' + :return: the response object + :rtype: :class:`Response` + """ + return self.put_api_for_id( + api="report", id=reference_id, payload=payload, encoding=encoding + ) + @ask_token @prepare_encoding def write_instance_checks(self, instance_id, encoding=None, **kwargs): @@ -937,7 +997,7 @@ def create_deployed_dag( encoding=None, ): if name is None: - return {"error": "No dag name was given"} + raise CornFlowApiError("No dag name was given") payload = dict( id=name, description=description, @@ -1020,7 +1080,7 @@ def group_variables_by_name(_vars, names_list, **kwargs): # 2. key can be a tuple or a single string. # 3. if a tuple, they can be an integer or a string. # - # it dos not permit the nested dictionary format of variables + # it does not permit the nested dictionary format of variables # we copy it because we will be taking out already seen variables _vars = dict(_vars) __vars = {k: {} for k in names_list} diff --git a/libs/client/cornflow_client/tests/const.py b/libs/client/cornflow_client/tests/const.py index ba21f4a1c..e6f39412a 100644 --- a/libs/client/cornflow_client/tests/const.py +++ b/libs/client/cornflow_client/tests/const.py @@ -195,7 +195,8 @@ def _get_file(relative_path): ) PULP_EXAMPLE = _get_file("./data/pulp_example_data.json") -HTML_REPORT = "../data/new_report.html" +HTML_REPORT = _get_file("./data/new_report.html") +TEST_FOLDER = "./" PUBLIC_DAGS = [ "solve_model_dag", diff --git a/libs/client/cornflow_client/tests/integration/test_cornflow_integration.py b/libs/client/cornflow_client/tests/integration/test_cornflow_integration.py index 38bf1be2a..d78e6e4d9 100644 --- a/libs/client/cornflow_client/tests/integration/test_cornflow_integration.py +++ b/libs/client/cornflow_client/tests/integration/test_cornflow_integration.py @@ -11,10 +11,15 @@ import pulp as pl -from cornflow_client import CornFlow +from cornflow_client import CornFlow, CornFlowApiError from cornflow_client.constants import STATUS_OPTIMAL, STATUS_NOT_SOLVED, STATUS_QUEUED from cornflow_client.schema.tools import get_pulp_jsonschema -from cornflow_client.tests.const import PUBLIC_DAGS, PULP_EXAMPLE +from cornflow_client.tests.const import ( + PUBLIC_DAGS, + PULP_EXAMPLE, + HTML_REPORT, + TEST_FOLDER, +) # Constants path_to_tests_dir = os.path.dirname(os.path.abspath(__file__)) @@ -551,9 +556,14 @@ def setUp(self): login_result = self.client.login("admin", "Adminpassword1!") self.assertIn("id", login_result.keys()) self.assertIn("token", login_result.keys()) - self.base_user_id = CornFlow(url="http://127.0.0.1:5050/").login( - "user", "UserPassword1!" - )["id"] + try: + self.base_user_id = CornFlow(url="http://127.0.0.1:5050/").login( + "user", "UserPassword1!" + )["id"] + except CornFlowApiError: + self.base_user_id = CornFlow(url="http://127.0.0.1:5050/").sign_up( + username="user", pwd="UserPassword1!", email="user@cornflow.org" + )["id"] def tearDown(self): pass @@ -660,3 +670,104 @@ def test_post_deployed_dag(self): self.assertIn(item, response.keys()) self.assertEqual("test_dag", response["id"]) self.assertEqual("test_dag_description", response["description"]) + + def test_post_report_html(self): + client = CornFlow(url="http://127.0.0.1:5050/") + _ = client.login("user", "UserPassword1!") + + data = _load_file(PULP_EXAMPLE) + + instance = client.create_instance(data, "test_example", "test_description") + + execution = client.create_execution( + instance_id=instance["id"], + config={"solver": "PULP_CBC_CMD", "timeLimit": 60}, + name="test_execution", + description="execution_description", + schema="solve_model_dag", + run=False, + ) + + response = self.client.create_report("new_report", HTML_REPORT, execution["id"]) + + self.assertEqual(response["execution_id"], execution["id"]) + + return response + + def test_get_one_report(self): + response = self.test_post_report_html() + report_id = response["id"] + + client = CornFlow(url="http://127.0.0.1:5050/") + _ = client.login("user", "UserPassword1!") + + content, headers = client.get_one_report( + reference_id=report_id, folder_destination=TEST_FOLDER + ) + + self.assertEqual(headers["File-Name"], response["name"]) + self.assertEqual(headers["File-Description"], response["description"]) + + # read from TEST FOLDER + with open(os.path.join(TEST_FOLDER, "new_report.html"), "r") as f: + file = f.read() + + # read from test/data folder + with open(HTML_REPORT, "r") as f: + file_2 = f.read() + + self.assertEqual(file, file_2) + + # remove file from TEST_FOLDER + os.remove(os.path.join(TEST_FOLDER, "new_report.html")) + + def test_get_all_reports(self): + report_1 = self.test_post_report_html()["id"] + report_2 = self.test_post_report_html()["id"] + + client = CornFlow(url="http://127.0.0.1:5050/") + _ = client.login("user", "UserPassword1!") + + response = client.get_reports() + + self.assertGreaterEqual(len(response), 2) + + client.delete_one_report(reference_id=report_1) + client.delete_one_report(reference_id=report_2) + + def test_put_one_report(self): + response = self.test_post_report_html() + report_id = response["id"] + + client = CornFlow(url="http://127.0.0.1:5050/") + _ = client.login("user", "UserPassword1!") + + payload = {"name": "new_name", "description": "some_description"} + + _ = client.put_one_report(reference_id=report_id, payload=payload) + + content, headers = client.get_one_report( + reference_id=report_id, folder_destination=TEST_FOLDER + ) + + self.assertEqual(headers["File-Name"], payload["name"]) + self.assertEqual(headers["File-Description"], payload["description"]) + self.assertNotEqual(headers["File-Name"], "new_report") + self.assertNotEqual(headers["File-Description"], "") + + _ = client.delete_one_report(reference_id=report_id) + + def test_delete_one_report(self): + response = self.test_post_report_html() + report_id = response["id"] + + client = CornFlow(url="http://127.0.0.1:5050/") + _ = client.login("user", "UserPassword1!") + + reports_before = client.get_reports() + + _ = client.delete_one_report(reference_id=report_id) + + reports_after = client.get_reports() + + self.assertLess(len(reports_after), len(reports_before)) diff --git a/libs/client/cornflow_client/tests/integration/test_raw_cornflow_integration.py b/libs/client/cornflow_client/tests/integration/test_raw_cornflow_integration.py index 2c1e7626f..f0a32fae7 100644 --- a/libs/client/cornflow_client/tests/integration/test_raw_cornflow_integration.py +++ b/libs/client/cornflow_client/tests/integration/test_raw_cornflow_integration.py @@ -11,10 +11,15 @@ import pulp as pl -from cornflow_client import CornFlow +from cornflow_client import CornFlow, CornFlowApiError from cornflow_client.constants import STATUS_OPTIMAL, STATUS_NOT_SOLVED, STATUS_QUEUED from cornflow_client.schema.tools import get_pulp_jsonschema -from cornflow_client.tests.const import PUBLIC_DAGS, PULP_EXAMPLE, HTML_REPORT +from cornflow_client.tests.const import ( + PUBLIC_DAGS, + PULP_EXAMPLE, + HTML_REPORT, + TEST_FOLDER, +) # Constants path_to_tests_dir = os.path.dirname(os.path.abspath(__file__)) @@ -35,7 +40,12 @@ def _get_file(relative_path): class TestRawCornflowClientUser(TestCase): def setUp(self): self.client = CornFlow(url="http://127.0.0.1:5050/") - login_result = self.client.raw.login("user", "UserPassword1!") + try: + login_result = self.client.raw.login("user", "UserPassword1!") + except CornFlowApiError: + login_result = self.client.raw.sign_up( + username="user", pwd="UserPassword1!", email="user@cornflow.org" + ) data = login_result.json() self.assertEqual(login_result.status_code, 200) self.assertIn("id", data.keys()) @@ -46,9 +56,9 @@ def tearDown(self): pass def check_execution_statuses( - self, execution_id, end_state=STATUS_OPTIMAL, initial_state=None + self, execution_id, end_state=STATUS_OPTIMAL, initial_state=STATUS_QUEUED ): - if initial_state is None: + if initial_state is not None: statuses = [initial_state] else: statuses = [] @@ -332,7 +342,9 @@ def test_get_execution_log(self): def test_get_execution_solution(self): execution = self.test_create_execution() - statuses = self.check_execution_statuses(execution["id"]) + statuses = self.check_execution_statuses( + execution["id"], initial_state=STATUS_QUEUED + ) response = self.client.raw.get_solution(execution["id"]) self.assertEqual(response.status_code, 200) @@ -605,6 +617,11 @@ def test_get_all_schemas(self): for schema in PUBLIC_DAGS: self.assertIn(schema, read_schemas) + def test_log_in_first(self): + client = CornFlow(url="http://127.0.0.1:5050/") + + self.assertRaises(CornFlowApiError, client.raw.get_all_instances) + class TestRawCornflowClientAdmin(TestCase): def setUp(self): @@ -612,9 +629,14 @@ def setUp(self): login_result = self.client.login("admin", "Adminpassword1!") self.assertIn("id", login_result.keys()) self.assertIn("token", login_result.keys()) - self.base_user_id = CornFlow(url="http://127.0.0.1:5050/").login( - "user", "UserPassword1!" - )["id"] + try: + self.base_user_id = CornFlow(url="http://127.0.0.1:5050/").login( + "user", "UserPassword1!" + )["id"] + except CornFlowApiError: + self.base_user_id = CornFlow(url="http://127.0.0.1:5050/").sign_up( + username="user", pwd="UserPassword1!", email="user@cornflow.org" + )["id"] def tearDown(self): pass @@ -642,6 +664,14 @@ def setUp(self): login_result = self.client.login("airflow", "Airflow_test_password1") self.assertIn("id", login_result.keys()) self.assertIn("token", login_result.keys()) + try: + self.base_user_id = CornFlow(url="http://127.0.0.1:5050/").login( + "user", "UserPassword1!" + )["id"] + except CornFlowApiError: + self.base_user_id = CornFlow(url="http://127.0.0.1:5050/").sign_up( + username="user", pwd="UserPassword1!", email="user@cornflow.org" + )["id"] def tearDown(self): pass @@ -733,6 +763,19 @@ def test_post_deployed_dag(self): self.assertEqual("test_dag_2", response["id"]) self.assertEqual("test_dag_2_description", response["description"]) + def test_raises_post_deployed_dag(self): + self.assertRaises( + CornFlowApiError, + self.client.raw.create_deployed_dag, + name=None, + description="test_dag_2_description", + instance_schema=dict(), + instance_checks_schema=dict(), + solution_schema=dict(), + solution_checks_schema=dict(), + config_schema=dict(), + ) + def test_post_report_html(self): client = CornFlow(url="http://127.0.0.1:5050/") _ = client.login("user", "UserPassword1!") @@ -752,4 +795,93 @@ def test_post_report_html(self): run=False, ).json() - client.raw.create_report("new_report", HTML_REPORT, execution["id"]) + response = self.client.raw.create_report( + "new_report", HTML_REPORT, execution["id"] + ) + + self.assertEqual(response.status_code, 201) + + return response + + def test_get_one_report(self): + response = self.test_post_report_html() + report_id = response.json()["id"] + + client = CornFlow(url="http://127.0.0.1:5050/") + _ = client.login("user", "UserPassword1!") + + response = client.raw.get_one_report( + reference_id=report_id, folder_destination=TEST_FOLDER + ) + self.assertEqual(response.status_code, 200) + + # read from TEST FOLDER + with open(os.path.join(TEST_FOLDER, "new_report.html"), "r") as f: + file = f.read() + + # read from test/data folder + with open(HTML_REPORT, "r") as f: + file_2 = f.read() + + self.assertEqual(file, file_2) + + # remove file from TEST_FOLDER + os.remove(os.path.join(TEST_FOLDER, "new_report.html")) + + def test_get_all_reports(self): + report_1 = self.test_post_report_html().json()["id"] + report_2 = self.test_post_report_html().json()["id"] + + client = CornFlow(url="http://127.0.0.1:5050/") + _ = client.login("user", "UserPassword1!") + + response = client.raw.get_reports() + + self.assertEqual(response.status_code, 200) + self.assertGreaterEqual(len(response.json()), 2) + + client.raw.delete_one_report(reference_id=report_1) + client.raw.delete_one_report(reference_id=report_2) + + def test_put_one_report(self): + response = self.test_post_report_html() + report_id = response.json()["id"] + + client = CornFlow(url="http://127.0.0.1:5050/") + _ = client.login("user", "UserPassword1!") + + payload = {"name": "new_name", "description": "some_description"} + + response = client.raw.put_one_report(reference_id=report_id, payload=payload) + + self.assertEqual(response.status_code, 200) + + new_report = client.raw.get_one_report( + reference_id=report_id, folder_destination=TEST_FOLDER + ) + + self.assertEqual(new_report.headers["File-Name"], payload["name"]) + self.assertEqual(new_report.headers["File-Description"], payload["description"]) + self.assertNotEqual(new_report.headers["File-Name"], "new_report") + self.assertNotEqual(new_report.headers["File-Description"], "") + + delete = client.raw.delete_one_report(reference_id=report_id) + self.assertEqual(delete.status_code, 200) + + def test_delete_one_report(self): + response = self.test_post_report_html() + report_id = response.json()["id"] + + client = CornFlow(url="http://127.0.0.1:5050/") + _ = client.login("user", "UserPassword1!") + + reports_before = client.raw.get_reports() + + self.assertEqual(reports_before.status_code, 200) + + response = client.raw.delete_one_report(reference_id=report_id) + self.assertEqual(response.status_code, 200) + + reports_after = client.raw.get_reports() + + self.assertLess(len(reports_after.json()), len(reports_before.json()))