Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOP-14025] Fix patching connection with new auth_data.password value #39

Merged
merged 1 commit into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog/next_release/39.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix 500 error in case of ``PATCH v1/connections/:id`` request with passed ``auth_data.password`` field value
31 changes: 12 additions & 19 deletions syncmaster/backend/api/v1/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import get_args

from fastapi import APIRouter, Depends, Query, status
from pydantic import SecretStr

from syncmaster.backend.api.deps import UnitOfWorkMarker
from syncmaster.backend.services import UnitOfWork, get_user
Expand Down Expand Up @@ -103,33 +102,27 @@ async def create_connection(
if group_permission < Permission.WRITE:
raise ActionNotAllowedError

data = connection_data.data.dict()
auth_data = connection_data.auth_data.dict()

# Trick to serialize SecretStr to JSON
for k, v in auth_data.items():
if isinstance(v, SecretStr):
auth_data[k] = v.get_secret_value()
async with unit_of_work:
connection = await unit_of_work.connection.create(
name=connection_data.name,
description=connection_data.description,
group_id=connection_data.group_id,
data=data,
data=connection_data.data.dict(),
)

await unit_of_work.credentials.create(
connection_id=connection.id,
data=auth_data,
data=connection_data.auth_data.dict(),
)

credentials = await unit_of_work.credentials.read(connection.id)
return ReadConnectionSchema(
id=connection.id,
group_id=connection.group_id,
name=connection.name,
description=connection.description,
data=connection.data,
auth_data=auth_data,
auth_data=credentials,
)


Expand Down Expand Up @@ -171,7 +164,7 @@ async def read_connection(
@router.patch("/connections/{connection_id}")
async def update_connection(
connection_id: int,
connection_data: UpdateConnectionSchema,
changes: UpdateConnectionSchema,
current_user: User = Depends(get_user(is_active=True)),
unit_of_work: UnitOfWork = Depends(UnitOfWorkMarker),
) -> ReadConnectionSchema:
Expand All @@ -189,25 +182,25 @@ async def update_connection(
async with unit_of_work:
connection = await unit_of_work.connection.update(
connection_id=connection_id,
name=connection_data.name,
description=connection_data.description,
connection_data=connection_data.data.dict(exclude={"auth_data"}) if connection_data.data else {},
name=changes.name,
description=changes.description,
data=changes.data.dict(exclude={"auth_data"}) if changes.data else {},
)

if connection_data.auth_data:
if changes.auth_data:
await unit_of_work.credentials.update(
connection_id=connection_id,
credential_data=connection_data.auth_data.dict(),
data=changes.auth_data.dict(),
)

auth_data = await unit_of_work.credentials.read(connection_id)
credentials = await unit_of_work.credentials.read(connection_id)
return ReadConnectionSchema(
id=connection.id,
group_id=connection.group_id,
name=connection.name,
description=connection.description,
data=connection.data,
auth_data=auth_data,
auth_data=credentials,
)


Expand Down
8 changes: 4 additions & 4 deletions syncmaster/db/repositories/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,19 @@ async def update(
connection_id: int,
name: str | None,
description: str | None,
connection_data: dict[str, Any],
data: dict[str, Any],
) -> Connection:
try:
connection = await self.read_by_id(connection_id=connection_id)
for key in connection.data:
if key not in connection_data or connection_data[key] is None:
connection_data[key] = connection.data[key]
data[key] = data.get(key, None) or connection.data[key]

return await self._update(
Connection.id == connection_id,
Connection.is_deleted.is_(False),
name=name or connection.name,
description=description or connection.description,
data=connection_data,
data=data,
)
except IntegrityError as e:
self._raise_error(e)
Expand Down
8 changes: 3 additions & 5 deletions syncmaster/db/repositories/credentials_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,15 @@ async def create(self, connection_id: int, data: dict) -> AuthData:
async def update(
self,
connection_id: int,
credential_data: dict,
data: dict,
) -> AuthData:
creds = await self.read(connection_id)
try:
for key in creds:
if key not in credential_data or credential_data[key] is None:
credential_data[key] = creds[key]

data[key] = data.get(key, None) or creds[key]
return await self._update(
AuthData.connection_id == connection_id,
value=encrypt_auth_data(value=credential_data, settings=self._settings),
value=encrypt_auth_data(value=data, settings=self._settings),
)
except IntegrityError as e:
self._raise_error(e)
Expand Down
24 changes: 18 additions & 6 deletions syncmaster/db/repositories/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json

from cryptography.fernet import Fernet
from pydantic import SecretStr

from syncmaster.config import Settings

Expand All @@ -11,15 +12,26 @@ def decrypt_auth_data(
value: str,
settings: Settings,
) -> dict:
f = Fernet(settings.CRYPTO_KEY)
return json.loads(f.decrypt(value))
decryptor = Fernet(settings.CRYPTO_KEY)
decrypted = decryptor.decrypt(value)
return json.loads(decrypted)


def _json_default(value):
if isinstance(value, SecretStr):
return value.get_secret_value()


def encrypt_auth_data(
value: dict,
settings: Settings,
) -> str:
key = str.encode(settings.CRYPTO_KEY)
f = Fernet(key)
token = f.encrypt(str.encode(json.dumps(value)))
return token.decode(encoding="utf-8")
encryptor = Fernet(settings.CRYPTO_KEY)
serialized = json.dumps(
value,
ensure_ascii=False,
sort_keys=True,
default=_json_default,
)
encrypted = encryptor.encrypt(serialized.encode("utf-8"))
return encrypted.decode("utf-8")
37 changes: 36 additions & 1 deletion tests/test_unit/test_connections/test_update_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,42 @@ async def test_update_connection_data_fields(
assert result.status_code == 200


async def test_update_connection_auth_data_fields(
async def test_update_connection_auth_data_all_felds(
client: AsyncClient,
group_connection: MockConnection,
role_developer_plus: UserTestRoles,
):
# Arrange
user = group_connection.owner_group.get_member_of_role(role_developer_plus)
# Act
result = await client.patch(
f"v1/connections/{group_connection.id}",
headers={"Authorization": f"Bearer {user.token}"},
json={"auth_data": {"type": "postgres", "user": "new_user", "password": "new_password"}},
)

# Assert
assert result.json() == {
"id": group_connection.id,
"name": group_connection.name,
"description": group_connection.description,
"group_id": group_connection.group_id,
"connection_data": {
"type": group_connection.data["type"],
"host": "127.0.0.1",
"port": group_connection.data["port"],
"additional_params": group_connection.data["additional_params"],
"database_name": group_connection.data["database_name"],
},
"auth_data": {
"type": group_connection.credentials.value["type"],
"user": "new_user",
},
}
assert result.status_code == 200


async def test_update_connection_auth_data_partial(
client: AsyncClient,
group_connection: MockConnection,
role_developer_plus: UserTestRoles,
Expand Down