diff --git a/authentik/lib/utils/http.py b/authentik/lib/utils/http.py index b90885ea39d3..a29b589fc0c1 100644 --- a/authentik/lib/utils/http.py +++ b/authentik/lib/utils/http.py @@ -21,7 +21,14 @@ class DebugSession(Session): def send(self, req: PreparedRequest, *args, **kwargs): request_id = str(uuid4()) - LOGGER.debug("HTTP request sent", uid=request_id, path=req.path_url, headers=req.headers) + LOGGER.debug( + "HTTP request sent", + uid=request_id, + url=req.url, + method=req.method, + headers=req.headers, + body=req.body, + ) resp = super().send(req, *args, **kwargs) LOGGER.debug( "HTTP response received", diff --git a/authentik/providers/scim/clients/groups.py b/authentik/providers/scim/clients/groups.py index 1f39eea8f52d..44b3405dfffa 100644 --- a/authentik/providers/scim/clients/groups.py +++ b/authentik/providers/scim/clients/groups.py @@ -2,9 +2,10 @@ from itertools import batched +from django.db import transaction from pydantic import ValidationError from pydanticscim.group import GroupMember -from pydanticscim.responses import PatchOp, PatchOperation +from pydanticscim.responses import PatchOp from authentik.core.models import Group from authentik.lib.sync.mapper import PropertyMappingManager @@ -19,7 +20,7 @@ from authentik.providers.scim.clients.exceptions import ( SCIMRequestException, ) -from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchRequest +from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema from authentik.providers.scim.models import ( SCIMMapping, @@ -104,13 +105,47 @@ def create(self, group: Group): provider=self.provider, group=group, scim_id=scim_id ) users = list(group.users.order_by("id").values_list("id", flat=True)) - self._patch_add_users(group, users) + self._patch_add_users(connection, users) return connection def update(self, group: Group, connection: SCIMProviderGroup): """Update existing group""" scim_group = self.to_schema(group, connection) scim_group.id = connection.scim_id + try: + if self._config.patch.supported: + return self._update_patch(group, scim_group, connection) + return self._update_put(group, scim_group, connection) + except NotFoundSyncException: + # Resource missing is handled by self.write, which will re-create the group + raise + + def _update_patch( + self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup + ): + """Update a group via PATCH request""" + # Patch group's attributes instead of replacing it and re-adding users if we can + self._request( + "PATCH", + f"/Groups/{connection.scim_id}", + json=PatchRequest( + Operations=[ + PatchOperation( + op=PatchOp.replace, + path=None, + value=scim_group.model_dump(mode="json", exclude_unset=True), + ) + ] + ).model_dump( + mode="json", + exclude_unset=True, + exclude_none=True, + ), + ) + return self.patch_compare_users(group) + + def _update_put(self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup): + """Update a group via PUT request""" try: self._request( "PUT", @@ -120,33 +155,25 @@ def update(self, group: Group, connection: SCIMProviderGroup): exclude_unset=True, ), ) - users = list(group.users.order_by("id").values_list("id", flat=True)) - return self._patch_add_users(group, users) - except NotFoundSyncException: - # Resource missing is handled by self.write, which will re-create the group - raise + return self.patch_compare_users(group) except (SCIMRequestException, ObjectExistsSyncException): # Some providers don't support PUT on groups, so this is mainly a fix for the initial # sync, send patch add requests for all the users the group currently has - users = list(group.users.order_by("id").values_list("id", flat=True)) - self._patch_add_users(group, users) - # Also update the group name - return self._patch( - scim_group.id, - PatchOperation( - op=PatchOp.replace, - path="displayName", - value=scim_group.displayName, - ), - ) + return self._update_patch(group, scim_group, connection) def update_group(self, group: Group, action: Direction, users_set: set[int]): """Update a group, either using PUT to replace it or PATCH if supported""" + scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first() + if not scim_group: + self.logger.warning( + "could not sync group membership, group does not exist", group=group + ) + return if self._config.patch.supported: if action == Direction.add: - return self._patch_add_users(group, users_set) + return self._patch_add_users(scim_group, users_set) if action == Direction.remove: - return self._patch_remove_users(group, users_set) + return self._patch_remove_users(scim_group, users_set) try: return self.write(group) except SCIMRequestException as exc: @@ -154,16 +181,19 @@ def update_group(self, group: Group, action: Direction, users_set: set[int]): # Assume that provider does not support PUT and also doesn't support # ServiceProviderConfig, so try PATCH as a fallback if action == Direction.add: - return self._patch_add_users(group, users_set) + return self._patch_add_users(scim_group, users_set) if action == Direction.remove: - return self._patch_remove_users(group, users_set) + return self._patch_remove_users(scim_group, users_set) raise exc - def _patch( + def _patch_chunked( self, group_id: str, *ops: PatchOperation, ): + """Helper function that chunks patch requests based on the maxOperations attribute. + This is not strictly according to specs but there's nothing in the schema that allows the + us to know what the maximum patch operations per request should be.""" chunk_size = self._config.bulk.maxOperations if chunk_size < 1: chunk_size = len(ops) @@ -177,16 +207,67 @@ def _patch( ), ) - def _patch_add_users(self, group: Group, users_set: set[int]): - """Add users in users_set to group""" - if len(users_set) < 1: - return + @transaction.atomic + def patch_compare_users(self, group: Group): + """Compare users with a SCIM group and add/remove any differences""" + # Get scim group first scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first() if not scim_group: self.logger.warning( "could not sync group membership, group does not exist", group=group ) return + # Get a list of all users in the authentik group + raw_users_should = list(group.users.order_by("id").values_list("id", flat=True)) + # Lookup the SCIM IDs of the users + users_should: list[str] = list( + SCIMProviderUser.objects.filter( + user__pk__in=raw_users_should, provider=self.provider + ).values_list("scim_id", flat=True) + ) + if len(raw_users_should) != len(users_should): + self.logger.warning( + "User count mismatch, not all users in the group are synced to SCIM yet.", + group=group, + ) + # Get current group status + current_group = SCIMGroupSchema.model_validate( + self._request("GET", f"/Groups/{scim_group.scim_id}") + ) + users_to_add = [] + users_to_remove = [] + # Check users currently in group and if they shouldn't be in the group and remove them + for user in current_group.members: + if user.value not in users_should: + users_to_remove.append(user.value) + # Check users that should be in the group and add them + for user in users_should: + if len([x for x in current_group.members if x.value == user]) < 1: + users_to_add.append(user) + return self._patch_chunked( + scim_group.scim_id, + *[ + PatchOperation( + op=PatchOp.add, + path="members", + value=[{"value": x}], + ) + for x in users_to_add + ], + *[ + PatchOperation( + op=PatchOp.remove, + path="members", + value=[{"value": x}], + ) + for x in users_to_remove + ], + ) + + def _patch_add_users(self, scim_group: SCIMProviderGroup, users_set: set[int]): + """Add users in users_set to group""" + if len(users_set) < 1: + return user_ids = list( SCIMProviderUser.objects.filter( user__pk__in=users_set, provider=self.provider @@ -194,7 +275,7 @@ def _patch_add_users(self, group: Group, users_set: set[int]): ) if len(user_ids) < 1: return - self._patch( + self._patch_chunked( scim_group.scim_id, *[ PatchOperation( @@ -206,16 +287,10 @@ def _patch_add_users(self, group: Group, users_set: set[int]): ], ) - def _patch_remove_users(self, group: Group, users_set: set[int]): + def _patch_remove_users(self, scim_group: SCIMProviderGroup, users_set: set[int]): """Remove users in users_set from group""" if len(users_set) < 1: return - scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first() - if not scim_group: - self.logger.warning( - "could not sync group membership, group does not exist", group=group - ) - return user_ids = list( SCIMProviderUser.objects.filter( user__pk__in=users_set, provider=self.provider @@ -223,7 +298,7 @@ def _patch_remove_users(self, group: Group, users_set: set[int]): ) if len(user_ids) < 1: return - self._patch( + self._patch_chunked( scim_group.scim_id, *[ PatchOperation( diff --git a/authentik/providers/scim/clients/schema.py b/authentik/providers/scim/clients/schema.py index b4444b37346f..1cbf07145299 100644 --- a/authentik/providers/scim/clients/schema.py +++ b/authentik/providers/scim/clients/schema.py @@ -2,6 +2,7 @@ from pydantic import Field from pydanticscim.group import Group as BaseGroup +from pydanticscim.responses import PatchOperation as BasePatchOperation from pydanticscim.responses import PatchRequest as BasePatchRequest from pydanticscim.responses import SCIMError as BaseSCIMError from pydanticscim.service_provider import Bulk as BaseBulk @@ -68,6 +69,12 @@ class PatchRequest(BasePatchRequest): schemas: tuple[str] = ("urn:ietf:params:scim:api:messages:2.0:PatchOp",) +class PatchOperation(BasePatchOperation): + """PatchOperation with optional path""" + + path: str | None + + class SCIMError(BaseSCIMError): """SCIM error with optional status code""" diff --git a/authentik/providers/scim/tests/test_membership.py b/authentik/providers/scim/tests/test_membership.py index 8b2b0dc9b317..24084622fc5d 100644 --- a/authentik/providers/scim/tests/test_membership.py +++ b/authentik/providers/scim/tests/test_membership.py @@ -252,3 +252,118 @@ def test_member_remove(self): ], }, ) + + def test_member_add_save(self): + """Test member add + save""" + config = ServiceProviderConfiguration.default() + + config.patch.supported = True + user_scim_id = generate_id() + group_scim_id = generate_id() + uid = generate_id() + group = Group.objects.create( + name=uid, + ) + + user = User.objects.create(username=generate_id()) + + # Test initial sync of group creation + with Mocker() as mocker: + mocker.get( + "https://localhost/ServiceProviderConfig", + json=config.model_dump(), + ) + mocker.post( + "https://localhost/Users", + json={ + "id": user_scim_id, + }, + ) + mocker.post( + "https://localhost/Groups", + json={ + "id": group_scim_id, + }, + ) + + self.configure() + sync_tasks.trigger_single_task(self.provider, scim_sync).get() + + self.assertEqual(mocker.call_count, 6) + self.assertEqual(mocker.request_history[0].method, "GET") + self.assertEqual(mocker.request_history[1].method, "GET") + self.assertEqual(mocker.request_history[2].method, "GET") + self.assertEqual(mocker.request_history[3].method, "POST") + self.assertEqual(mocker.request_history[4].method, "GET") + self.assertEqual(mocker.request_history[5].method, "POST") + self.assertJSONEqual( + mocker.request_history[3].body, + { + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], + "emails": [], + "active": True, + "externalId": user.uid, + "name": {"familyName": " ", "formatted": " ", "givenName": ""}, + "displayName": "", + "userName": user.username, + }, + ) + self.assertJSONEqual( + mocker.request_history[5].body, + { + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], + "externalId": str(group.pk), + "displayName": group.name, + }, + ) + + with Mocker() as mocker: + mocker.get( + "https://localhost/ServiceProviderConfig", + json=config.model_dump(), + ) + mocker.get( + f"https://localhost/Groups/{group_scim_id}", + json={}, + ) + mocker.patch( + f"https://localhost/Groups/{group_scim_id}", + json={}, + ) + group.users.add(user) + group.save() + self.assertEqual(mocker.call_count, 5) + self.assertEqual(mocker.request_history[0].method, "GET") + self.assertEqual(mocker.request_history[1].method, "PATCH") + self.assertEqual(mocker.request_history[2].method, "GET") + self.assertEqual(mocker.request_history[3].method, "PATCH") + self.assertEqual(mocker.request_history[4].method, "GET") + self.assertJSONEqual( + mocker.request_history[1].body, + { + "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], + "Operations": [ + { + "op": "add", + "path": "members", + "value": [{"value": user_scim_id}], + } + ], + }, + ) + self.assertJSONEqual( + mocker.request_history[3].body, + { + "Operations": [ + { + "op": "replace", + "value": { + "id": group_scim_id, + "displayName": group.name, + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], + "externalId": str(group.pk), + }, + } + ] + }, + ) diff --git a/web/src/admin/providers/scim/SCIMProviderForm.ts b/web/src/admin/providers/scim/SCIMProviderForm.ts index 0d82688ad37d..5bc15c2bf8cf 100644 --- a/web/src/admin/providers/scim/SCIMProviderForm.ts +++ b/web/src/admin/providers/scim/SCIMProviderForm.ts @@ -38,12 +38,15 @@ export async function scimPropertyMappingsProvider(page = 1, search = "") { }; } -export function makeSCIMPropertyMappingsSelector(instanceMappings: string[] | undefined) { +export function makeSCIMPropertyMappingsSelector( + instanceMappings: string[] | undefined, + defaultSelected: string, +) { const localMappings = instanceMappings ? new Set(instanceMappings) : undefined; return localMappings ? ([pk, _]: DualSelectPair) => localMappings.has(pk) : ([_0, _1, _2, mapping]: DualSelectPair) => - mapping?.managed === "goauthentik.io/providers/scim/user"; + mapping?.managed === defaultSelected; } @customElement("ak-provider-scim-form") @@ -189,6 +192,7 @@ export class SCIMProviderFormPage extends BaseProviderForm { .provider=${scimPropertyMappingsProvider} .selector=${makeSCIMPropertyMappingsSelector( this.instance?.propertyMappings, + "goauthentik.io/providers/scim/user", )} available-label=${msg("Available User Property Mappings")} selected-label=${msg("Selected User Property Mappings")} @@ -205,6 +209,7 @@ export class SCIMProviderFormPage extends BaseProviderForm { .provider=${scimPropertyMappingsProvider} .selector=${makeSCIMPropertyMappingsSelector( this.instance?.propertyMappingsGroup, + "goauthentik.io/providers/scim/group", )} available-label=${msg("Available Group Property Mappings")} selected-label=${msg("Selected Group Property Mappings")}