diff --git a/src/backend/app/organisations/organisation_schemas.py b/src/backend/app/organisations/organisation_schemas.py index 9df5d7b30..8fb46b102 100644 --- a/src/backend/app/organisations/organisation_schemas.py +++ b/src/backend/app/organisations/organisation_schemas.py @@ -20,7 +20,7 @@ from re import sub from typing import Annotated, Optional -from pydantic import BaseModel, Field, FieldValidationInfo +from pydantic import BaseModel, Field, ValidationInfo from pydantic.functional_validators import field_validator from app.central.central_schemas import ODKCentralIn @@ -41,7 +41,7 @@ class OrganisationInBase(ODKCentralIn, DbOrganisation): @field_validator("slug", mode="after") @classmethod - def set_slug(cls, value: Optional[str], info: FieldValidationInfo) -> str: + def set_slug(cls, value: Optional[str], info: ValidationInfo) -> str: """Set the slug attribute from the name. NOTE this is a bit of a hack. diff --git a/src/backend/app/projects/project_schemas.py b/src/backend/app/projects/project_schemas.py index e5d284567..15a55ac64 100644 --- a/src/backend/app/projects/project_schemas.py +++ b/src/backend/app/projects/project_schemas.py @@ -26,7 +26,7 @@ from pydantic import ( BaseModel, Field, - FieldValidationInfo, + ValidationInfo, computed_field, ) from pydantic.functional_serializers import field_serializer @@ -80,7 +80,7 @@ class ProjectInBase(DbProject): def set_project_slug( cls, value: Optional[str], - info: FieldValidationInfo, + info: ValidationInfo, ) -> str: """Set the slug attribute from the name. diff --git a/src/backend/pyproject.toml b/src/backend/pyproject.toml index 6160eee59..5b208b4d6 100644 --- a/src/backend/pyproject.toml +++ b/src/backend/pyproject.toml @@ -127,6 +127,7 @@ testpaths = [ "tests", ] asyncio_mode="auto" +asyncio_default_fixture_loop_scope="session" [tool.commitizen] name = "cz_conventional_commits" diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index c3f70ca2e..13e70b682 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -21,16 +21,16 @@ import os from io import BytesIO from pathlib import Path -from typing import Any, Generator +from typing import Any, AsyncGenerator from uuid import uuid4 import pytest +import pytest_asyncio import requests from fastapi import FastAPI -from fastapi.testclient import TestClient -from geojson_pydantic import Polygon +from httpx import ASGITransport, AsyncClient from loguru import logger as log -from psycopg import connect +from psycopg import AsyncConnection from app.auth.auth_routes import get_or_create_user from app.auth.auth_schemas import AuthUser, FMTMUser @@ -39,8 +39,9 @@ from app.config import encrypt_value, settings from app.db.database import db_conn from app.db.enums import TaskStatus, UserRole -from app.db.models import DbOrganisation, DbTaskHistory +from app.db.models import DbProject, DbTask, DbTaskHistory from app.main import get_application +from app.organisations.organisation_deps import get_organisation from app.projects import project_crud from app.projects.project_schemas import ProjectIn from app.users.user_deps import get_user @@ -58,23 +59,25 @@ def pytest_configure(config): # sqlalchemy_log.propagate = False -@pytest.fixture(autouse=True) -def app() -> Generator[FastAPI, Any, None]: +@pytest_asyncio.fixture(autouse=True) +async def app() -> AsyncGenerator[FastAPI, Any]: """Get the FastAPI test server.""" yield get_application() -@pytest.fixture(scope="function") -def db(): - """The psycopg database connection using psycopg3.""" - db_conn = connect(settings.FMTM_DB_URL.unicode_string()) +@pytest_asyncio.fixture(scope="function") +async def db(): + """The psycopg async database connection using psycopg3.""" + db_conn = await AsyncConnection.connect( + settings.FMTM_DB_URL.unicode_string(), + ) try: yield db_conn finally: - db_conn.close() + await db_conn.close() -@pytest.fixture(scope="function") +@pytest_asyncio.fixture(scope="function") async def admin_user(db): """A test user.""" db_user = await get_or_create_user( @@ -94,13 +97,13 @@ async def admin_user(db): ) -@pytest.fixture(scope="function") -def organisation(db): +@pytest_asyncio.fixture(scope="function") +async def organisation(db): """A test organisation.""" - return db.query(DbOrganisation).filter(DbOrganisation.name == "HOTOSM").first() + return await get_organisation(db, "HOTOSM") -@pytest.fixture(scope="function") +@pytest_asyncio.fixture(scope="function") async def project(db, admin_user, organisation): """A test project, using the test user and org.""" project_metadata = ProjectIn( @@ -112,9 +115,9 @@ async def project(db, admin_user, organisation): odk_central_user=os.getenv("ODK_CENTRAL_USER"), odk_central_password=os.getenv("ODK_CENTRAL_PASSWD"), hashtags="hashtag1 hashtag2", - outline=Polygon( - type="Polygon", - coordinates=[ + outline={ + "type": "Polygon", + "coordinates": [ [ [85.299989110, 27.7140080437], [85.299989110, 27.7108923499], @@ -123,7 +126,8 @@ async def project(db, admin_user, organisation): [85.299989110, 27.7140080437], ] ], - ), + }, + author_id=admin_user.id, organisation_id=organisation.id, ) @@ -147,13 +151,8 @@ async def project(db, admin_user, organisation): # Create FMTM Project try: - new_project = await project_crud.create_project_with_project_info( - db, - project_metadata, - odkproject["id"], - admin_user, - ) - log.debug(f"Project returned: {new_project.__dict__}") + new_project = await DbProject.create(db, project_metadata) + log.debug(f"Project returned: {new_project}") assert new_project is not None except Exception as e: log.exception(e) @@ -162,7 +161,7 @@ async def project(db, admin_user, organisation): return new_project -@pytest.fixture(scope="function") +@pytest_asyncio.fixture(scope="function") async def tasks(project, db): """Test tasks, using the test project.""" boundaries = { @@ -188,22 +187,18 @@ async def tasks(project, db): "timestamp": "2021-02-11T17:21:06", }, } - try: - tasks = await project_crud.create_tasks_from_geojson( - db=db, project_id=project.id, boundaries=boundaries - ) + try: + tasks = await DbTask.create(db, project.id, boundaries) assert tasks is True - - # Refresh the project to include the tasks - db.refresh(project) except Exception as e: log.exception(e) pytest.fail(f"Test failed with exception: {str(e)}") - return project.tasks + + return tasks -@pytest.fixture(scope="function") +@pytest_asyncio.fixture(scope="function") async def task_history(db, project, tasks, admin_user): """A test task history using the test user, project and task.""" user = await get_user(admin_user.id, db) @@ -217,13 +212,11 @@ async def task_history(db, project, tasks, admin_user): actioned_by=user, user_id=user.id, ) - db.add(task_history_entry) - db.commit() - db.refresh(task_history_entry) - return task_history_entry + db_task_history = await DbTaskHistory.create(db, task_history_entry) + return db_task_history -@pytest.fixture(scope="function") +@pytest_asyncio.fixture(scope="function") async def odk_project(db, client, project, tasks): """Create ODK Central resources for a project and generate the necessary files.""" with open(f"{test_data_path}/data_extract_kathmandu.geojson", "rb") as f: @@ -253,7 +246,7 @@ async def odk_project(db, client, project, tasks): ) } try: - response = client.post( + response = await client.post( f"/projects/{project.id}/generate-project-data", files=xform_file, ) @@ -265,7 +258,7 @@ async def odk_project(db, client, project, tasks): yield project -@pytest.fixture(scope="function") +@pytest_asyncio.fixture(scope="function") async def entities(odk_project): """Get entities data.""" odk_credentials = { @@ -282,8 +275,8 @@ async def entities(odk_project): yield entities -@pytest.fixture(scope="function") -def project_data(): +@pytest_asyncio.fixture(scope="function") +async def project_data(): """Sample data for creating a project.""" project_name = f"Test Project {uuid4()}" data = { @@ -317,7 +310,7 @@ def project_data(): return data -# @pytest.fixture(scope="function") +# @pytest_asyncio.fixture(scope="function") # def get_ids(db, project): # user_id_query = text(f"SELECT id FROM {DbUser.__table__.name} LIMIT 1") # organisation_id_query = text( @@ -338,10 +331,13 @@ def project_data(): # return data -@pytest.fixture(scope="function") -def client(app, db): +@pytest_asyncio.fixture(scope="function") +async def client(app, db): """The FastAPI test server.""" app.dependency_overrides[db_conn] = lambda: db - with TestClient(app) as c: - yield c + async with AsyncClient( + transport=ASGITransport(app=app), + base_url="http://testserver", + ) as ac: + yield ac diff --git a/src/backend/tests/test_central_routes.py b/src/backend/tests/test_central_routes.py index f3bc0128e..6949da4ac 100644 --- a/src/backend/tests/test_central_routes.py +++ b/src/backend/tests/test_central_routes.py @@ -24,7 +24,7 @@ async def test_list_forms(client): """Test get a list of all XLSForms available in FMTM.""" - response = client.get("/central/list-forms") + response = await client.get("/central/list-forms") assert response.status_code == 200 forms_json = response.json() diff --git a/src/backend/tests/test_organisation_routes.py b/src/backend/tests/test_organisation_routes.py index 27e394208..e8199cb9c 100644 --- a/src/backend/tests/test_organisation_routes.py +++ b/src/backend/tests/test_organisation_routes.py @@ -22,7 +22,7 @@ async def test_get_organisation(client, organisation): """Test get list of organisations.""" - response = client.get("/organisation/") + response = await client.get("/organisation/") assert response.status_code == 200 data = response.json()[-1] diff --git a/src/backend/tests/test_projects_routes.py b/src/backend/tests/test_projects_routes.py index ec48f32ac..9af160f63 100644 --- a/src/backend/tests/test_projects_routes.py +++ b/src/backend/tests/test_projects_routes.py @@ -28,13 +28,11 @@ import requests from fastapi import HTTPException from loguru import logger as log -from shapely import Polygon from app.central import central_schemas from app.central.central_crud import create_odk_project -from app.central.central_schemas import TaskStatus from app.config import encrypt_value, settings -from app.db.models import DbProject +from app.db.enums import TaskStatus from app.db.postgis_utils import check_crs from app.projects import project_crud from tests.test_data import test_data_path @@ -44,21 +42,23 @@ odk_central_password = encrypt_value(os.getenv("ODK_CENTRAL_PASSWD", "")) -def create_project(client, organisation_id, project_data): +async def create_project(client, organisation_id, project_data): """Create a new project.""" - response = client.post(f"/projects?org_id={organisation_id}", json=project_data) + response = await client.post( + f"/projects?org_id={organisation_id}", json=project_data + ) assert response.status_code == 200 return response.json() -def test_create_project(client, organisation, project_data): +async def test_create_project(client, organisation, project_data): """Test project creation endpoint.""" - response_data = create_project(client, organisation.id, project_data) + response_data = await create_project(client, organisation.id, project_data) project_name = project_data["name"] assert "id" in response_data # Duplicate response to test error condition: project name already exists - response_duplicate = client.post( + response_duplicate = await client.post( f"/projects?org_id={organisation.id}", json=project_data ) assert response_duplicate.status_code == 400 @@ -136,7 +136,7 @@ def test_create_project(client, organisation, project_data): }, ], ) -def test_valid_geojson_types(client, organisation, project_data, geojson_type): +async def test_valid_geojson_types(client, organisation, project_data, geojson_type): """Test valid geojson types.""" project_data["outline"] = geojson_type response_data = create_project(client, organisation.id, project_data) @@ -169,10 +169,12 @@ def test_valid_geojson_types(client, organisation, project_data, geojson_type): }, ], ) -def test_invalid_geojson_types(client, organisation, project_data, geojson_type): +async def test_invalid_geojson_types(client, organisation, project_data, geojson_type): """Test invalid geojson types.""" project_data["outline"] = geojson_type - response = client.post(f"/projects?org_id={organisation.id}", json=project_data) + response = await client.post( + f"/projects?org_id={organisation.id}", json=project_data + ) assert response.status_code == 422 @@ -213,7 +215,7 @@ def test_hashtags(client, organisation, project_data, hashtag_input, expected_ou async def test_delete_project(client, admin_user, project): """Test deleting a FMTM project, plus ODK Central project.""" - response = client.delete(f"/projects/{project.id}") + response = await client.delete(f"/projects/{project.id}") assert response.status_code == 204 @@ -237,34 +239,6 @@ async def test_create_odk_project(): mock_project.createProject.assert_called_once_with("FMTM Test Project") -async def test_convert_to_app_project(): - """Test conversion to app project.""" - polygon = Polygon( - [ - (85.317028828, 27.7052522097), - (85.317028828, 27.7041424888), - (85.318844411, 27.7041424888), - (85.318844411, 27.7052522097), - (85.317028828, 27.7052522097), - ] - ) - - mock_db_project = DbProject( - id=1, - outline=polygon.__geo_interface__, - ) - - result = await project_crud.convert_to_app_project(mock_db_project) - - assert result is not None - assert isinstance(result, DbProject) - - assert result.outline is not None - - assert result.tasks is not None - assert isinstance(result.tasks, list) - - async def test_create_project_with_project_info(db, project): """Test creating a project with all project info.""" assert isinstance(project.id, int) @@ -280,14 +254,14 @@ async def test_upload_data_extracts(client, project): open(f"{test_data_path}/data_extract_kathmandu.fgb", "rb"), ) } - response = client.post( + response = await client.post( f"/projects/upload-custom-extract/?project_id={project.id}", files=fgb_file, ) assert response.status_code == 200 - response = client.get( + response = await client.get( f"/projects/data-extract-url/?project_id={project.id}", ) assert "url" in response.json() @@ -299,14 +273,14 @@ async def test_upload_data_extracts(client, project): open(f"{test_data_path}/data_extract_kathmandu.geojson", "rb"), ) } - response = client.post( + response = await client.post( f"/projects/upload-custom-extract/?project_id={project.id}", files=geojson_file, ) assert response.status_code == 200 - response = client.get( + response = await client.get( f"/projects/data-extract-url/?project_id={project.id}", ) assert "url" in response.json() @@ -356,7 +330,7 @@ async def test_generate_project_files(db, client, project): BytesIO(task_geojson).read(), ) } - response = client.post( + response = await client.post( f"/projects/{project_id}/upload-task-boundaries", files=task_geojson_file, ) @@ -391,7 +365,7 @@ async def test_generate_project_files(db, client, project): xlsform_obj, ) } - response = client.post( + response = await client.post( f"/projects/{project_id}/generate-project-data", files=xform_file, ) @@ -420,7 +394,7 @@ async def test_update_project(client, admin_user, project): }, } - response = client.put(f"/projects/{project.id}", json=updated_project_data) + response = await client.put(f"/projects/{project.id}", json=updated_project_data) if response.status_code != 200: log.error(response.json()) @@ -439,7 +413,7 @@ async def test_update_project(client, admin_user, project): async def test_project_summaries(client, project): """Test read project summaries.""" - response = client.get("/projects/summaries") + response = await client.get("/projects/summaries") assert response.status_code == 200 assert "results" in response.json() @@ -455,7 +429,7 @@ async def test_project_summaries(client, project): async def test_project_by_id(client, project): """Test read project by id.""" - response = client.get(f"projects/{project.id}") + response = await client.get(f"projects/{project.id}?project_id={project.id}") assert response.status_code == 200 data = response.json() @@ -480,7 +454,7 @@ async def test_set_entity_mapping_status(client, odk_project, entities): entity = entities[0] expected_status = TaskStatus.LOCKED_FOR_MAPPING - response = client.post( + response = await client.post( f"/projects/{odk_project.id}/entity/status", json={ "entity_id": entity["id"], @@ -499,7 +473,7 @@ async def test_set_entity_mapping_status(client, odk_project, entities): async def test_get_entity_mapping_status(client, odk_project, entities): """Test get the ODK entity mapping status.""" entity = entities[0] - response = client.get( + response = await client.get( f"/projects/{odk_project.id}/entity/status", params={"entity_id": entity["id"]} ) response_entity = response.json() @@ -511,7 +485,7 @@ async def test_get_entity_mapping_status(client, odk_project, entities): async def test_get_entities_mapping_statuses(client, odk_project, entities): """Test get the ODK entities mapping statuses.""" odk_project_id = odk_project.id - response = client.get(f"projects/{odk_project_id}/entities/statuses") + response = await client.get(f"projects/{odk_project_id}/entities/statuses") response_entities = response.json() assert len(response_entities) == len(entities) @@ -529,7 +503,7 @@ def compare_entities(response_entity, expected_entity): assert str(response_entity["status"]) == str(expected_entity["status"]) -def test_project_task_split(client): +async def test_project_task_split(client): """Test project AOI splitting into tasks.""" aoi_geojson = json.dumps( { @@ -552,7 +526,7 @@ def test_project_task_split(client): ) } - response = client.post( + response = await client.post( "/projects/task-split", files=aoi_geojson_file, data={"no_of_buildings": 40}, @@ -563,7 +537,7 @@ def test_project_task_split(client): assert "features" in response.json() # Test without required value should cause validation error - response = client.post("/projects/task-split") + response = await client.post("/projects/task-split") assert response.status_code == 422 diff --git a/src/backend/tests/test_task_routes.py b/src/backend/tests/test_task_routes.py index abb7272d1..464e206eb 100644 --- a/src/backend/tests/test_task_routes.py +++ b/src/backend/tests/test_task_routes.py @@ -22,13 +22,13 @@ from app.db.enums import TaskStatus -def test_read_task_history(client, task_history): +async def test_read_task_history(client, task_history): """Test task history for a project.""" task_id = task_history.task_id assert task_id is not None - response = client.get(f"/tasks/{task_id}/history/") + response = await client.get(f"/tasks/{task_id}/history/") data = response.json()[0] assert response.status_code == 200 @@ -36,13 +36,13 @@ def test_read_task_history(client, task_history): assert data["username"] == task_history.actioned_by.username -def test_update_task_status(client, tasks): +async def test_update_task_status(client, tasks): """Test update the task status.""" task_id = tasks[0].id project_id = tasks[0].project_id new_status = TaskStatus.LOCKED_FOR_MAPPING - response = client.post( + response = await client.post( f"tasks/{task_id}/new-status/{new_status.value}?project_id={project_id}" ) diff --git a/src/backend/tests/test_users.py b/src/backend/tests/test_users.py index ee5b4ac4d..ac547882a 100644 --- a/src/backend/tests/test_users.py +++ b/src/backend/tests/test_users.py @@ -31,28 +31,31 @@ def test_nothing(): # from app.users.user_crud import create_user -# @pytest.fixture +# @pytest_asyncio.fixture # def users(db): # create_user(db, user_schemas.UserIn(username="admin", password="admin")) # create_user(db, user_schemas.UserIn(username="niraj", password="niraj")) # create_user(db, user_schemas.UserIn(username="test", password="test")) -# def test_list_users(users, client): -# response = client.get("/users") +# async def test_list_users(users, client): +# response = await client.get("/users") # assert len(response.json()) == 3 -# def test_create_users(client): -# response = client.post("/users/", json={"username": "test3", "password": "test1"}) +# async def test_create_users(client): +# response = await client.post("/users/", json={ +# "username": "test3", "password": "test1"}) # assert response.status_code == status.HTTP_200_OK -# response = client.post("/users/", json={"username": "niraj", "password": "niraj"}) +# response = await client.post("/users/", json={ +# "username": "niraj", "password": "niraj"}) # assert response.status_code == status.HTTP_200_OK -# response = client.post("/users/", json={"username": "niraj"}) +# response = await client.post("/users/", json={"username": "niraj"}) # assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY -# response = client.post("/users/", json={"username": "niraj", "password": "niraj"}) +# response = await client.post("/users/", json={ +# "username": "niraj", "password": "niraj"}) # assert response.status_code == status.HTTP_400_BAD_REQUEST # assert response.json() == {"detail": "Username already registered"}