diff --git a/docs/changelog/next_release/39.bugfix.rst b/docs/changelog/next_release/39.bugfix.rst new file mode 100644 index 00000000..2f621f6f --- /dev/null +++ b/docs/changelog/next_release/39.bugfix.rst @@ -0,0 +1 @@ +Fix 500 error in case of ``PATCH v1/connections/:id`` request with passed ``auth_data.password`` field value diff --git a/syncmaster/backend/api/v1/connections.py b/syncmaster/backend/api/v1/connections.py index b6965e70..6f40651e 100644 --- a/syncmaster/backend/api/v1/connections.py +++ b/syncmaster/backend/api/v1/connections.py @@ -3,7 +3,6 @@ from typing import get_args from fastapi import APIRouter, Depends, Query, status -from pydantic import SecretStr from syncmaster.backend.api.deps import UnitOfWorkMarker from syncmaster.backend.services import UnitOfWork, get_user @@ -103,33 +102,27 @@ async def create_connection( if group_permission < Permission.WRITE: raise ActionNotAllowedError - data = connection_data.data.dict() - auth_data = connection_data.auth_data.dict() - - # Trick to serialize SecretStr to JSON - for k, v in auth_data.items(): - if isinstance(v, SecretStr): - auth_data[k] = v.get_secret_value() async with unit_of_work: connection = await unit_of_work.connection.create( name=connection_data.name, description=connection_data.description, group_id=connection_data.group_id, - data=data, + data=connection_data.data.dict(), ) await unit_of_work.credentials.create( connection_id=connection.id, - data=auth_data, + data=connection_data.auth_data.dict(), ) + credentials = await unit_of_work.credentials.read(connection.id) return ReadConnectionSchema( id=connection.id, group_id=connection.group_id, name=connection.name, description=connection.description, data=connection.data, - auth_data=auth_data, + auth_data=credentials, ) @@ -171,7 +164,7 @@ async def read_connection( @router.patch("/connections/{connection_id}") async def update_connection( connection_id: int, - connection_data: UpdateConnectionSchema, + changes: UpdateConnectionSchema, current_user: User = Depends(get_user(is_active=True)), unit_of_work: UnitOfWork = Depends(UnitOfWorkMarker), ) -> ReadConnectionSchema: @@ -189,25 +182,25 @@ async def update_connection( async with unit_of_work: connection = await unit_of_work.connection.update( connection_id=connection_id, - name=connection_data.name, - description=connection_data.description, - connection_data=connection_data.data.dict(exclude={"auth_data"}) if connection_data.data else {}, + name=changes.name, + description=changes.description, + data=changes.data.dict(exclude={"auth_data"}) if changes.data else {}, ) - if connection_data.auth_data: + if changes.auth_data: await unit_of_work.credentials.update( connection_id=connection_id, - credential_data=connection_data.auth_data.dict(), + data=changes.auth_data.dict(), ) - auth_data = await unit_of_work.credentials.read(connection_id) + credentials = await unit_of_work.credentials.read(connection_id) return ReadConnectionSchema( id=connection.id, group_id=connection.group_id, name=connection.name, description=connection.description, data=connection.data, - auth_data=auth_data, + auth_data=credentials, ) diff --git a/syncmaster/db/repositories/connection.py b/syncmaster/db/repositories/connection.py index 301a3180..d6b03ac0 100644 --- a/syncmaster/db/repositories/connection.py +++ b/syncmaster/db/repositories/connection.py @@ -81,19 +81,19 @@ async def update( connection_id: int, name: str | None, description: str | None, - connection_data: dict[str, Any], + data: dict[str, Any], ) -> Connection: try: connection = await self.read_by_id(connection_id=connection_id) for key in connection.data: - if key not in connection_data or connection_data[key] is None: - connection_data[key] = connection.data[key] + data[key] = data.get(key, None) or connection.data[key] + return await self._update( Connection.id == connection_id, Connection.is_deleted.is_(False), name=name or connection.name, description=description or connection.description, - data=connection_data, + data=data, ) except IntegrityError as e: self._raise_error(e) diff --git a/syncmaster/db/repositories/credentials_repository.py b/syncmaster/db/repositories/credentials_repository.py index e96b651a..03e5d298 100644 --- a/syncmaster/db/repositories/credentials_repository.py +++ b/syncmaster/db/repositories/credentials_repository.py @@ -65,17 +65,15 @@ async def create(self, connection_id: int, data: dict) -> AuthData: async def update( self, connection_id: int, - credential_data: dict, + data: dict, ) -> AuthData: creds = await self.read(connection_id) try: for key in creds: - if key not in credential_data or credential_data[key] is None: - credential_data[key] = creds[key] - + data[key] = data.get(key, None) or creds[key] return await self._update( AuthData.connection_id == connection_id, - value=encrypt_auth_data(value=credential_data, settings=self._settings), + value=encrypt_auth_data(value=data, settings=self._settings), ) except IntegrityError as e: self._raise_error(e) diff --git a/syncmaster/db/repositories/utils.py b/syncmaster/db/repositories/utils.py index c039acfa..2b2acdf3 100644 --- a/syncmaster/db/repositories/utils.py +++ b/syncmaster/db/repositories/utils.py @@ -3,6 +3,7 @@ import json from cryptography.fernet import Fernet +from pydantic import SecretStr from syncmaster.config import Settings @@ -11,15 +12,26 @@ def decrypt_auth_data( value: str, settings: Settings, ) -> dict: - f = Fernet(settings.CRYPTO_KEY) - return json.loads(f.decrypt(value)) + decryptor = Fernet(settings.CRYPTO_KEY) + decrypted = decryptor.decrypt(value) + return json.loads(decrypted) + + +def _json_default(value): + if isinstance(value, SecretStr): + return value.get_secret_value() def encrypt_auth_data( value: dict, settings: Settings, ) -> str: - key = str.encode(settings.CRYPTO_KEY) - f = Fernet(key) - token = f.encrypt(str.encode(json.dumps(value))) - return token.decode(encoding="utf-8") + encryptor = Fernet(settings.CRYPTO_KEY) + serialized = json.dumps( + value, + ensure_ascii=False, + sort_keys=True, + default=_json_default, + ) + encrypted = encryptor.encrypt(serialized.encode("utf-8")) + return encrypted.decode("utf-8") diff --git a/tests/test_unit/test_connections/test_update_connection.py b/tests/test_unit/test_connections/test_update_connection.py index 00bcc2f7..1135b62f 100644 --- a/tests/test_unit/test_connections/test_update_connection.py +++ b/tests/test_unit/test_connections/test_update_connection.py @@ -209,7 +209,42 @@ async def test_update_connection_data_fields( assert result.status_code == 200 -async def test_update_connection_auth_data_fields( +async def test_update_connection_auth_data_all_felds( + client: AsyncClient, + group_connection: MockConnection, + role_developer_plus: UserTestRoles, +): + # Arrange + user = group_connection.owner_group.get_member_of_role(role_developer_plus) + # Act + result = await client.patch( + f"v1/connections/{group_connection.id}", + headers={"Authorization": f"Bearer {user.token}"}, + json={"auth_data": {"type": "postgres", "user": "new_user", "password": "new_password"}}, + ) + + # Assert + assert result.json() == { + "id": group_connection.id, + "name": group_connection.name, + "description": group_connection.description, + "group_id": group_connection.group_id, + "connection_data": { + "type": group_connection.data["type"], + "host": "127.0.0.1", + "port": group_connection.data["port"], + "additional_params": group_connection.data["additional_params"], + "database_name": group_connection.data["database_name"], + }, + "auth_data": { + "type": group_connection.credentials.value["type"], + "user": "new_user", + }, + } + assert result.status_code == 200 + + +async def test_update_connection_auth_data_partial( client: AsyncClient, group_connection: MockConnection, role_developer_plus: UserTestRoles,