From efce57589f06e4f50ecb3e0a6051dbeff0676d77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D1=80=D1=82=D1=8B=D0=BD=D0=BE=D0=B2=20=D0=9C?= =?UTF-8?q?=D0=B0=D0=BA=D1=81=D0=B8=D0=BC=20=D0=A1=D0=B5=D1=80=D0=B3=D0=B5?= =?UTF-8?q?=D0=B5=D0=B2=D0=B8=D1=87?= Date: Fri, 19 Apr 2024 14:11:03 +0300 Subject: [PATCH] [DOP-14025] Fix patching connection with new auth_data.password value --- docs/changelog/next_release/39.bugfix.rst | 1 + syncmaster/backend/api/v1/connections.py | 35 +++++++----------- syncmaster/db/repositories/connection.py | 8 ++-- .../db/repositories/credentials_repository.py | 8 ++-- syncmaster/db/repositories/utils.py | 24 +++++++++--- .../test_update_connection.py | 37 ++++++++++++++++++- 6 files changed, 75 insertions(+), 38 deletions(-) create mode 100644 docs/changelog/next_release/39.bugfix.rst diff --git a/docs/changelog/next_release/39.bugfix.rst b/docs/changelog/next_release/39.bugfix.rst new file mode 100644 index 00000000..2f621f6f --- /dev/null +++ b/docs/changelog/next_release/39.bugfix.rst @@ -0,0 +1 @@ +Fix 500 error in case of ``PATCH v1/connections/:id`` request with passed ``auth_data.password`` field value diff --git a/syncmaster/backend/api/v1/connections.py b/syncmaster/backend/api/v1/connections.py index acfe9a2a..5c61a7a4 100644 --- a/syncmaster/backend/api/v1/connections.py +++ b/syncmaster/backend/api/v1/connections.py @@ -4,7 +4,6 @@ from typing import get_args from fastapi import APIRouter, Depends, Query, status -from pydantic import SecretStr from syncmaster.backend.api.deps import UnitOfWorkMarker from syncmaster.backend.services import UnitOfWork, get_user @@ -106,33 +105,27 @@ async def create_connection( if group_permission < Permission.WRITE: raise ActionNotAllowedError - data = connection_data.data.dict() - auth_data = connection_data.auth_data.dict() - - # Trick to serialize SecretStr to JSON - for k, v in auth_data.items(): - if isinstance(v, SecretStr): - auth_data[k] = v.get_secret_value() async with unit_of_work: connection = await unit_of_work.connection.create( name=connection_data.name, description=connection_data.description, group_id=connection_data.group_id, - data=data, + data=connection_data.data.dict(), ) await unit_of_work.credentials.add_to_connection( connection_id=connection.id, - data=auth_data, + data=connection_data.auth_data.dict(), ) + credentials = await unit_of_work.credentials.get_for_connection(connection.id) return ReadConnectionSchema( id=connection.id, group_id=connection.group_id, name=connection.name, description=connection.description, data=connection.data, - auth_data=auth_data, + auth_data=credentials, ) @@ -158,9 +151,7 @@ async def read_connection( connection = await unit_of_work.connection.read_by_id(connection_id=connection_id) try: - credentials = await unit_of_work.credentials.get_for_connection( - connection_id=connection.id, - ) + credentials = await unit_of_work.credentials.get_for_connection(connection.id) except AuthDataNotFoundError: credentials = None @@ -177,7 +168,7 @@ async def read_connection( @router.patch("/connections/{connection_id}") async def update_connection( connection_id: int, - connection_data: UpdateConnectionSchema, + changes: UpdateConnectionSchema, current_user: User = Depends(get_user(is_active=True)), unit_of_work: UnitOfWork = Depends(UnitOfWorkMarker), ) -> ReadConnectionSchema: @@ -195,25 +186,25 @@ async def update_connection( async with unit_of_work: connection = await unit_of_work.connection.update( connection_id=connection_id, - name=connection_data.name, - description=connection_data.description, - connection_data=connection_data.data.dict(exclude={"auth_data"}) if connection_data.data else {}, + name=changes.name, + description=changes.description, + data=changes.data.dict(exclude={"auth_data"}) if changes.data else {}, ) - if connection_data.auth_data: + if changes.auth_data: await unit_of_work.credentials.update( connection_id=connection_id, - credential_data=connection_data.auth_data.dict(), + data=changes.auth_data.dict(), ) - auth_data = await unit_of_work.credentials.get_for_connection(connection_id) + credentials = await unit_of_work.credentials.get_for_connection(connection_id) return ReadConnectionSchema( id=connection.id, group_id=connection.group_id, name=connection.name, description=connection.description, data=connection.data, - auth_data=auth_data, + auth_data=credentials, ) diff --git a/syncmaster/db/repositories/connection.py b/syncmaster/db/repositories/connection.py index 301a3180..d6b03ac0 100644 --- a/syncmaster/db/repositories/connection.py +++ b/syncmaster/db/repositories/connection.py @@ -81,19 +81,19 @@ async def update( connection_id: int, name: str | None, description: str | None, - connection_data: dict[str, Any], + data: dict[str, Any], ) -> Connection: try: connection = await self.read_by_id(connection_id=connection_id) for key in connection.data: - if key not in connection_data or connection_data[key] is None: - connection_data[key] = connection.data[key] + data[key] = data.get(key, None) or connection.data[key] + return await self._update( Connection.id == connection_id, Connection.is_deleted.is_(False), name=name or connection.name, description=description or connection.description, - data=connection_data, + data=data, ) except IntegrityError as e: self._raise_error(e) diff --git a/syncmaster/db/repositories/credentials_repository.py b/syncmaster/db/repositories/credentials_repository.py index b16e94d5..fe51e20c 100644 --- a/syncmaster/db/repositories/credentials_repository.py +++ b/syncmaster/db/repositories/credentials_repository.py @@ -68,17 +68,15 @@ async def delete_from_connection(self, connection_id: int) -> AuthData: async def update( self, connection_id: int, - credential_data: dict, + data: dict, ) -> AuthData: creds = await self.get_for_connection(connection_id) try: for key in creds: - if key not in credential_data or credential_data[key] is None: - credential_data[key] = creds[key] - + data[key] = data.get(key, None) or creds[key] return await self._update( AuthData.connection_id == connection_id, - value=encrypt_auth_data(value=credential_data, settings=self._settings), + value=encrypt_auth_data(value=data, settings=self._settings), ) except IntegrityError as e: self._raise_error(e) diff --git a/syncmaster/db/repositories/utils.py b/syncmaster/db/repositories/utils.py index c039acfa..2b2acdf3 100644 --- a/syncmaster/db/repositories/utils.py +++ b/syncmaster/db/repositories/utils.py @@ -3,6 +3,7 @@ import json from cryptography.fernet import Fernet +from pydantic import SecretStr from syncmaster.config import Settings @@ -11,15 +12,26 @@ def decrypt_auth_data( value: str, settings: Settings, ) -> dict: - f = Fernet(settings.CRYPTO_KEY) - return json.loads(f.decrypt(value)) + decryptor = Fernet(settings.CRYPTO_KEY) + decrypted = decryptor.decrypt(value) + return json.loads(decrypted) + + +def _json_default(value): + if isinstance(value, SecretStr): + return value.get_secret_value() def encrypt_auth_data( value: dict, settings: Settings, ) -> str: - key = str.encode(settings.CRYPTO_KEY) - f = Fernet(key) - token = f.encrypt(str.encode(json.dumps(value))) - return token.decode(encoding="utf-8") + encryptor = Fernet(settings.CRYPTO_KEY) + serialized = json.dumps( + value, + ensure_ascii=False, + sort_keys=True, + default=_json_default, + ) + encrypted = encryptor.encrypt(serialized.encode("utf-8")) + return encrypted.decode("utf-8") diff --git a/tests/test_unit/test_connections/test_update_connection.py b/tests/test_unit/test_connections/test_update_connection.py index 00bcc2f7..1135b62f 100644 --- a/tests/test_unit/test_connections/test_update_connection.py +++ b/tests/test_unit/test_connections/test_update_connection.py @@ -209,7 +209,42 @@ async def test_update_connection_data_fields( assert result.status_code == 200 -async def test_update_connection_auth_data_fields( +async def test_update_connection_auth_data_all_felds( + client: AsyncClient, + group_connection: MockConnection, + role_developer_plus: UserTestRoles, +): + # Arrange + user = group_connection.owner_group.get_member_of_role(role_developer_plus) + # Act + result = await client.patch( + f"v1/connections/{group_connection.id}", + headers={"Authorization": f"Bearer {user.token}"}, + json={"auth_data": {"type": "postgres", "user": "new_user", "password": "new_password"}}, + ) + + # Assert + assert result.json() == { + "id": group_connection.id, + "name": group_connection.name, + "description": group_connection.description, + "group_id": group_connection.group_id, + "connection_data": { + "type": group_connection.data["type"], + "host": "127.0.0.1", + "port": group_connection.data["port"], + "additional_params": group_connection.data["additional_params"], + "database_name": group_connection.data["database_name"], + }, + "auth_data": { + "type": group_connection.credentials.value["type"], + "user": "new_user", + }, + } + assert result.status_code == 200 + + +async def test_update_connection_auth_data_partial( client: AsyncClient, group_connection: MockConnection, role_developer_plus: UserTestRoles,