diff --git a/openapi.yaml b/openapi.yaml index 629bf406..b24a811c 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -172,12 +172,20 @@ paths: description: Number of results to return per page. schema: type: integer + - in: query + name: name + schema: + type: string - name: offset required: false in: query description: The initial index from which to return the results. schema: type: integer + - in: query + name: system__slug + schema: + type: string tags: - product security: @@ -328,36 +336,18 @@ components: id: type: integer readOnly: true - deleted_on: - type: string - format: date-time - readOnly: true - nullable: true - deleted_by_cascade: - type: boolean - readOnly: true - created_on: - type: string - format: date-time - readOnly: true - updated_on: - type: string - format: date-time - readOnly: true name: type: string maxLength: 255 - description: + slug: type: string - api_key: + nullable: true + maxLength: 80 + description: type: string required: - - created_on - - deleted_by_cascade - - deleted_on - id - name - - updated_on PaginatedIntegratedSystemList: type: object properties: @@ -405,28 +395,14 @@ components: id: type: integer readOnly: true - deleted_on: - type: string - format: date-time - readOnly: true - nullable: true - deleted_by_cascade: - type: boolean - readOnly: true - created_on: - type: string - format: date-time - readOnly: true - updated_on: - type: string - format: date-time - readOnly: true name: type: string maxLength: 255 - description: + slug: type: string - api_key: + nullable: true + maxLength: 80 + description: type: string PatchedProduct: type: object diff --git a/system_meta/management/commands/generate_test_data.py b/system_meta/management/commands/generate_test_data.py new file mode 100644 index 00000000..7931b28e --- /dev/null +++ b/system_meta/management/commands/generate_test_data.py @@ -0,0 +1,234 @@ +""" +Adds some test data to the system. This includes three IntegratedSystems with three +products each. + +Ignoring A003 because "help" is valid for argparse. +Ignoring S311 because it's complaining about the faker package. +""" +# ruff: noqa: A003, S311 + +import random +import uuid +from decimal import Decimal + +import faker +from django.core.management import BaseCommand +from django.core.management.base import CommandParser +from django.db import transaction + +from system_meta.models import IntegratedSystem, Product + + +def get_input(text): + """Wrap the internal input function so we can test it later.""" + + return input(text) + + +def fake_courseware_id(courseware_type: str, **kwargs) -> str: + """ + Generate a fake courseware id. + + Courseware IDs generally are in the format: + -v1:+(+) + + Type is either "course" or "program", depending on what you specify. School ID is + one of "MITx", "MITxT", "edX", "xPRO", or "Sample". Courseware ID is a set of + numbers: a number < 100, a number < 1000 with a leading zero, and an optional + number < 10, separated by periods. Courseware ID is followed by an "x". This + should be pretty like the IDs that are on MITx Online now (but pretty unlike the + xPRO ones, which usually use a text courseware ID, but that's fine since these + are fake). + + Arguments: + - courseware_type (str): "course" or "program"; the type of + courseware id to generate. + + Keyword Arguments: + - include_run_tag (bool): include the run tag. Defaults to False. + + Returns: + - str: The generated courseware id, in the normal format. + """ + fake = faker.Faker() + + school_id = random.choice(["MITx", "MITxT", "edX", "xPRO", "Sample"]) + courseware_id = f"{random.randint(0, 99)}.{random.randint(0, 999):03d}" + courseware_type = courseware_type.lower() + optional_third_digit = random.randint(0, 9) if fake.boolean() else "" + optional_run_tag = ( + f"+{random.randint(1,3)}T{fake.date_this_decade().year}" + if kwargs.get("include_run_tag", False) + else "" + ) + + return ( + f"{courseware_type}-v1:{school_id}+{courseware_id}" + f"{optional_third_digit}x{optional_run_tag}" + ) + + +class Command(BaseCommand): + """Adds some test data to the system.""" + + def add_arguments(self, parser: CommandParser) -> None: + """Add arguments to the command parser.""" + parser.add_argument( + "--remove", + action="store_true", + help="Remove the test data. This is potentially dangerous.", + ) + + parser.add_argument( + "--only-systems", + action="store_true", + help="Only add test systems. Will add two active and one inactive system.", + ) + + parser.add_argument( + "--only-products", + action="store_true", + help="Only add test products.", + ) + + parser.add_argument( + "--system-slug", + type=str, + help=( + "The slug of the system to add products to." + " Only used with --only-products." + ), + nargs="?", + ) + + def add_test_systems(self) -> None: + """Add the test systems.""" + max_systems = 3 + for i in range(1, max_systems + 1): + system = IntegratedSystem.objects.create( + name=f"Test System {i}", + description=f"Test System {i} description.", + api_key=uuid.uuid4(), + ) + self.stdout.write(f"Created system {system.name} - {system.slug}") + + def add_test_products(self, system_slug: str) -> None: + """Add the test products to the specified system.""" + self.stdout.write(f"Creating test products for {system_slug}") + + if not IntegratedSystem.objects.filter(slug=system_slug).exists(): + self.stdout.write( + self.style.ERROR(f"Integrated system {system_slug} does not exist.") + ) + return + + system = IntegratedSystem.objects.get(slug=system_slug) + + max_products = 3 + for i in range(1, max_products + 1): + product_sku = fake_courseware_id("course", include_run_tag=True) + product = Product.objects.create( + name=f"Test Product {i}", + description=f"Test Product {i} description.", + sku=product_sku, + system=system, + price=Decimal(random.random() * 10000).quantize(Decimal("0.01")), + system_data={ + "courserun": product_sku, + "program": fake_courseware_id("program"), + }, + ) + self.stdout.write(f"Created product {product.id} - {product.sku}") + + def remove_test_data(self) -> None: + """Remove the test data.""" + + test_systems = ( + IntegratedSystem.all_objects.prefetch_related("products") + .filter(name__startswith="Test System") + .all() + ) + + self.stdout.write( + self.style.WARNING("This command will remove these systems and products:") + ) + + for system in test_systems: + self.stdout.write( + self.style.WARNING(f"System: {system.name} ({system.id})") + ) + + for product in system.products.all(): + self.stdout.write( + self.style.WARNING(f"\tProduct: {product.name} ({product.id})") + ) + + self.stdout.write( + self.style.WARNING( + "This will ACTUALLY DELETE these records." + " Are you sure you want to do this?" + ) + ) + + if get_input("Type 'yes' to continue: ") != "yes": + self.stdout.write(self.style.ERROR("Aborting.")) + return + + for system in test_systems: + Product.all_objects.filter( + pk__in=[product.id for product in system.products.all()] + ).delete() + IntegratedSystem.all_objects.filter(pk=system.id).delete() + + self.stdout.write(self.style.SUCCESS("Test data removed.")) + + def handle(self, *args, **options) -> None: # noqa: ARG002 + """Handle the command.""" + remove = options["remove"] + only_systems = options["only_systems"] + only_products = options["only_products"] + systems = [] + + with transaction.atomic(): + if remove: + self.remove_test_data() + return + + if not only_products: + self.add_test_systems() + systems = [ + system.slug + for system in ( + IntegratedSystem.all_objects.filter( + name__startswith="Test System" + ).all() + ) + ] + + if only_systems: + IntegratedSystem.objects.filter( + name__startswith="Test System" + ).last().delete() + return + + if only_products: + if not options["system_slug"] or len(options["system_slug"]) == 0: + self.stdout.write( + self.style.ERROR( + "You must specify a system when using --only-products." + ) + ) + return + else: + systems = [options["system_slug"]] + + self.stdout.write(f"we are creating products now {systems}") + + [self.add_test_products(system) for system in systems] + + if not only_products: + IntegratedSystem.objects.filter( + name__startswith="Test System" + ).last().delete() + + return diff --git a/system_meta/management/tests/generate_test_data_test.py b/system_meta/management/tests/generate_test_data_test.py new file mode 100644 index 00000000..9729fa3a --- /dev/null +++ b/system_meta/management/tests/generate_test_data_test.py @@ -0,0 +1,165 @@ +"""Tests for the manage_product command""" +# ruff: noqa: W605 + +import re +from io import StringIO + +import faker +import pytest +from django.core.management import call_command + +from system_meta.factories import IntegratedSystemFactory, ProductFactory +from system_meta.management.commands.generate_test_data import fake_courseware_id +from system_meta.models import IntegratedSystem, Product + +pytestmark = pytest.mark.django_db +FAKE = faker.Factory.create() + + +@pytest.mark.parametrize( + ("include_run_tag", "program_or_course"), + [(True, True), (True, False), (False, True), (False, False)], +) +def test_generate_fake_courseware_id(include_run_tag, program_or_course): + """ + Tests that generate_fake_courseware_id generates a courseware id in the proper format. + """ + + courseware_id = fake_courseware_id( + "program" if program_or_course else "course", include_run_tag=include_run_tag + ) + + assert ( + courseware_id.startswith("program-v1:") + if program_or_course + else courseware_id.startswith("course-v1:") + ) + + if program_or_course: + if include_run_tag: + courseware_re = re.compile( + r"program-v1:(MITx|MITxT|edX|xPRO|Sample)\+\d{1,2}\.\d{3}\d?x\+\d{1,3}T\d{4}" + ) + else: + courseware_re = re.compile( + r"program-v1:(MITx|MITxT|edX|xPRO|Sample)\+\d{1,2}\.\d{3}\d?x" + ) + else: # noqa: PLR5501 + if include_run_tag: + courseware_re = re.compile( + r"course-v1:(MITx|MITxT|edX|xPRO|Sample)\+\d{1,2}\.\d{3}\d?x\+\d{1,3}T\d{4}" + ) + else: + courseware_re = re.compile( + r"course-v1:(MITx|MITxT|edX|xPRO|Sample)\+\d{1,2}\.\d{3}\d?x" + ) + + assert courseware_re.match(courseware_id) + + +def test_add_only_systems(): + """Test that only-systems adds test systems with no products.""" + out = StringIO() + + call_command("generate_test_data", only_systems=True, stdout=out) + + assert ( + IntegratedSystem.all_objects.filter(name__startswith="Test System").count() == 3 + ) + assert IntegratedSystem.objects.filter(name__startswith="Test System").count() == 2 + + +@pytest.mark.parametrize("system_type", ["normal", "skip", "wrong"]) +def test_add_only_products(system_type): + """ + Test that only-product adds products for the specified system. + + Args: + - system_type (str): type of system to try + normal is make a system normally (i.e. run the command correctly) + skip is don't make a system at all (should generate a specific error) + wrong is make a system, but specify something incorrect + """ + out = StringIO() + + if system_type == "skip": + call_command("generate_test_data", only_products=True, stdout=out) + + assert "You must specify a system" in out.getvalue() + elif system_type == "wrong": + bad_slug = "some nonsense" + call_command( + "generate_test_data", only_products=True, system_slug=bad_slug, stdout=out + ) + + assert f"Integrated system {bad_slug} does not exist" in out.getvalue() + else: + system = IntegratedSystemFactory.create() + + call_command( + "generate_test_data", + only_products=True, + system_slug=system.slug, + stdout=out, + ) + + assert "Created product" in out.getvalue() + + assert ( + Product.objects.filter( + name__startswith="Test Product", system=system + ).count() + == 3 + ) + + +def test_add_all(): + """Test that just running the command generates expected output.""" + + call_command("generate_test_data") + + assert ( + IntegratedSystem.all_objects.filter(name__startswith="Test System").count() == 3 + ) + assert IntegratedSystem.objects.filter(name__startswith="Test System").count() == 2 + + for system in IntegratedSystem.all_objects.all(): + assert ( + Product.all_objects.filter( + system=system, name__startswith="Test Product" + ).count() + == 3 + ) + + +def test_remove_test_data(mocker): + """Test that the remove test data command does the best it can to just remove test data.""" + + input_mock = mocker.patch( + "system_meta.management.commands.generate_test_data.get_input", + return_value="yes", + ) + out = StringIO() + + ProductFactory.create_batch(3) + before_system_count = IntegratedSystem.all_objects.count() + before_product_count = Product.all_objects.count() + + call_command("generate_test_data") + + assert before_system_count < IntegratedSystem.all_objects.count() + assert before_product_count < Product.all_objects.count() + + call_command("generate_test_data", "--remove", stdout=out) + + assert "Test data removed" in out.getvalue() + input_mock.assert_called() + + # Checking these two ways for a reason: + # - Make sure the test data wasn't soft deleted + # - Make sure it didn't soft-delete any of the products/systems we created + + assert IntegratedSystem.all_objects.count() == before_system_count + assert Product.all_objects.count() == before_product_count + assert IntegratedSystem.objects.count() == before_system_count + assert Product.objects.count() == before_product_count diff --git a/system_meta/migrations/0002_add_system_slug.py b/system_meta/migrations/0002_add_system_slug.py new file mode 100644 index 00000000..779b7331 --- /dev/null +++ b/system_meta/migrations/0002_add_system_slug.py @@ -0,0 +1,17 @@ +# Generated by Django 4.2.8 on 2024-01-11 16:49 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("system_meta", "0001_add_integrated_system_and_product_models"), + ] + + operations = [ + migrations.AddField( + model_name="integratedsystem", + name="slug", + field=models.CharField(blank=True, max_length=80, null=True, unique=True), + ), + ] diff --git a/system_meta/models.py b/system_meta/models.py index f2a17e1d..7d37c83d 100644 --- a/system_meta/models.py +++ b/system_meta/models.py @@ -8,6 +8,7 @@ from mitol.common.models import TimestampedModel from safedelete.managers import SafeDeleteManager from safedelete.models import SafeDeleteModel +from slugify import slugify from unified_ecommerce.utils import SoftDeleteActiveModel @@ -19,6 +20,7 @@ class IntegratedSystem(SafeDeleteModel, SoftDeleteActiveModel, TimestampedModel) """Represents an integrated system""" name = models.CharField(max_length=255, unique=True) + slug = models.CharField(max_length=80, unique=True, blank=True, null=True) description = models.TextField(blank=True) api_key = models.TextField(blank=True) @@ -29,6 +31,12 @@ def __str__(self): """Return string representation of the system""" return f"{self.name} ({self.id})" + def save(self, *args, **kwargs): + """Save the product. Create a slug if it doesn't already exist.""" + if not self.slug: + self.slug = slugify(self.name) + super().save(*args, **kwargs) + @reversion.register(exclude=("created_on", "updated_on")) class Product(SafeDeleteModel, SoftDeleteActiveModel, TimestampedModel): diff --git a/system_meta/serializers.py b/system_meta/serializers.py index 20dc2d51..256285c9 100644 --- a/system_meta/serializers.py +++ b/system_meta/serializers.py @@ -8,6 +8,16 @@ class IntegratedSystemSerializer(serializers.ModelSerializer): """Serializer for IntegratedSystem model.""" + class Meta: + """Meta class for serializer.""" + + model = IntegratedSystem + fields = ["id", "name", "slug", "description"] + + +class AdminIntegratedSystemSerializer(serializers.ModelSerializer): + """Serializer for IntegratedSystem model.""" + class Meta: """Meta class for serializer.""" diff --git a/system_meta/serializers_test.py b/system_meta/serializers_test.py index 2560aa79..c9afa24b 100644 --- a/system_meta/serializers_test.py +++ b/system_meta/serializers_test.py @@ -4,12 +4,25 @@ from system_meta.factories import IntegratedSystemFactory, ProductFactory from system_meta.models import IntegratedSystem, Product -from system_meta.serializers import IntegratedSystemSerializer, ProductSerializer +from system_meta.serializers import ( + AdminIntegratedSystemSerializer, + IntegratedSystemSerializer, + ProductSerializer, +) from unified_ecommerce.test_utils import BaseSerializerTest pytestmark = pytest.mark.django_db +class TestAdminIntegratedSystemSerializer(BaseSerializerTest): + """Tests for the IntegratedSystemSerializer.""" + + serializer_class = AdminIntegratedSystemSerializer + factory_class = IntegratedSystemFactory + model_class = IntegratedSystem + queryset = IntegratedSystem.all_objects + + class TestIntegratedSystemSerializer(BaseSerializerTest): """Tests for the IntegratedSystemSerializer.""" @@ -17,6 +30,7 @@ class TestIntegratedSystemSerializer(BaseSerializerTest): factory_class = IntegratedSystemFactory model_class = IntegratedSystem queryset = IntegratedSystem.all_objects + only_fields = ["id", "name", "slug", "description"] class TestProductSerializer(BaseSerializerTest): diff --git a/system_meta/views.py b/system_meta/views.py index cbda5d83..d5087045 100644 --- a/system_meta/views.py +++ b/system_meta/views.py @@ -2,7 +2,8 @@ import logging -from rest_framework import status, viewsets +from django_filters.rest_framework import DjangoFilterBackend +from rest_framework import status from rest_framework.decorators import ( api_view, authentication_classes, @@ -14,29 +15,51 @@ from rest_framework.response import Response from system_meta.models import IntegratedSystem, Product -from system_meta.serializers import IntegratedSystemSerializer, ProductSerializer +from system_meta.serializers import ( + AdminIntegratedSystemSerializer, + IntegratedSystemSerializer, + ProductSerializer, +) from unified_ecommerce.authentication import ( ApiGatewayAuthentication, ) +from unified_ecommerce.permissions import ( + IsAdminUserOrReadOnly, +) from unified_ecommerce.utils import decode_x_header +from unified_ecommerce.viewsets import AuthVariegatedModelViewSet log = logging.getLogger(__name__) -class IntegratedSystemViewSet(viewsets.ModelViewSet): +class IntegratedSystemViewSet(AuthVariegatedModelViewSet): """Viewset for IntegratedSystem model.""" queryset = IntegratedSystem.objects.all() - serializer_class = IntegratedSystemSerializer - permission_classes = (IsAuthenticated,) + read_write_serializer_class = AdminIntegratedSystemSerializer + read_only_serializer_class = IntegratedSystemSerializer + permission_classes = [ + IsAdminUserOrReadOnly, + ] -class ProductViewSet(viewsets.ModelViewSet): +class ProductViewSet(AuthVariegatedModelViewSet): """Viewset for Product model.""" queryset = Product.objects.all() serializer_class = ProductSerializer - permission_classes = (IsAuthenticated,) + permission_classes = [ + IsAdminUserOrReadOnly, + ] + read_write_serializer_class = ProductSerializer + read_only_serializer_class = ProductSerializer + filter_backends = [ + DjangoFilterBackend, + ] + filterset_fields = [ + "name", + "system__slug", + ] @api_view(["GET"]) diff --git a/system_meta/views_test.py b/system_meta/views_test.py index aa249b1d..92ee4cf6 100644 --- a/system_meta/views_test.py +++ b/system_meta/views_test.py @@ -10,25 +10,50 @@ IntegratedSystemFactory, ) from system_meta.models import IntegratedSystem, Product +from system_meta.serializers import ( + AdminIntegratedSystemSerializer, + IntegratedSystemSerializer, + ProductSerializer, +) from system_meta.views import IntegratedSystemViewSet, ProductViewSet -from unified_ecommerce.test_utils import BaseViewSetTest +from unified_ecommerce.test_utils import AuthVariegatedModelViewSetTest pytestmark = pytest.mark.django_db -class TestIntegratedSystemViewSet(BaseViewSetTest): +class TestIntegratedSystemViewSet(AuthVariegatedModelViewSetTest): """Tests for the IntegratedSystemViewSet.""" viewset_class = IntegratedSystemViewSet factory_class = ActiveIntegratedSystemFactory - queryset = IntegratedSystem.objects.all() + queryset = IntegratedSystem.all_objects.all() list_url = "/api/v0/meta/integrated_system/" object_url = "/api/v0/meta/integrated_system/{}/" - @pytest.mark.parametrize("is_logged_in", [True, False]) - @pytest.mark.parametrize("is_active_system", [True, False]) - def test_retrieve(self, is_logged_in, is_active_system, client, user_client): + read_only_serializer_class = IntegratedSystemSerializer + read_write_serializer_class = AdminIntegratedSystemSerializer + + @pytest.mark.parametrize( + ("is_logged_in", "use_staff_user", "is_active_system"), + [ + (True, False, True), + (True, True, True), + (False, False, True), + (True, False, False), + (True, True, False), + (False, False, False), + ], + ) + def test_retrieve( # noqa: PLR0913 + self, + is_active_system, + is_logged_in, + use_staff_user, + client, + user_client, + staff_client, + ): """ Test that the viewset can retrieve an object that is either active or possibly inactive. @@ -40,11 +65,30 @@ def test_retrieve(self, is_logged_in, is_active_system, client, user_client): else InactiveIntegratedSystemFactory ) - super().test_retrieve(is_logged_in, client, user_client) + super().test_retrieve( + is_logged_in, use_staff_user, client, user_client, staff_client + ) - @pytest.mark.parametrize("is_active_system", [True, False]) - @pytest.mark.parametrize("is_logged_in", [True, False]) - def test_update(self, is_active_system, is_logged_in, client, user_client): + @pytest.mark.parametrize( + ("is_logged_in", "use_staff_user", "is_active_system"), + [ + (True, False, True), + (True, True, True), + (False, False, True), + (True, False, False), + (True, True, False), + (False, False, False), + ], + ) + def test_update( # noqa: PLR0913 + self, + is_active_system, + is_logged_in, + use_staff_user, + client, + user_client, + staff_client, + ): """Test that the viewset can update an object.""" self.factory_class = ( ActiveIntegratedSystemFactory @@ -54,10 +98,10 @@ def test_update(self, is_active_system, is_logged_in, client, user_client): update_data = {"name": "Updated Name"} (instance, response) = super().test_update( - update_data, is_logged_in, client, user_client + update_data, is_logged_in, use_staff_user, client, user_client, staff_client ) - if is_logged_in: + if is_logged_in and use_staff_user: if not is_active_system: assert instance.name != update_data["name"] assert response.status_code == 404 @@ -68,21 +112,73 @@ def test_update(self, is_active_system, is_logged_in, client, user_client): else: assert instance.name != update_data["name"] - @pytest.mark.parametrize("is_logged_in", [True, False]) - def test_delete(self, is_logged_in, client, user_client): + @pytest.mark.parametrize( + ( + "is_logged_in", + "use_staff_user", + ), + [ + ( + True, + False, + ), + ( + True, + True, + ), + ( + False, + False, + ), + ], + ) + def test_delete( # noqa: PLR0913 + self, is_logged_in, use_staff_user, client, user_client, staff_client + ): """Test that the viewset can delete an object.""" - (instance, response) = super().test_delete(is_logged_in, client, user_client) - if is_logged_in: + self.queryset = IntegratedSystem.objects.all() + (instance, response) = super().test_delete( + is_logged_in, use_staff_user, client, user_client, staff_client + ) + + if is_logged_in and use_staff_user: instance.refresh_from_db() assert response.status_code == 204 assert not instance.is_active else: assert instance.is_active - @pytest.mark.parametrize("is_logged_in", [True, False]) + @pytest.mark.parametrize( + ( + "is_logged_in", + "use_staff_user", + ), + [ + ( + True, + False, + ), + ( + True, + True, + ), + ( + False, + False, + ), + ], + ) @pytest.mark.parametrize("with_bad_data", [True, False]) - def test_create(self, with_bad_data, is_logged_in, client, user_client): + def test_create( # noqa: PLR0913 + self, + with_bad_data, + is_logged_in, + use_staff_user, + client, + user_client, + staff_client, + ): """Test that the viewset can create an object.""" create_data = { "description": "a description", @@ -92,20 +188,21 @@ def test_create(self, with_bad_data, is_logged_in, client, user_client): if not with_bad_data: create_data["name"] = "System Name" - response = super().test_create(create_data, is_logged_in, client, user_client) + response = super().test_create( + create_data, is_logged_in, use_staff_user, client, user_client, staff_client + ) - if is_logged_in: + if is_logged_in and use_staff_user: assert response.status_code == 201 if not with_bad_data else 400 assert ( IntegratedSystem.objects.filter(name="System Name").exists() is not with_bad_data ) else: - assert response.data["detail"].code == "not_authenticated" assert response.status_code == 403 -class TestProductViewSet(BaseViewSetTest): +class TestProductViewSet(AuthVariegatedModelViewSetTest): """Tests for the ProductViewSet.""" viewset_class = ProductViewSet @@ -115,9 +212,29 @@ class TestProductViewSet(BaseViewSetTest): list_url = "/api/v0/meta/product/" object_url = "/api/v0/meta/product/{}/" - @pytest.mark.parametrize("is_logged_in", [True, False]) - @pytest.mark.parametrize("is_active_product", [True, False]) - def test_retrieve(self, is_logged_in, is_active_product, client, user_client): + read_only_serializer_class = ProductSerializer + read_write_serializer_class = ProductSerializer + + @pytest.mark.parametrize( + ("is_logged_in", "use_staff_user", "is_active_product"), + [ + (True, False, True), + (True, True, True), + (False, False, True), + (True, False, False), + (True, True, False), + (False, False, False), + ], + ) + def test_retrieve( # noqa: PLR0913 + self, + is_active_product, + is_logged_in, + use_staff_user, + client, + user_client, + staff_client, + ): """ Test that the viewset can retrieve an object that is either active or possibly inactive. @@ -127,11 +244,30 @@ def test_retrieve(self, is_logged_in, is_active_product, client, user_client): ActiveProductFactory if is_active_product else InactiveProductFactory ) - super().test_retrieve(is_logged_in, client, user_client) + super().test_retrieve( + is_logged_in, use_staff_user, client, user_client, staff_client + ) - @pytest.mark.parametrize("is_active_product", [True, False]) - @pytest.mark.parametrize("is_logged_in", [True, False]) - def test_update(self, is_active_product, is_logged_in, client, user_client): + @pytest.mark.parametrize( + ("is_logged_in", "use_staff_user", "is_active_product"), + [ + (True, False, True), + (True, True, True), + (False, False, True), + (True, False, False), + (True, True, False), + (False, False, False), + ], + ) + def test_update( # noqa: PLR0913 + self, + is_active_product, + is_logged_in, + use_staff_user, + client, + user_client, + staff_client, + ): """Test that the viewset can update an object.""" self.factory_class = ( ActiveProductFactory if is_active_product else InactiveProductFactory @@ -139,10 +275,10 @@ def test_update(self, is_active_product, is_logged_in, client, user_client): update_data = {"name": "Updated Name"} (instance, response) = super().test_update( - update_data, is_logged_in, client, user_client + update_data, is_logged_in, use_staff_user, client, user_client, staff_client ) - if is_logged_in: + if is_logged_in and use_staff_user: if not is_active_product: assert instance.name != update_data["name"] assert response.status_code == 404 @@ -153,21 +289,72 @@ def test_update(self, is_active_product, is_logged_in, client, user_client): else: assert instance.name != update_data["name"] - @pytest.mark.parametrize("is_logged_in", [True, False]) - def test_delete(self, is_logged_in, client, user_client): + @pytest.mark.parametrize( + ( + "is_logged_in", + "use_staff_user", + ), + [ + ( + True, + False, + ), + ( + True, + True, + ), + ( + False, + False, + ), + ], + ) + def test_delete( # noqa: PLR0913 + self, is_logged_in, use_staff_user, client, user_client, staff_client + ): """Test that the viewset can delete an object.""" - (instance, response) = super().test_delete(is_logged_in, client, user_client) + self.queryset = Product.objects.all() + (instance, response) = super().test_delete( + is_logged_in, use_staff_user, client, user_client, staff_client + ) - if is_logged_in: + if is_logged_in and use_staff_user: instance.refresh_from_db() assert response.status_code == 204 assert not instance.is_active else: assert instance.is_active - @pytest.mark.parametrize("is_logged_in", [True, False]) + @pytest.mark.parametrize( + ( + "is_logged_in", + "use_staff_user", + ), + [ + ( + True, + False, + ), + ( + True, + True, + ), + ( + False, + False, + ), + ], + ) @pytest.mark.parametrize("with_bad_data", [True, False]) - def test_create(self, with_bad_data, is_logged_in, client, user_client): + def test_create( # noqa: PLR0913 + self, + with_bad_data, + is_logged_in, + use_staff_user, + client, + user_client, + staff_client, + ): """Test that the viewset can create an object.""" system = IntegratedSystemFactory.create() create_data = { @@ -180,11 +367,12 @@ def test_create(self, with_bad_data, is_logged_in, client, user_client): if not with_bad_data: create_data["description"] = "a description" - response = super().test_create(create_data, is_logged_in, client, user_client) + response = super().test_create( + create_data, is_logged_in, use_staff_user, client, user_client, staff_client + ) - if is_logged_in: + if is_logged_in and use_staff_user: assert response.status_code == 201 if not with_bad_data else 400 assert Product.objects.filter(name="New Name").exists() is not with_bad_data else: - assert response.data["detail"].code == "not_authenticated" assert response.status_code == 403 diff --git a/unified_ecommerce/permissions.py b/unified_ecommerce/permissions.py new file mode 100644 index 00000000..1c30574b --- /dev/null +++ b/unified_ecommerce/permissions.py @@ -0,0 +1,20 @@ +"""Custom DRF permissions.""" + +from rest_framework import permissions + + +class IsAdminUserOrReadOnly(permissions.BasePermission): + """Determines if the user owns the object""" + + def has_permission(self, request, view): # noqa: ARG002 + """ + Return True if the user is an admin user requesting a write operation, + or if the user is logged in. Otherwise, return False. + """ + + if request.method in permissions.SAFE_METHODS or ( + request.user.is_authenticated and request.user.is_staff + ): + return True + + return False diff --git a/unified_ecommerce/test_utils.py b/unified_ecommerce/test_utils.py index ef19d2f2..1631814e 100644 --- a/unified_ecommerce/test_utils.py +++ b/unified_ecommerce/test_utils.py @@ -13,6 +13,7 @@ from django.conf import settings from django.core.serializers import serialize from django.core.serializers.json import DjangoJSONEncoder +from django.http import HttpRequest from django.http.response import HttpResponse from rest_framework.renderers import JSONRenderer @@ -213,6 +214,27 @@ def make_timestamps_matchable(objs, **kwargs): ] +def generate_mocked_request(user): + """ + Generate a mocked request for test_process_cybersource_*. + + The RequestFactory misses some stuff, so instead just make a full-fat + HttpRequest and add the things in that we need. + + Args: + - user (User): The user to set in the request. + + Returns: + - HttpRequest: The mocked request. + """ + mocked_request = HttpRequest() + mocked_request.user = user + mocked_request.META["REMOTE_ADDR"] = "127.0.0.1" + mocked_request.META["HTTP_HOST"] = "localhost" + + return mocked_request + + class PickleableMock(Mock): """ A Mock that can be passed to pickle.dumps() @@ -232,12 +254,22 @@ class ViewSetNotConfiguredError(Exception): class BaseSerializerTest: - """Base class for serializer tests.""" + """ + Base class for serializer tests. + + Class variables: + - model_class (class): the model class to test + - serializer_class (class): the serializer class to test + - factory_class (class): the factory class to use for creating instances + - queryset (QuerySet): the queryset to use, if you need a custom one + - only_fields (list): a list of field names that should be included in the serialized output + """ model_class = None serializer_class = None factory_class = None queryset = None + only_fields = None def test_serialize(self): """Test that the serializer can serialize an instance.""" @@ -249,7 +281,11 @@ def test_serialize(self): instance_qs = self.model_class.objects.filter(pk=instance.pk) serializer = self.serializer_class(instance) - dj_serializer = queryset_to_json(instance_qs) + + if self.only_fields: + dj_serializer = instance_qs.values(*self.only_fields).get() + else: + dj_serializer = queryset_to_json(instance_qs) assert_json_equal(*make_timestamps_matchable([serializer.data, dj_serializer])) @@ -281,6 +317,9 @@ def _test_retrieval(self, api_client, url, url_name, **kwargs): """ Test that hitting the specified URL works with the specified client. + You still need to test for the proper response code; this will just + check for 500s and 403. + Args: - api_client (APIClient): the client to use - url (str): the URL to test with @@ -298,7 +337,8 @@ def _test_retrieval(self, api_client, url, url_name, **kwargs): response = api_client.get(url) assert response.status_code < 500 - assert response.status_code == 403 if kwargs["test_non_authenticated"] else 200 + if kwargs["test_non_authenticated"]: + assert response.status_code == 403 return response def test_get_queryset(self): @@ -441,3 +481,243 @@ def test_create(self, create_data, is_logged_in, client, user_client): assert response.status_code == 403 return response + + +class AuthVariegatedModelViewSetTest(BaseViewSetTest): + """ + Extends the BaseViewSetTest class to add support for auth variegated + viewsets. These viewsets use different serializers based on the user's + session. + + Set read_only_viewset_class and read_write_viewset_class to the appropriate + values for your viewset. + """ + + read_only_serializer_class = None + read_write_serializer_class = None + + @pytest.mark.parametrize("is_logged_in", [True, False]) + def test_get_serializer_class(self, is_logged_in, user): + """Test the viewset returns the right serializer class.""" + + viewset = self.viewset_class() + + if is_logged_in: + # Forcing staff/superuser to get the read-write serializer. + user.is_staff = True + user.is_superuser = True + viewset.request = generate_mocked_request(user) + assert viewset.get_serializer_class() == self.read_write_serializer_class + else: + assert viewset.get_serializer_class() == self.read_only_serializer_class + + def _test_retrieval(self, api_client, url, url_name, **kwargs): # noqa: ARG002 + """ + Test that hitting the specified URL works with the specified client. + This will expect a 200 regardless of what you're requesting as you'll + get a different result set and not a denial for these viewsets. + + noqa ARG002 so the API matches the parent class. + + Args: + - api_client (APIClient): the client to use + - url (str): the URL to test with + - url_name (str): the name of the URL (used for the skip message if it's not defined) + + Keyword Args: + - test_non_authenticated (bool): whether or not this should expect a 403 + + Returns: + - response (Response): the response from the API + """ + if not url: + exception_string = f"{url_name} is not defined" + raise ViewSetNotConfiguredError(exception_string) + + response = api_client.get(url) + assert response.status_code < 500 + return response + + def _determine_client_wrapper( # noqa: PLR0913 + self, is_logged_in, use_staff_user, client, user_client, staff_client + ): + """ + Determine the client to use for the test. + + Args: + - is_logged_in (bool): whether or not the client is logged in + - use_staff_user (bool): use the staff user client, rather than the regular user client + - client (APIClient): the client to use for non-logged in requests + - user_client (APIClient): the client to use for logged in regular user requests + - staff_client (APIClient): the client to use for staff user requests + + Returns: + - client (APIClient): the client to use for the test + """ + if is_logged_in: + if use_staff_user: + return staff_client + else: + return user_client + else: + return client + + @pytest.mark.parametrize( + ("is_logged_in", "use_staff_user"), + [(True, False), (True, True), (False, False)], + ) + def test_retrieve(self, *args): + """ + Test that the viewset can retrieve an object, and returns the correct + dataset depending on the status of the user. + + If the user is anonymous _or_ a regular user, we should receive a + dataset that matches the read-only serializer class. Otherwise, we + should receive a dataset that matches the read-write serializer class. + + Args: + - is_logged_in (bool): whether or not the client is logged in + - use_staff_user (bool): use the staff user client, rather than the regular user client + - client (APIClient): the client to use for non-logged in requests + - user_client (APIClient): the client to use for logged in regular user requests + - staff_client (APIClient): the client to use for staff user requests + """ + instance = self.factory_class() + + is_logged_in = args[0] + use_staff_user = args[1] + + use_client = self._determine_client_wrapper(*args) + + response = self._test_retrieval( + use_client, + self.object_url.format(instance.pk), + "object_url", + ) + + if instance.is_active: + if is_logged_in and use_staff_user: + dj_serializer = self.read_write_serializer_class(instance).data + else: + dj_serializer = self.read_only_serializer_class(instance).data + + assert_json_equal( + *make_timestamps_matchable([response.data, dj_serializer]) + ) + + if not instance.is_active: + # Deleted items should return a 404. + assert response.status_code == 404 + + @pytest.mark.parametrize( + ("is_logged_in", "use_staff_user"), + [(True, False), (True, True), (False, False)], + ) + def test_update(self, update_data, *args): + """ + Test that the viewset can update an object. + + Args: + - update_data (dict): the data to use for the update + - is_logged_in (bool): whether or not the client is logged in + - use_staff_user (bool): use the staff user client, rather than the regular user client + - client (APIClient): the client to use for non-logged in requests + - user_client (APIClient): the client to use for logged in regular user requests + - staff_client (APIClient): the client to use for staff user requests + + Returns: + - tuple of (instance, response): the instance that was updated and the response + """ + instance = self.factory_class() + + use_client = self._determine_client_wrapper(*args) + + is_logged_in = args[0] + use_staff_user = args[1] + + response = use_client.patch( + self.object_url.format(instance.pk), data=update_data + ) + + assert response.status_code < 500 + + if not is_logged_in or not use_staff_user: + assert response.status_code == 403 + + return (instance, response) + + @pytest.mark.parametrize( + ("is_logged_in", "use_staff_user"), + [(True, False), (True, True), (False, False)], + ) + def test_delete(self, *args): + """ + Test that the viewset can delete an object. Note that this will actually test + deletion. + + Args: + - is_logged_in (bool): whether or not the client is logged in + - use_staff_user (bool): use the staff user client, rather than the regular user client + - client (APIClient): the client to use for non-logged in requests + - user_client (APIClient): the client to use for logged in regular user requests + - staff_client (APIClient): the client to use for staff user requests + + Returns: + - tuple of (instance, response): the instance that was deleted and the response + """ + instance = self.factory_class() + + before_count = self.queryset.count() + + use_client = self._determine_client_wrapper(*args) + is_logged_in = args[0] + use_staff_user = args[1] + + response = use_client.delete(self.object_url.format(instance.pk)) + + assert response.status_code < 500 + assert ( + response.status_code == 403 + if not is_logged_in or not use_staff_user + else 204 + ) + + assert ( + self.queryset.count() == before_count - 1 + if is_logged_in and use_staff_user + else before_count + ) + + return (instance, response) + + @pytest.mark.parametrize( + ("is_logged_in", "use_staff_user"), + [(True, False), (True, True), (False, False)], + ) + def test_create(self, create_data, *args): + """ + Test that the viewset can create an object. + + Args: + - create_data (dict): the data to use for the update + - is_logged_in (bool): whether or not the client is logged in + - use_staff_user (bool): use the staff user client, rather than the regular user client + - client (APIClient): the client to use for non-logged in requests + - user_client (APIClient): the client to use for logged in regular user requests + - staff_client (APIClient): the client to use for staff user requests + + Returns: + - response (Response): the response from the API + """ + use_client = self._determine_client_wrapper(*args) + is_logged_in = args[0] + use_staff_user = args[1] + + response = use_client.post(self.list_url, data=create_data) + + assert response.status_code < 500 + + if not is_logged_in or not use_staff_user: + assert response.status_code == 403 + + return response diff --git a/unified_ecommerce/viewsets.py b/unified_ecommerce/viewsets.py new file mode 100644 index 00000000..32de394e --- /dev/null +++ b/unified_ecommerce/viewsets.py @@ -0,0 +1,39 @@ +"""Common viewsets for Unified Ecommerce.""" + +import logging + +from rest_framework import viewsets + +log = logging.getLogger(__name__) + + +class AuthVariegatedModelViewSet(viewsets.ModelViewSet): + """ + Viewset with customizable serializer based on user authentication. + + This bifurcates the ModelViewSet so that if the user is a read-only user (i.e. + not a staff or superuser, or not logged in), they get a separate "read-only" + serializer. Otherwise, we use a regular serializer. The read-only serializer can + then have different fields so you can hide irrelevant data from anonymous users. + + You will need to enforce the read-onlyness of the API yourself; use something like + the IsAuthenticatedOrReadOnly permission class or do something in the serializer. + + Set read_write_serializer_class to the serializer you want to use for admins and + set read_only_serializer_class to the one for regular users. + """ + + read_write_serializer_class = None + read_only_serializer_class = None + + def get_serializer_class(self): + """Get the serializer class for the route.""" + + if hasattr(self, "request") and ( + self.request.user.is_staff or self.request.user.is_superuser + ): + log.debug("get_serializer_class returning the Admin one") + return self.read_write_serializer_class + + log.debug("get_serializer_class returning the regular one") + return self.read_only_serializer_class