diff --git a/setup.py b/setup.py index ab7a5bc6..9ce03d58 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ 'jmespath==1.0.1', 'python-hcl2==4.3.2', 'requests==2.32.3', - 'fastapi==0.109.2', + 'fastapi>=0.115.2,<0.116.0', 'python-multipart==0.0.7', 'click==8.1.7', 'uvicorn==0.23.2', diff --git a/startleft/startleft/api/controllers/diagram/diag_create_otm_controller.py b/startleft/startleft/api/controllers/diagram/diag_create_otm_controller.py index 695b3776..bc3df1bc 100644 --- a/startleft/startleft/api/controllers/diagram/diag_create_otm_controller.py +++ b/startleft/startleft/api/controllers/diagram/diag_create_otm_controller.py @@ -25,8 +25,8 @@ @check_mime_type('diag_file', 'diag_type', DiagramFileNotValidError) def diagram(diag_file: UploadFile = File(...), diag_type: DiagramType = Form(...), - id: str = Form(...), - name: str = Form(...), + id: str = Form(..., min_length=1, max_length=999), + name: str = Form(..., min_length=1, max_length=999), default_mapping_file: UploadFile = File(...), custom_mapping_file: UploadFile = File(None)): logger.info( diff --git a/startleft/startleft/api/controllers/etm/etm_create_otm_controller.py b/startleft/startleft/api/controllers/etm/etm_create_otm_controller.py index c7c22cff..6aa458ee 100644 --- a/startleft/startleft/api/controllers/etm/etm_create_otm_controller.py +++ b/startleft/startleft/api/controllers/etm/etm_create_otm_controller.py @@ -25,8 +25,8 @@ @check_mime_type('source_file', 'source_type') def etm(source_file: UploadFile = File(...), source_type: EtmType = Form(...), - id: str = Form(...), - name: str = Form(...), + id: str = Form(..., min_length=1, max_length=999), + name: str = Form(..., min_length=1, max_length=999), default_mapping_file: UploadFile = File(...), custom_mapping_file: UploadFile = File(None)): logger.info( diff --git a/startleft/startleft/api/controllers/iac/iac_create_otm_controller.py b/startleft/startleft/api/controllers/iac/iac_create_otm_controller.py index 175c24a8..87692472 100644 --- a/startleft/startleft/api/controllers/iac/iac_create_otm_controller.py +++ b/startleft/startleft/api/controllers/iac/iac_create_otm_controller.py @@ -29,8 +29,8 @@ def iac(iac_file: List[UploadFile] = File(...), iac_type: IacType = Form(...), - id: str = Form(...), - name: str = Form(...), + id: str = Form(..., min_length=1, max_length=999), + name: str = Form(..., min_length=1, max_length=999), mapping_file: UploadFile = File(None), default_mapping_file: UploadFile = File(None), custom_mapping_file: UploadFile = File(None)): diff --git a/tests/integration/api/controllers/diagram/test_otm_controller_diagram.py b/tests/integration/api/controllers/diagram/test_otm_controller_diagram.py index b4a1e077..db28bff3 100644 --- a/tests/integration/api/controllers/diagram/test_otm_controller_diagram.py +++ b/tests/integration/api/controllers/diagram/test_otm_controller_diagram.py @@ -7,8 +7,6 @@ from startleft.startleft.api.controllers.diagram import diag_create_otm_controller from tests.resources import test_resource_paths -IRIUSRISK_URL = '' - webapp = fastapi_server.webapp client = TestClient(webapp) diff --git a/tests/integration/api/controllers/etm/test_etm_controller_diagram.py b/tests/integration/api/controllers/etm/test_etm_controller_diagram.py new file mode 100644 index 00000000..d0456b1e --- /dev/null +++ b/tests/integration/api/controllers/etm/test_etm_controller_diagram.py @@ -0,0 +1,47 @@ +import json + +import pytest +from fastapi.testclient import TestClient + +from startleft.startleft.api import fastapi_server +from startleft.startleft.api.controllers.etm import etm_create_otm_controller + +webapp = fastapi_server.webapp + +client = TestClient(webapp) + + +def get_url(): + return etm_create_otm_controller.PREFIX + etm_create_otm_controller.URL + + +octet_stream = 'application/octet-stream' + + +class TestOTMControllerEtm: + + @pytest.mark.parametrize("project_id,project_name,source_file,errors_expected, error_type", [ + (None, 'name', None, 4, 'RequestValidationError'), + ('id', None, None, 4, 'RequestValidationError'), + ('id', 'name', None, 3, 'RequestValidationError'), + (None, None, None, 5, 'RequestValidationError'), + ('', None, None, 5, 'RequestValidationError') + ]) + def test_create_project_validation_error(self, project_id: str, project_name: str, source_file, + errors_expected: int, + error_type: str): + # Given a body + body = {'id': project_id, 'name': project_name} + + # When I do post to the endpoint + files = {'source_file': source_file} if source_file else None + response = client.post(get_url(), files=files, data=body) + + # Then + assert response.status_code == 400 + res_body = json.loads(response.text) + assert res_body['status'] == '400' + assert res_body['error_type'] == error_type + assert len(res_body['errors']) == errors_expected + for e in res_body['errors']: + assert len(e['errorMessage']) > 0 diff --git a/tests/integration/api/controllers/iac/cloudformation/test_otm_controller_iac_cloudformation.py b/tests/integration/api/controllers/iac/cloudformation/test_otm_controller_iac_cloudformation.py index 16b5fccb..a5901fef 100644 --- a/tests/integration/api/controllers/iac/cloudformation/test_otm_controller_iac_cloudformation.py +++ b/tests/integration/api/controllers/iac/cloudformation/test_otm_controller_iac_cloudformation.py @@ -12,7 +12,7 @@ from startleft.startleft.api.controllers.iac import iac_create_otm_controller from tests.resources.test_resource_paths import default_cloudformation_mapping, example_json, \ cloudformation_malformed_mapping_wrong_id, invalid_yaml, cloudformation_all_functions, \ - cloudformation_mapping_all_functions, cloudformation_gz, visio_aws_shapes, cloudformation_multiple_files_networks, \ + cloudformation_mapping_all_functions, cloudformation_gz, cloudformation_multiple_files_networks, \ cloudformation_multiple_files_resources, cloudformation_ref_full_syntax, cloudformation_ref_short_syntax TESTING_IAC_TYPE = IacType.CLOUDFORMATION.value @@ -98,6 +98,9 @@ def test_create_project_validation_error(self, project_id: str, project_name: st res_body = json.loads(response.content.decode('utf-8')) assert res_body['status'] == '400' assert res_body['error_type'] == error_type + assert len(res_body['errors']) == 1 + for e in res_body['errors']: + assert len(e['errorMessage']) > 0 @responses.activate @patch('slp_cft.slp_cft.validate.cft_validator.CloudformationValidator.validate')