diff --git a/main/factories.py b/main/factories.py index 499c599cdb..2ea8eccb1d 100644 --- a/main/factories.py +++ b/main/factories.py @@ -4,7 +4,7 @@ import ulid from django.conf import settings -from factory import LazyFunction, RelatedFactory, SubFactory, Trait +from factory import Faker, LazyFunction, RelatedFactory, SubFactory, Trait from factory.django import DjangoModelFactory from factory.fuzzy import FuzzyText from social_django.models import UserSocialAuth @@ -15,8 +15,8 @@ class UserFactory(DjangoModelFactory): username = LazyFunction(lambda: ulid.new().str) email = FuzzyText(suffix="@example.com") - first_name = FuzzyText() - last_name = FuzzyText() + first_name = Faker("first_name") + last_name = Faker("last_name") profile = RelatedFactory("profiles.factories.ProfileFactory", "user") diff --git a/main/settings.py b/main/settings.py index 55660d70e8..de6eb06ecb 100644 --- a/main/settings.py +++ b/main/settings.py @@ -122,6 +122,7 @@ "data_fixtures", "vector_search", "ai_chat", + "scim", ) if not get_bool("RUN_DATA_MIGRATIONS", default=False): @@ -141,9 +142,11 @@ "documentationUri": "", }, ], - "USER_ADAPTER": "profiles.scim.adapters.LearnSCIMUser", - "USER_MODEL_GETTER": "profiles.scim.adapters.get_user_model_for_scim", - "USER_FILTER_PARSER": "profiles.scim.filters.LearnUserFilterQuery", + "SERVICE_PROVIDER_CONFIG_MODEL": "scim.config.LearnSCIMServiceProviderConfig", + "USER_ADAPTER": "scim.adapters.LearnSCIMUser", + "USER_MODEL_GETTER": "scim.adapters.get_user_model_for_scim", + "USER_FILTER_PARSER": "scim.filters.LearnUserFilterQuery", + "GET_IS_AUTHENTICATED_PREDICATE": "scim.utils.is_authenticated_predicate", } diff --git a/main/settings_celery.py b/main/settings_celery.py index b043acc0f3..227c5e032d 100644 --- a/main/settings_celery.py +++ b/main/settings_celery.py @@ -131,6 +131,12 @@ "schedule": crontab(minute=30, hour=18), # 2:30pm EST "kwargs": {"period": "daily", "subscription_type": "channel_subscription_type"}, }, + "daily_embed_new_learning_resources": { + "task": "vector_search.tasks.embed_new_learning_resources", + "schedule": get_int( + "EMBED_NEW_RESOURCES_SCHEDULE_SECONDS", 60 * 30 + ), # default is every 30 minutes + }, "send-search-subscription-emails-every-1-days": { "task": "learning_resources_search.tasks.send_subscription_emails", "schedule": crontab(minute=0, hour=19), # 3:00pm EST diff --git a/main/urls.py b/main/urls.py index ef45361d81..b34382d7ba 100644 --- a/main/urls.py +++ b/main/urls.py @@ -17,7 +17,7 @@ from django.conf import settings from django.conf.urls.static import static from django.contrib import admin -from django.urls import include, path, re_path +from django.urls import include, re_path from django.views.generic.base import RedirectView from rest_framework.routers import DefaultRouter @@ -41,7 +41,6 @@ urlpatterns = ( [ # noqa: RUF005 - path("scim/v2/", include("django_scim.urls")), re_path(r"^o/", include("oauth2_provider.urls", namespace="oauth2_provider")), re_path(r"^admin/", admin.site.urls), re_path(r"", include("authentication.urls")), @@ -58,6 +57,7 @@ re_path(r"", include("articles.urls")), re_path(r"", include("testimonials.urls")), re_path(r"", include("news_events.urls")), + re_path(r"", include("scim.urls")), re_path(r"", include(features_router.urls)), re_path(r"^app", RedirectView.as_view(url=settings.APP_BASE_URL)), # Hijack diff --git a/poetry.lock b/poetry.lock index e6542321b4..d5138cb9c1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1209,6 +1209,20 @@ files = [ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] +[[package]] +name = "deepmerge" +version = "2.0" +description = "A toolset for deeply merging Python dictionaries." +optional = false +python-versions = ">=3.8" +files = [ + {file = "deepmerge-2.0-py3-none-any.whl", hash = "sha256:6de9ce507115cff0bed95ff0ce9ecc31088ef50cbdf09bc90a09349a318b3d00"}, + {file = "deepmerge-2.0.tar.gz", hash = "sha256:5c3d86081fbebd04dd5de03626a0607b809a98fb6ccba5770b62466fe940ff20"}, +] + +[package.extras] +dev = ["black", "build", "mypy", "pytest", "pyupgrade", "twine", "validate-pyproject[all]"] + [[package]] name = "defusedxml" version = "0.7.1" @@ -7831,4 +7845,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = "3.12.6" -content-hash = "21a25778ed407405a95e83d4acf4bad0046b0098d8ac525ac533dcf9b3cfb35d" +content-hash = "8d53049656967aae32771f78b8df08e006d52f3da491995076f3d3cc981f74a1" diff --git a/profiles/factories.py b/profiles/factories.py index 502cd599f5..7803c232f3 100644 --- a/profiles/factories.py +++ b/profiles/factories.py @@ -1,6 +1,8 @@ """Factories for making test data""" -from factory import Faker, Sequence, SubFactory +import uuid + +from factory import Faker, LazyFunction, SelfAttribute, Sequence, SubFactory from factory.django import DjangoModelFactory from factory.fuzzy import FuzzyChoice from faker.providers import BaseProvider @@ -49,6 +51,9 @@ class ProfileFactory(DjangoModelFactory): [Profile.CertificateDesired.YES.value, Profile.CertificateDesired.NO.value] ) + scim_external_id = LazyFunction(uuid.uuid4) + scim_username = SelfAttribute("user.email") + class Meta: model = Profile diff --git a/profiles/scim/views_test.py b/profiles/scim/views_test.py deleted file mode 100644 index 5a340dac79..0000000000 --- a/profiles/scim/views_test.py +++ /dev/null @@ -1,117 +0,0 @@ -import json - -from django.contrib.auth import get_user_model -from django.urls import reverse -from django_scim import constants - -User = get_user_model() - - -def test_scim_post_user(staff_client): - """Test that we can create a user via SCIM API""" - user_q = User.objects.filter(profile__scim_external_id="1") - assert not user_q.exists() - - resp = staff_client.post( - reverse("scim:users"), - content_type="application/scim+json", - data=json.dumps( - { - "schemas": [constants.SchemaURI.USER], - "emails": [{"value": "jdoe@example.com", "primary": True}], - "active": True, - "userName": "jdoe", - "externalId": "1", - "name": { - "familyName": "Doe", - "givenName": "John", - }, - "fullName": "John Smith Doe", - "emailOptIn": 1, - } - ), - ) - - assert resp.status_code == 201, f"Error response: {resp.content}" - - user = user_q.first() - - assert user is not None - assert user.email == "jdoe@example.com" - assert user.username == "jdoe" - assert user.first_name == "John" - assert user.last_name == "Doe" - assert user.profile.name == "John Smith Doe" - assert user.profile.email_optin is True - - # test an update - resp = staff_client.put( - f"{reverse('scim:users')}/{user.profile.scim_id}", - content_type="application/scim+json", - data=json.dumps( - { - "schemas": [constants.SchemaURI.USER], - "emails": [{"value": "jsmith@example.com", "primary": True}], - "active": True, - "userName": "jsmith", - "externalId": "1", - "name": { - "familyName": "Smith", - "givenName": "Jimmy", - }, - "fullName": "Jimmy Smith", - "emailOptIn": 0, - } - ), - ) - - assert resp.status_code == 200, f"Error response: {resp.content}" - - user = user_q.first() - - assert user is not None - assert user.email == "jsmith@example.com" - assert user.username == "jsmith" - assert user.first_name == "Jimmy" - assert user.last_name == "Smith" - assert user.profile.name == "Jimmy Smith" - assert user.profile.email_optin is False - - resp = staff_client.patch( - f"{reverse('scim:users')}/{user.profile.scim_id}", - content_type="application/scim+json", - data=json.dumps( - { - "schemas": [constants.SchemaURI.PATCH_OP], - "Operations": [ - { - "op": "replace", - # yes, the value we get from scim-for-keycloak is a JSON encoded string...inside JSON... - "value": json.dumps( - { - "schemas": [constants.SchemaURI.USER], - "emailOptIn": 1, - "fullName": "Billy Bob", - "name": { - "givenName": "Billy", - "familyName": "Bob", - }, - } - ), - } - ], - } - ), - ) - - assert resp.status_code == 200, f"Error response: {resp.content}" - - user = user_q.first() - - assert user is not None - assert user.email == "jsmith@example.com" - assert user.username == "jsmith" - assert user.first_name == "Billy" - assert user.last_name == "Bob" - assert user.profile.name == "Billy Bob" - assert user.profile.email_optin is True diff --git a/pyproject.toml b/pyproject.toml index d16a073a9a..1d4e9d6af6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ llama-index-llms-openai = "^0.3.12" llama-index-agent-openai = "^0.4.1" langchain-experimental = "^0.3.4" langchain-openai = "^0.3.2" +deepmerge = "^2.0" [tool.poetry.group.dev.dependencies] diff --git a/scim/README.md b/scim/README.md new file mode 100644 index 0000000000..017948ca5a --- /dev/null +++ b/scim/README.md @@ -0,0 +1,36 @@ +## SCIM + +## Prerequisites + +- You need the following a local [Keycloak](https://www.keycloak.org/) instance running. Note which major version you are running (should be at least 26.x). + - You should have custom user profile fields setup on your `olapps` realm: + - `fullName`: required, otherwise defaults + - `emailOptIn`: defaults + +## Install the scim-for-keycloak plugin + +Sign up for an account on https://scim-for-keycloak.de and follow the instructions here: https://scim-for-keycloak.de/documentation/installation/install + +## Configure SCIM + +In the SCIM admin console, do the following: + +### Configure Remote SCIM Provider + +- In django-admin, go to OAuth Toolkit and create a new access token +- Go to Remote SCIM Provider +- Click the `+` button +- Specify a base URL for your learn API backend: `http://:8063/scim/v2/` +- At the bottom of the page, click "Use default configuration" +- Add a new authentication method: + - Type: Long Life Bearer Token + - Bearer Token: the access token you created above +- On the Schemas tab, edit the User schema and add these custom attributes: + - Add a `fullName` attribute and set the Custom Attribute Name to `fullName` + - Add an attribute named `emailOptIn` with the following settings: + - Type: integer + - Custom Attribute Name: `emailOptIn` +- On the Realm Assignments tab, assign to the `olapps` realm +- Go to the Synchronization tab and perform one: + - Identifier attribute: email + - Synchronization strategy: Search and Bulk diff --git a/profiles/scim/__init__.py b/scim/__init__.py similarity index 100% rename from profiles/scim/__init__.py rename to scim/__init__.py diff --git a/profiles/scim/adapters.py b/scim/adapters.py similarity index 84% rename from profiles/scim/adapters.py rename to scim/adapters.py index 94c4b012f9..6983d4c480 100644 --- a/profiles/scim/adapters.py +++ b/scim/adapters.py @@ -44,6 +44,7 @@ class LearnSCIMUser(SCIMUser): ("active", None, None): "is_active", ("name", "givenName", None): "first_name", ("name", "familyName", None): "last_name", + ("userName", None, None): "username", } IGNORED_PATHS = { @@ -158,7 +159,7 @@ def delete(self): """ self.obj.is_active = False self.obj.save() - logger.info("Deactivated user id %i", self.obj.user.id) + logger.info("Deactivated user id %i", self.obj.id) def handle_add( self, @@ -193,7 +194,7 @@ def parse_scim_for_keycloak_payload(self, payload: str) -> dict: if isinstance(value, dict): for nested_key, nested_value in value.items(): - result[f"{key}.{nested_key}"] = nested_value + result[self.split_path(f"{key}.{nested_key}")] = nested_value else: result[key] = value @@ -202,11 +203,32 @@ def parse_scim_for_keycloak_payload(self, payload: str) -> dict: def parse_path_and_values( self, path: Optional[str], value: Union[str, list, dict] ) -> list: - if not path and isinstance(value, str): + """Parse the incoming value(s)""" + if isinstance(value, str): # scim-for-keycloak sends this as a noncompliant JSON-encoded string - value = self.parse_scim_for_keycloak_payload(value) + if path is None: + val = json.loads(value) + else: + msg = "Called with a non-null path and a str value" + raise ValueError(msg) + else: + val = value + + results = [] + + for attr_path, attr_value in val.items(): + if isinstance(attr_value, dict): + # nested object, we want to recursively flatten it to `first.second` + results.extend(self.parse_path_and_values(attr_path, attr_value)) + else: + flattened_path = ( + f"{path}.{attr_path}" if path is not None else attr_path + ) + new_path = self.split_path(flattened_path) + new_value = attr_value + results.append((new_path, new_value)) - return super().parse_path_and_values(path, value) + return results def handle_replace( self, @@ -219,22 +241,20 @@ def handle_replace( All operations happen within an atomic transaction. """ + if not isinstance(value, dict): # Restructure for use in loop below. value = {path: value} for nested_path, nested_value in (value or {}).items(): if nested_path.first_path in self.ATTR_MAP: - setattr( - self.obj, self.ATTR_MAP.get(nested_path.first_path), nested_value - ) - + setattr(self.obj, self.ATTR_MAP[nested_path.first_path], nested_value) elif nested_path.first_path == ("fullName", None, None): self.obj.profile.name = nested_value elif nested_path.first_path == ("emailOptIn", None, None): self.obj.profile.email_optin = nested_value == 1 elif nested_path.first_path == ("emails", None, None): - self.parse_emails(value) + self.parse_emails(nested_value) elif nested_path.first_path not in self.IGNORED_PATHS: logger.debug( "Ignoring SCIM update for path: %s", nested_path.first_path diff --git a/scim/apps.py b/scim/apps.py new file mode 100644 index 0000000000..7cfdae6bfa --- /dev/null +++ b/scim/apps.py @@ -0,0 +1,5 @@ +from django.apps import AppConfig + + +class ScimConfig(AppConfig): + name = "scim" diff --git a/scim/config.py b/scim/config.py new file mode 100644 index 0000000000..49da726497 --- /dev/null +++ b/scim/config.py @@ -0,0 +1,13 @@ +from django_scim.models import SCIMServiceProviderConfig + + +class LearnSCIMServiceProviderConfig(SCIMServiceProviderConfig): + """Custom provider config""" + + def to_dict(self): + result = super().to_dict() + + result["bulk"]["supported"] = True + result["filter"]["supported"] = True + + return result diff --git a/scim/constants.py b/scim/constants.py new file mode 100644 index 0000000000..c51546dabe --- /dev/null +++ b/scim/constants.py @@ -0,0 +1,7 @@ +"""SCIM constants""" + + +class SchemaURI: + BULK_REQUEST = "urn:ietf:params:scim:api:messages:2.0:BulkRequest" + + BULK_RESPONSE = "urn:ietf:params:scim:api:messages:2.0:BulkResponse" diff --git a/profiles/scim/filters.py b/scim/filters.py similarity index 51% rename from profiles/scim/filters.py rename to scim/filters.py index 699752be01..a971e7fe9f 100644 --- a/profiles/scim/filters.py +++ b/scim/filters.py @@ -2,16 +2,27 @@ from django_scim.filters import UserFilterQuery +from scim.parser.queries.sql import PatchedSQLQuery + class LearnUserFilterQuery(UserFilterQuery): """Filters for users""" + query_class = PatchedSQLQuery + attr_map: dict[tuple[Optional[str], Optional[str], Optional[str]], str] = { ("userName", None, None): "auth_user.username", + ("emails", "value", None): "auth_user.email", ("active", None, None): "auth_user.is_active", - ("name", "formatted", None): "profiles_profile.name", + ("fullName", None, None): "profiles_profile.name", + ("name", "givenName", None): "auth_user.first_name", + ("name", "familyName", None): "auth_user.last_name", } joins: tuple[str, ...] = ( "INNER JOIN profiles_profile ON profiles_profile.user_id = auth_user.id", ) + + @classmethod + def search(cls, filter_query, request=None): + return super().search(filter_query, request=request) diff --git a/scim/forms.py b/scim/forms.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/scim/parser/__init__.py b/scim/parser/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/scim/parser/queries/__init__.py b/scim/parser/queries/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/scim/parser/queries/sql.py b/scim/parser/queries/sql.py new file mode 100644 index 0000000000..d6d8fcde52 --- /dev/null +++ b/scim/parser/queries/sql.py @@ -0,0 +1,15 @@ +from scim2_filter_parser.lexer import SCIMLexer +from scim2_filter_parser.parser import SCIMParser +from scim2_filter_parser.queries.sql import SQLQuery + +from scim.parser.transpilers.sql import PatchedTranspiler + + +class PatchedSQLQuery(SQLQuery): + """Patched SQLQuery to use the patch transpiler""" + + def build_where_sql(self): + self.token_stream = SCIMLexer().tokenize(self.filter) + self.ast = SCIMParser().parse(self.token_stream) + self.transpiler = PatchedTranspiler(self.attr_map) + self.where_sql, self.params_dict = self.transpiler.transpile(self.ast) diff --git a/scim/parser/transpilers/__init__.py b/scim/parser/transpilers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/scim/parser/transpilers/sql.py b/scim/parser/transpilers/sql.py new file mode 100644 index 0000000000..e095cce2b3 --- /dev/null +++ b/scim/parser/transpilers/sql.py @@ -0,0 +1,20 @@ +import string + +from scim2_filter_parser.transpilers.sql import Transpiler + + +class PatchedTranspiler(Transpiler): + """ + This is a fixed version of the upstream sql transpiler that converts SCIM queries + to SQL queries. + + Specifically it fixes the upper limit of 26 conditions for the search endpoint due + to the upstream library using the ascii alphabet for query parameters. + """ + + def get_next_id(self): + """Convert the current index to a base26 string""" + chars = string.ascii_lowercase + index = len(self.params) + + return (chars[-1] * int(index / len(chars))) + chars[index % len(chars)] diff --git a/scim/urls.py b/scim/urls.py new file mode 100644 index 0000000000..88221d982e --- /dev/null +++ b/scim/urls.py @@ -0,0 +1,17 @@ +"""URL configurations for profiles""" + +from django.urls import include, re_path + +from scim import views + +ol_scim_urls = ( + [ + re_path("^Bulk$", views.BulkView.as_view(), name="bulk"), + ], + "ol-scim", +) + +urlpatterns = [ + re_path("^scim/v2/", include(ol_scim_urls)), + re_path("^scim/v2/", include("django_scim.urls", namespace="scim")), +] diff --git a/scim/utils.py b/scim/utils.py new file mode 100644 index 0000000000..d6d0a4aee1 --- /dev/null +++ b/scim/utils.py @@ -0,0 +1,6 @@ +"""Utils""" + + +def is_authenticated_predicate(user): + """Verify that the user is active and staff""" + return user.is_authenticated and user.is_active and user.is_staff diff --git a/scim/views.py b/scim/views.py new file mode 100644 index 0000000000..72dcc09a65 --- /dev/null +++ b/scim/views.py @@ -0,0 +1,160 @@ +"""SCIM view customizations""" + +import copy +import json +import logging +from http import HTTPStatus +from urllib.parse import urlparse + +from django.http import HttpRequest, HttpResponse +from django.urls import Resolver404, resolve +from django_scim import constants as djs_constants +from django_scim import exceptions +from django_scim import views as djs_views + +from scim import constants + +log = logging.getLogger() + + +class InMemoryHttpRequest(HttpRequest): + """ + A spoofed HttpRequest that only exists in-memory. + It does not implement all features of HttpRequest and is only used + for the bulk SCIM operations here so we can reuse view implementations. + """ + + def __init__(self, request, path, method, body): + super().__init__() + + self.META = copy.deepcopy( + { + key: value + for key, value in request.META.items() + if not key.startswith(("wsgi", "uwsgi")) + } + ) + self.path = path + self.method = method + self.content_type = djs_constants.SCIM_CONTENT_TYPE + + # normally HttpRequest would read this in, but we already have the value + self._body = body + + +class BulkView(djs_views.SCIMView): + http_method_names = ["post"] + + def post(self, request, *args, **kwargs): # noqa: ARG002 + body = self.load_body(request.body) + + if body.get("schemas") != [constants.SchemaURI.BULK_REQUEST]: + msg = "Invalid schema uri. Must be SearchRequest." + raise exceptions.BadRequestError(msg) + + fail_on_errors = body.get("failOnErrors", None) + + if fail_on_errors is not None and isinstance(int, fail_on_errors): + msg = "Invalid failOnErrors. Must be an integer." + raise exceptions.BaseRequestError(msg) + + operations = body.get("Operations") + + results = self._attempt_operations(request, operations, fail_on_errors) + + response = { + "schemas": [constants.SchemaURI.BULK_RESPONSE], + "Operations": results, + } + + content = json.dumps(response) + + return HttpResponse( + content=content, + content_type=djs_constants.SCIM_CONTENT_TYPE, + status=HTTPStatus.OK, + ) + + def _attempt_operations(self, request, operations, fail_on_errors): + """Attempt to run the operations that were passed""" + responses = [] + num_errors = 0 + + for operation in operations: + # per-spec,if we've hit the error threshold stop processing and return + if fail_on_errors is not None and num_errors >= fail_on_errors: + break + + op_response = self._attempt_operation(request, operation) + + # if the operation returned a non-2xx status code, record it as a failure + if int(op_response.get("status")) >= HTTPStatus.MULTIPLE_CHOICES: + num_errors += 1 + + responses.append(op_response) + + return responses + + def _attempt_operation(self, bulk_request, operation): + """Attempt an operation as part of a bulk request""" + + method = operation.get("method") + bulk_id = operation.get("bulkId") + path = operation.get("path") + data = operation.get("data") + + try: + url_match = resolve(path, urlconf="django_scim.urls") + except Resolver404: + return self._operation_error( + bulk_id, + HTTPStatus.NOT_IMPLEMENTED, + "Endpoint is not supported for /Bulk", + ) + + # this is an ephemeral request not tied to the real request directly + op_request = InMemoryHttpRequest( + bulk_request, path, method, json.dumps(data).encode(djs_constants.ENCODING) + ) + + op_response = url_match.func(op_request, *url_match.args, **url_match.kwargs) + result = { + "method": method, + "bulkId": bulk_id, + "status": str(op_response.status_code), + } + + location = None + + if op_response.status_code >= HTTPStatus.BAD_REQUEST and op_response.content: + result["response"] = json.loads(op_response.content.decode("utf-8")) + + location = op_response.headers.get("Location", None) + + if location is not None: + result["location"] = location + # this is a custom field that the scim-for-keycloak plugin requires + try: + path = urlparse(location).path + location_match = resolve(path) + # this URL will be something like /scim/v2/Users/12345 + # resolving it gives the uuid + result["id"] = location_match.kwargs["uuid"] + except Resolver404: + log.exception("Unable to resolve resource url: %s", location) + + return result + + def _operation_error(self, method, bulk_id, status_code, detail): + """Return a failure response""" + status_code = str(status_code) + return { + "method": method, + "status": status_code, + "bulkId": bulk_id, + "response": { + "schemas": [djs_constants.SchemaURI.ERROR], + "status": status_code, + "detail": detail, + }, + } diff --git a/scim/views_test.py b/scim/views_test.py new file mode 100644 index 0000000000..c1ee44f9db --- /dev/null +++ b/scim/views_test.py @@ -0,0 +1,415 @@ +import itertools +import json +import operator +import random +from collections.abc import Callable +from functools import reduce +from types import SimpleNamespace + +import pytest +from anys import ANY_STR +from deepmerge import always_merger +from django.contrib.auth import get_user_model +from django.test import Client +from django.urls import reverse +from django_scim import constants as djs_constants + +from main.factories import UserFactory +from scim import constants + +User = get_user_model() + + +@pytest.fixture +def scim_client(staff_user): + """Test client for scim""" + client = Client() + client.force_login(staff_user) + return client + + +def test_scim_user_post(scim_client): + """Test that we can create a user via SCIM API""" + user_q = User.objects.filter(profile__scim_external_id="1") + assert not user_q.exists() + + resp = scim_client.post( + reverse("scim:users"), + content_type="application/scim+json", + data=json.dumps( + { + "schemas": [djs_constants.SchemaURI.USER], + "emails": [{"value": "jdoe@example.com", "primary": True}], + "active": True, + "userName": "jdoe", + "externalId": "1", + "name": { + "familyName": "Doe", + "givenName": "John", + }, + "fullName": "John Smith Doe", + "emailOptIn": 1, + } + ), + ) + + assert resp.status_code == 201, f"Error response: {resp.content}" + + user = user_q.first() + + assert user is not None + assert user.email == "jdoe@example.com" + assert user.username == "jdoe" + assert user.first_name == "John" + assert user.last_name == "Doe" + assert user.profile.name == "John Smith Doe" + assert user.profile.email_optin is True + + +def test_scim_user_put(scim_client): + """Test that a user can be updated via PUT""" + user = UserFactory.create() + + resp = scim_client.put( + f"{reverse('scim:users')}/{user.profile.scim_id}", + content_type="application/scim+json", + data=json.dumps( + { + "schemas": [djs_constants.SchemaURI.USER], + "emails": [{"value": "jsmith@example.com", "primary": True}], + "active": True, + "userName": "jsmith", + "externalId": "1", + "name": { + "familyName": "Smith", + "givenName": "Jimmy", + }, + "fullName": "Jimmy Smith", + "emailOptIn": 0, + } + ), + ) + + assert resp.status_code == 200, f"Error response: {resp.content}" + + user.refresh_from_db() + + assert user.email == "jsmith@example.com" + assert user.username == "jsmith" + assert user.first_name == "Jimmy" + assert user.last_name == "Smith" + assert user.profile.name == "Jimmy Smith" + assert user.profile.email_optin is False + + +def test_scim_user_patch(scim_client): + """Test that a user can be updated via PATCH""" + user = UserFactory.create() + + resp = scim_client.patch( + f"{reverse('scim:users')}/{user.profile.scim_id}", + content_type="application/scim+json", + data=json.dumps( + { + "schemas": [djs_constants.SchemaURI.PATCH_OP], + "Operations": [ + { + "op": "replace", + # yes, the value we get from scim-for-keycloak is a JSON encoded string...inside JSON... + "value": json.dumps( + { + "schemas": [djs_constants.SchemaURI.USER], + "emailOptIn": 1, + "fullName": "Billy Bob", + "name": { + "givenName": "Billy", + "familyName": "Bob", + }, + } + ), + } + ], + } + ), + ) + + assert resp.status_code == 200, f"Error response: {resp.content}" + + user_updated = User.objects.get(pk=user.id) + + assert user_updated.email == user.email + assert user_updated.username == user.username + assert user_updated.first_name == "Billy" + assert user_updated.last_name == "Bob" + assert user_updated.profile.name == "Billy Bob" + assert user_updated.profile.email_optin is True + + +def _user_to_scim_payload(user): + """Test util to serialize a user to a SCIM representation""" + return { + "schemas": [djs_constants.SchemaURI.USER], + "emails": [{"value": user.email, "primary": True}], + "userName": user.username, + "emailOptIn": 1 if user.profile.email_optin else 0, + "fullName": user.profile.name, + "name": { + "givenName": user.first_name, + "familyName": user.last_name, + }, + } + + +USER_FIELD_TYPES: dict[str, type] = { + "username": str, + "email": str, + "first_name": str, + "last_name": str, + "profile.email_optin": bool, + "profile.name": str, +} + +USER_FIELDS_TO_SCIM: dict[str, Callable] = { + "username": lambda value: {"userName": value}, + "email": lambda value: {"emails": [{"value": value, "primary": True}]}, + "first_name": lambda value: {"name": {"givenName": value}}, + "last_name": lambda value: {"name": {"familyName": value}}, + "profile.email_optin": lambda value: {"emailOptIn": 1 if value else 0}, + "profile.name": lambda value: {"fullName": value}, +} + + +def _post_operation(data, bulk_id_gen): + """Operation for a bulk POST""" + bulk_id = str(next(bulk_id_gen)) + return SimpleNamespace( + payload={ + "method": "post", + "bulkId": bulk_id, + "path": "/Users", + "data": _user_to_scim_payload(data), + }, + user=None, + expected_user_state=data, + expected_response={ + "method": "post", + "location": ANY_STR, + "bulkId": bulk_id, + "status": "201", + "id": ANY_STR, + }, + ) + + +def _put_operation(user, data, bulk_id_gen): + """Operation for a bulk PUT""" + bulk_id = str(next(bulk_id_gen)) + return SimpleNamespace( + payload={ + "method": "put", + "bulkId": bulk_id, + "path": f"/Users/{user.profile.scim_id}", + "data": _user_to_scim_payload(data), + }, + user=user, + expected_user_state=data, + expected_response={ + "method": "put", + "location": ANY_STR, + "bulkId": bulk_id, + "status": "200", + "id": str(user.profile.scim_id), + }, + ) + + +def _patch_operation(user, data, fields_to_patch, bulk_id_gen): + """Operation for a bulk PUT""" + + def _expected_patch_value(field): + field_getter = operator.attrgetter(field) + return field_getter(data if field in fields_to_patch else user) + + bulk_id = str(next(bulk_id_gen)) + field_updates = [ + mk_scim_value(operator.attrgetter(user_path)(data)) + for user_path, mk_scim_value in USER_FIELDS_TO_SCIM.items() + if user_path in fields_to_patch + ] + + return SimpleNamespace( + payload={ + "method": "patch", + "bulkId": bulk_id, + "path": f"/Users/{user.profile.scim_id}", + "data": { + "schemas": [djs_constants.SchemaURI.PATCH_OP], + "Operations": [ + { + "op": "replace", + "value": reduce(always_merger.merge, field_updates, {}), + } + ], + }, + }, + user=user, + expected_user_state=SimpleNamespace( + email=_expected_patch_value("email"), + username=_expected_patch_value("username"), + first_name=_expected_patch_value("first_name"), + last_name=_expected_patch_value("last_name"), + profile=SimpleNamespace( + name=_expected_patch_value("profile.name"), + email_optin=_expected_patch_value("profile.email_optin"), + ), + ), + expected_response={ + "method": "patch", + "location": ANY_STR, + "bulkId": bulk_id, + "status": "200", + "id": str(user.profile.scim_id), + }, + ) + + +def _delete_operation(user, bulk_id_gen): + """Operation for a bulk DELETE""" + bulk_id = str(next(bulk_id_gen)) + return SimpleNamespace( + payload={ + "method": "delete", + "bulkId": bulk_id, + "path": f"/Users/{user.profile.scim_id}", + }, + user=user, + expected_user_state=None, + expected_response={ + "method": "delete", + "bulkId": bulk_id, + "status": "204", + }, + ) + + +@pytest.fixture +def bulk_test_data(): + """Test data for the /Bulk API tests""" + existing_users = UserFactory.create_batch(500) + remaining_users = set(existing_users) + + users_to_put = random.sample(sorted(remaining_users, key=lambda user: user.id), 100) + remaining_users = remaining_users - set(users_to_put) + + users_to_patch = random.sample( + sorted(remaining_users, key=lambda user: user.id), 100 + ) + remaining_users = remaining_users - set(users_to_patch) + + users_to_delete = random.sample( + sorted(remaining_users, key=lambda user: user.id), 100 + ) + remaining_users = remaining_users - set(users_to_delete) + + user_post_data = UserFactory.build_batch(100) + user_put_data = UserFactory.build_batch(len(users_to_put)) + user_patch_data = UserFactory.build_batch(len(users_to_patch)) + + bulk_id_gen = itertools.count() + + post_operations = [_post_operation(data, bulk_id_gen) for data in user_post_data] + put_operations = [ + _put_operation(user, data, bulk_id_gen) + for user, data in zip(users_to_put, user_put_data) + ] + patch_operations = [ + _patch_operation(user, patch_data, fields_to_patch, bulk_id_gen) + for user, patch_data, fields_to_patch in [ + ( + user, + patch_data, + # random number of field updates + list( + random.sample( + list(USER_FIELDS_TO_SCIM.keys()), + random.randint(1, len(USER_FIELDS_TO_SCIM.keys())), # noqa: S311 + ) + ), + ) + for user, patch_data in zip(users_to_patch, user_patch_data) + ] + ] + delete_operations = [ + _delete_operation(user, bulk_id_gen) for user in users_to_delete + ] + + operations = [ + *post_operations, + *patch_operations, + *put_operations, + *delete_operations, + ] + random.shuffle(operations) + + return SimpleNamespace( + existing_users=existing_users, + remaining_users=remaining_users, + post_operations=post_operations, + patch_operations=patch_operations, + put_operations=put_operations, + delete_operations=delete_operations, + operations=operations, + ) + + +def test_bulk_post(scim_client, bulk_test_data): + """Verify that bulk operations work as expected""" + user_count = User.objects.count() + + resp = scim_client.post( + reverse("ol-scim:bulk"), + content_type="application/scim+json", + data=json.dumps( + { + "schemas": [constants.SchemaURI.BULK_REQUEST], + "Operations": [ + operation.payload for operation in bulk_test_data.operations + ], + } + ), + ) + + assert resp.status_code == 200 + + # singular user is the staff user + assert User.objects.count() == user_count + len(bulk_test_data.post_operations) + + results_by_bulk_id = { + op_result["bulkId"]: op_result for op_result in resp.json()["Operations"] + } + + for operation in bulk_test_data.operations: + assert ( + results_by_bulk_id[operation.payload["bulkId"]] + == operation.expected_response + ) + + if operation in bulk_test_data.delete_operations: + user = User.objects.get(id=operation.user.id) + assert not user.is_active + else: + if operation in bulk_test_data.post_operations: + user = User.objects.get(username=operation.expected_user_state.username) + else: + user = User.objects.get(id=operation.user.id) + + for key, key_type in USER_FIELD_TYPES.items(): + attr_getter = operator.attrgetter(key) + + actual_value = attr_getter(user) + expected_value = attr_getter(operation.expected_user_state) + + if key_type is bool or key_type is None: + assert actual_value is expected_value + else: + assert actual_value == expected_value