Skip to content

Commit

Permalink
[DOP-14025] Fix patching connection with new auth_data.password value
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Apr 22, 2024
1 parent 2a5d95f commit b229b92
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 35 deletions.
1 change: 1 addition & 0 deletions docs/changelog/next_release/39.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix 500 error in case of ``PATCH v1/connections/:id`` request with passed ``auth_data.password`` field value
31 changes: 12 additions & 19 deletions syncmaster/backend/api/v1/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import get_args

from fastapi import APIRouter, Depends, Query, status
from pydantic import SecretStr

from syncmaster.backend.api.deps import UnitOfWorkMarker
from syncmaster.backend.services import UnitOfWork, get_user
Expand Down Expand Up @@ -103,33 +102,27 @@ async def create_connection(
if group_permission < Permission.WRITE:
raise ActionNotAllowedError

data = connection_data.data.dict()
auth_data = connection_data.auth_data.dict()

# Trick to serialize SecretStr to JSON
for k, v in auth_data.items():
if isinstance(v, SecretStr):
auth_data[k] = v.get_secret_value()
async with unit_of_work:
connection = await unit_of_work.connection.create(
name=connection_data.name,
description=connection_data.description,
group_id=connection_data.group_id,
data=data,
data=connection_data.data.dict(),
)

await unit_of_work.credentials.create(
connection_id=connection.id,
data=auth_data,
data=connection_data.auth_data.dict(),
)

credentials = await unit_of_work.credentials.read(connection.id)
return ReadConnectionSchema(
id=connection.id,
group_id=connection.group_id,
name=connection.name,
description=connection.description,
data=connection.data,
auth_data=auth_data,
auth_data=credentials,
)


Expand Down Expand Up @@ -171,7 +164,7 @@ async def read_connection(
@router.patch("/connections/{connection_id}")
async def update_connection(
connection_id: int,
connection_data: UpdateConnectionSchema,
changes: UpdateConnectionSchema,
current_user: User = Depends(get_user(is_active=True)),
unit_of_work: UnitOfWork = Depends(UnitOfWorkMarker),
) -> ReadConnectionSchema:
Expand All @@ -189,25 +182,25 @@ async def update_connection(
async with unit_of_work:
connection = await unit_of_work.connection.update(
connection_id=connection_id,
name=connection_data.name,
description=connection_data.description,
connection_data=connection_data.data.dict(exclude={"auth_data"}) if connection_data.data else {},
name=changes.name,
description=changes.description,
data=changes.data.dict(exclude={"auth_data"}) if changes.data else {},
)

if connection_data.auth_data:
if changes.auth_data:
await unit_of_work.credentials.update(
connection_id=connection_id,
credential_data=connection_data.auth_data.dict(),
data=changes.auth_data.dict(),
)

auth_data = await unit_of_work.credentials.read(connection_id)
credentials = await unit_of_work.credentials.read(connection_id)
return ReadConnectionSchema(
id=connection.id,
group_id=connection.group_id,
name=connection.name,
description=connection.description,
data=connection.data,
auth_data=auth_data,
auth_data=credentials,
)


Expand Down
8 changes: 4 additions & 4 deletions syncmaster/db/repositories/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,19 @@ async def update(
connection_id: int,
name: str | None,
description: str | None,
connection_data: dict[str, Any],
data: dict[str, Any],
) -> Connection:
try:
connection = await self.read_by_id(connection_id=connection_id)
for key in connection.data:
if key not in connection_data or connection_data[key] is None:
connection_data[key] = connection.data[key]
data[key] = data.get(key, None) or connection.data[key]

return await self._update(
Connection.id == connection_id,
Connection.is_deleted.is_(False),
name=name or connection.name,
description=description or connection.description,
data=connection_data,
data=data,
)
except IntegrityError as e:
self._raise_error(e)
Expand Down
8 changes: 3 additions & 5 deletions syncmaster/db/repositories/credentials_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,15 @@ async def create(self, connection_id: int, data: dict) -> AuthData:
async def update(
self,
connection_id: int,
credential_data: dict,
data: dict,
) -> AuthData:
creds = await self.read(connection_id)
try:
for key in creds:
if key not in credential_data or credential_data[key] is None:
credential_data[key] = creds[key]

data[key] = data.get(key, None) or creds[key]
return await self._update(
AuthData.connection_id == connection_id,
value=encrypt_auth_data(value=credential_data, settings=self._settings),
value=encrypt_auth_data(value=data, settings=self._settings),
)
except IntegrityError as e:
self._raise_error(e)
Expand Down
24 changes: 18 additions & 6 deletions syncmaster/db/repositories/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json

from cryptography.fernet import Fernet
from pydantic import SecretStr

from syncmaster.config import Settings

Expand All @@ -11,15 +12,26 @@ def decrypt_auth_data(
value: str,
settings: Settings,
) -> dict:
f = Fernet(settings.CRYPTO_KEY)
return json.loads(f.decrypt(value))
decryptor = Fernet(settings.CRYPTO_KEY)
decrypted = decryptor.decrypt(value)
return json.loads(decrypted)


def _json_default(value):
if isinstance(value, SecretStr):
return value.get_secret_value()


def encrypt_auth_data(
value: dict,
settings: Settings,
) -> str:
key = str.encode(settings.CRYPTO_KEY)
f = Fernet(key)
token = f.encrypt(str.encode(json.dumps(value)))
return token.decode(encoding="utf-8")
encryptor = Fernet(settings.CRYPTO_KEY)
serialized = json.dumps(
value,
ensure_ascii=False,
sort_keys=True,
default=_json_default,
)
encrypted = encryptor.encrypt(serialized.encode("utf-8"))
return encrypted.decode("utf-8")
37 changes: 36 additions & 1 deletion tests/test_unit/test_connections/test_update_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,42 @@ async def test_update_connection_data_fields(
assert result.status_code == 200


async def test_update_connection_auth_data_fields(
async def test_update_connection_auth_data_all_felds(
client: AsyncClient,
group_connection: MockConnection,
role_developer_plus: UserTestRoles,
):
# Arrange
user = group_connection.owner_group.get_member_of_role(role_developer_plus)
# Act
result = await client.patch(
f"v1/connections/{group_connection.id}",
headers={"Authorization": f"Bearer {user.token}"},
json={"auth_data": {"type": "postgres", "user": "new_user", "password": "new_password"}},
)

# Assert
assert result.json() == {
"id": group_connection.id,
"name": group_connection.name,
"description": group_connection.description,
"group_id": group_connection.group_id,
"connection_data": {
"type": group_connection.data["type"],
"host": "127.0.0.1",
"port": group_connection.data["port"],
"additional_params": group_connection.data["additional_params"],
"database_name": group_connection.data["database_name"],
},
"auth_data": {
"type": group_connection.credentials.value["type"],
"user": "new_user",
},
}
assert result.status_code == 200


async def test_update_connection_auth_data_partial(
client: AsyncClient,
group_connection: MockConnection,
role_developer_plus: UserTestRoles,
Expand Down

0 comments on commit b229b92

Please sign in to comment.