From 6ef13ad0d1f09a6818fe6c0605a95c7a170b14a3 Mon Sep 17 00:00:00 2001 From: James Kachel Date: Thu, 29 Feb 2024 07:43:56 -0600 Subject: [PATCH 1/8] Adds the auth variegated viewset and adds system slug to the integrated system --- .../management/commands/generate_test_data.py | 218 ++++++++++++++++++ .../migrations/0002_add_system_slug.py | 17 ++ unified_ecommerce/permissions.py | 20 ++ unified_ecommerce/viewsets.py | 39 ++++ 4 files changed, 294 insertions(+) create mode 100644 system_meta/management/commands/generate_test_data.py create mode 100644 system_meta/migrations/0002_add_system_slug.py create mode 100644 unified_ecommerce/permissions.py create mode 100644 unified_ecommerce/viewsets.py diff --git a/system_meta/management/commands/generate_test_data.py b/system_meta/management/commands/generate_test_data.py new file mode 100644 index 00000000..ea4fa59b --- /dev/null +++ b/system_meta/management/commands/generate_test_data.py @@ -0,0 +1,218 @@ +""" +Adds some test data to the system. This includes three IntegratedSystems with three +products each. + +Ignoring A003 because "help" is valid for argparse. +Ignoring S311 because it's complaining about the faker package. +""" +# ruff: noqa: A003, S311 + +import random +import uuid +from decimal import Decimal + +import faker +from django.core.management import BaseCommand +from django.core.management.base import CommandParser +from django.db import transaction + +from system_meta.models import IntegratedSystem, Product + + +def fake_courseware_id(courseware_type: str, **kwargs) -> str: + """ + Generate a fake courseware id. + + Courseware IDs generally are in the format: + -v1:+(+) + + Type is either "course" or "program", depending on what you specify. School ID is + one of "MITx", "MITxT", "edX", "xPRO", or "Sample". Courseware ID is a set of + numbers: a number < 100, a number < 1000 with a leading zero, and an optional + number < 10, separated by periods. Courseware ID is followed by an "x". This + should be pretty like the IDs that are on MITx Online now (but pretty unlike the + xPRO ones, which usually use a text courseware ID, but that's fine since these + are fake). + + Arguments: + - courseware_type (str): "course" or "program"; the type of + courseware id to generate. + + Keyword Arguments: + - include_run_tag (bool): include the run tag. Defaults to False. + + Returns: + - str: The generated courseware id, in the normal format. + """ + fake = faker.Faker() + + school_id = random.choice(["MITx", "MITxT", "edX", "xPRO", "Sample"]) + courseware_id = f"{random.randint(0, 99)}.{random.randint(0, 999):03d}" + courseware_type = courseware_type.lower() + optional_third_digit = random.randint(0, 9) if fake.boolean() else "" + optional_run_tag = ( + f"+{random.randint(1,3)}T{fake.date_this_decade().year}" + if kwargs["include_run_tag"] + else "" + ) + + return ( + f"{courseware_type}-v1:{school_id}+{courseware_id}" + f"{optional_third_digit}x{optional_run_tag}" + ) + + +class Command(BaseCommand): + """Adds some test data to the system.""" + + def add_arguments(self, parser: CommandParser) -> None: + """Add arguments to the command parser.""" + parser.add_argument( + "--remove", + action="store_true", + help="Remove the test data. This is potentially dangerous.", + ) + + parser.add_argument( + "--only-systems", + action="store_true", + help="Only add test systems.", + ) + + parser.add_argument( + "--only-products", + action="store_true", + help="Only add test products.", + ) + + parser.add_argument( + "--system", + type=str, + help=( + "The name of the system to add products to." + " Only used with --only-products." + ), + nargs="?", + ) + + def add_test_systems(self) -> None: + """Add the test systems.""" + max_systems = 3 + for i in range(1, max_systems + 1): + IntegratedSystem.objects.create( + name=f"Test System {i}", + description=f"Test System {i} description.", + is_active=True, + api_key=uuid.uuid4(), + ) + + def add_test_products(self, system: str) -> None: + """Add the test products to the specified system.""" + + if not IntegratedSystem.objects.filter(name=system).exists(): + self.stdout.write( + self.style.ERROR(f"Integrated system {system} does not exist.") + ) + return + + system = IntegratedSystem.objects.get(name=system) + + for i in range(1, 4): + product_sku = fake_courseware_id("course", include_run_tag=True) + Product.objects.create( + name=f"Test Product {i}", + description=f"Test Product {i} description.", + sku=product_sku, + system=system, + is_active=True, + price=Decimal(random.random() * 10000).quantize(Decimal("0.01")), + system_data={ + "courserun": product_sku, + "program": fake_courseware_id("program"), + }, + ) + + def remove_test_data(self) -> None: + """Remove the test data.""" + + test_systems = ( + IntegratedSystem.all_objects.prefetch_related("products") + .filter(name__startswith="Test System") + .all() + ) + + self.stdout.write( + self.style.WARNING("This command will remove these systems and products:") + ) + + for system in test_systems: + self.stdout.write( + self.style.WARNING(f"System: {system.name} ({system.id})") + ) + + for product in system.products.all(): + self.stdout.write( + self.style.WARNING(f"\tProduct: {product.name} ({product.id})") + ) + + self.stdout.write( + self.style.WARNING( + "This will ACTUALLY DELETE these records." + " Are you sure you want to do this?" + ) + ) + + if input("Type 'yes' to continue: ") != "yes": + self.stdout.write(self.style.ERROR("Aborting.")) + return + + for system in test_systems: + Product.all_objects.filter( + pk__in=[product.id for product in system.products.all()] + ).delete() + IntegratedSystem.all_objects.filter(pk=system.id).delete() + + self.stdout.write(self.style.SUCCESS("Test data removed.")) + + def handle(self, *args, **options) -> None: # noqa: ARG002 + """Handle the command.""" + remove = options["remove"] + only_systems = options["only_systems"] + only_products = options["only_products"] + systems = [options["system"]] if options["system"] else [] + + with transaction.atomic(): + if remove: + self.remove_test_data() + return + + if not only_products: + self.add_test_systems() + + if not only_systems: + if only_products and len(systems) == 0: + self.stdout.write( + self.style.ERROR( + "You must specify a system when using --only-products." + ) + ) + return + else: + systems = [ + system.name + for system in ( + IntegratedSystem.all_objects.filter( + name__startswith="Test System" + ).all() + ) + ] + + [self.add_test_products(system) for system in systems] + return + + if not only_products: + third_test_system = IntegratedSystem.all_objects.filter( + name__startswith="Test System" + ).get() + third_test_system.is_active = False + third_test_system.save(update_fields=("is_active",)) diff --git a/system_meta/migrations/0002_add_system_slug.py b/system_meta/migrations/0002_add_system_slug.py new file mode 100644 index 00000000..779b7331 --- /dev/null +++ b/system_meta/migrations/0002_add_system_slug.py @@ -0,0 +1,17 @@ +# Generated by Django 4.2.8 on 2024-01-11 16:49 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("system_meta", "0001_add_integrated_system_and_product_models"), + ] + + operations = [ + migrations.AddField( + model_name="integratedsystem", + name="slug", + field=models.CharField(blank=True, max_length=80, null=True, unique=True), + ), + ] diff --git a/unified_ecommerce/permissions.py b/unified_ecommerce/permissions.py new file mode 100644 index 00000000..1c30574b --- /dev/null +++ b/unified_ecommerce/permissions.py @@ -0,0 +1,20 @@ +"""Custom DRF permissions.""" + +from rest_framework import permissions + + +class IsAdminUserOrReadOnly(permissions.BasePermission): + """Determines if the user owns the object""" + + def has_permission(self, request, view): # noqa: ARG002 + """ + Return True if the user is an admin user requesting a write operation, + or if the user is logged in. Otherwise, return False. + """ + + if request.method in permissions.SAFE_METHODS or ( + request.user.is_authenticated and request.user.is_staff + ): + return True + + return False diff --git a/unified_ecommerce/viewsets.py b/unified_ecommerce/viewsets.py new file mode 100644 index 00000000..32de394e --- /dev/null +++ b/unified_ecommerce/viewsets.py @@ -0,0 +1,39 @@ +"""Common viewsets for Unified Ecommerce.""" + +import logging + +from rest_framework import viewsets + +log = logging.getLogger(__name__) + + +class AuthVariegatedModelViewSet(viewsets.ModelViewSet): + """ + Viewset with customizable serializer based on user authentication. + + This bifurcates the ModelViewSet so that if the user is a read-only user (i.e. + not a staff or superuser, or not logged in), they get a separate "read-only" + serializer. Otherwise, we use a regular serializer. The read-only serializer can + then have different fields so you can hide irrelevant data from anonymous users. + + You will need to enforce the read-onlyness of the API yourself; use something like + the IsAuthenticatedOrReadOnly permission class or do something in the serializer. + + Set read_write_serializer_class to the serializer you want to use for admins and + set read_only_serializer_class to the one for regular users. + """ + + read_write_serializer_class = None + read_only_serializer_class = None + + def get_serializer_class(self): + """Get the serializer class for the route.""" + + if hasattr(self, "request") and ( + self.request.user.is_staff or self.request.user.is_superuser + ): + log.debug("get_serializer_class returning the Admin one") + return self.read_write_serializer_class + + log.debug("get_serializer_class returning the regular one") + return self.read_only_serializer_class From 48a0ae53a247c19c155ac70420afad12e05f990d Mon Sep 17 00:00:00 2001 From: James Kachel Date: Thu, 29 Feb 2024 07:52:53 -0600 Subject: [PATCH 2/8] Adding more support stuff, migrating in other changes that needed to happen for this --- system_meta/models.py | 8 + system_meta/serializers.py | 10 ++ system_meta/serializers_test.py | 16 +- system_meta/views.py | 31 +++- system_meta/views_test.py | 268 +++++++++++++++++++++++++----- unified_ecommerce/test_utils.py | 286 +++++++++++++++++++++++++++++++- 6 files changed, 570 insertions(+), 49 deletions(-) diff --git a/system_meta/models.py b/system_meta/models.py index f2a17e1d..7d37c83d 100644 --- a/system_meta/models.py +++ b/system_meta/models.py @@ -8,6 +8,7 @@ from mitol.common.models import TimestampedModel from safedelete.managers import SafeDeleteManager from safedelete.models import SafeDeleteModel +from slugify import slugify from unified_ecommerce.utils import SoftDeleteActiveModel @@ -19,6 +20,7 @@ class IntegratedSystem(SafeDeleteModel, SoftDeleteActiveModel, TimestampedModel) """Represents an integrated system""" name = models.CharField(max_length=255, unique=True) + slug = models.CharField(max_length=80, unique=True, blank=True, null=True) description = models.TextField(blank=True) api_key = models.TextField(blank=True) @@ -29,6 +31,12 @@ def __str__(self): """Return string representation of the system""" return f"{self.name} ({self.id})" + def save(self, *args, **kwargs): + """Save the product. Create a slug if it doesn't already exist.""" + if not self.slug: + self.slug = slugify(self.name) + super().save(*args, **kwargs) + @reversion.register(exclude=("created_on", "updated_on")) class Product(SafeDeleteModel, SoftDeleteActiveModel, TimestampedModel): diff --git a/system_meta/serializers.py b/system_meta/serializers.py index 20dc2d51..256285c9 100644 --- a/system_meta/serializers.py +++ b/system_meta/serializers.py @@ -8,6 +8,16 @@ class IntegratedSystemSerializer(serializers.ModelSerializer): """Serializer for IntegratedSystem model.""" + class Meta: + """Meta class for serializer.""" + + model = IntegratedSystem + fields = ["id", "name", "slug", "description"] + + +class AdminIntegratedSystemSerializer(serializers.ModelSerializer): + """Serializer for IntegratedSystem model.""" + class Meta: """Meta class for serializer.""" diff --git a/system_meta/serializers_test.py b/system_meta/serializers_test.py index 2560aa79..c9afa24b 100644 --- a/system_meta/serializers_test.py +++ b/system_meta/serializers_test.py @@ -4,12 +4,25 @@ from system_meta.factories import IntegratedSystemFactory, ProductFactory from system_meta.models import IntegratedSystem, Product -from system_meta.serializers import IntegratedSystemSerializer, ProductSerializer +from system_meta.serializers import ( + AdminIntegratedSystemSerializer, + IntegratedSystemSerializer, + ProductSerializer, +) from unified_ecommerce.test_utils import BaseSerializerTest pytestmark = pytest.mark.django_db +class TestAdminIntegratedSystemSerializer(BaseSerializerTest): + """Tests for the IntegratedSystemSerializer.""" + + serializer_class = AdminIntegratedSystemSerializer + factory_class = IntegratedSystemFactory + model_class = IntegratedSystem + queryset = IntegratedSystem.all_objects + + class TestIntegratedSystemSerializer(BaseSerializerTest): """Tests for the IntegratedSystemSerializer.""" @@ -17,6 +30,7 @@ class TestIntegratedSystemSerializer(BaseSerializerTest): factory_class = IntegratedSystemFactory model_class = IntegratedSystem queryset = IntegratedSystem.all_objects + only_fields = ["id", "name", "slug", "description"] class TestProductSerializer(BaseSerializerTest): diff --git a/system_meta/views.py b/system_meta/views.py index cbda5d83..0f1d94f1 100644 --- a/system_meta/views.py +++ b/system_meta/views.py @@ -2,6 +2,7 @@ import logging +from django_filters.rest_framework import DjangoFilterBackend from rest_framework import status, viewsets from rest_framework.decorators import ( api_view, @@ -14,29 +15,49 @@ from rest_framework.response import Response from system_meta.models import IntegratedSystem, Product -from system_meta.serializers import IntegratedSystemSerializer, ProductSerializer +from system_meta.serializers import ( + AdminIntegratedSystemSerializer, + IntegratedSystemSerializer, + ProductSerializer, +) from unified_ecommerce.authentication import ( ApiGatewayAuthentication, ) +from unified_ecommerce.permissions import ( + IsAdminUserOrReadOnly, +) from unified_ecommerce.utils import decode_x_header +from unified_ecommerce.viewsets import AuthVariegatedModelViewSet log = logging.getLogger(__name__) -class IntegratedSystemViewSet(viewsets.ModelViewSet): +class IntegratedSystemViewSet(AuthVariegatedModelViewSet): """Viewset for IntegratedSystem model.""" queryset = IntegratedSystem.objects.all() - serializer_class = IntegratedSystemSerializer - permission_classes = (IsAuthenticated,) + read_write_serializer_class = AdminIntegratedSystemSerializer + read_only_serializer_class = IntegratedSystemSerializer + permission_classes = [ + IsAdminUserOrReadOnly, + ] -class ProductViewSet(viewsets.ModelViewSet): +class ProductViewSet(AuthVariegatedModelViewSet): """Viewset for Product model.""" queryset = Product.objects.all() serializer_class = ProductSerializer permission_classes = (IsAuthenticated,) + read_write_serializer_class = ProductSerializer + read_only_serializer_class = ProductSerializer + filter_backends = [ + DjangoFilterBackend, + ] + filterset_fields = [ + "name", + "system__slug", + ] @api_view(["GET"]) diff --git a/system_meta/views_test.py b/system_meta/views_test.py index aa249b1d..92ee4cf6 100644 --- a/system_meta/views_test.py +++ b/system_meta/views_test.py @@ -10,25 +10,50 @@ IntegratedSystemFactory, ) from system_meta.models import IntegratedSystem, Product +from system_meta.serializers import ( + AdminIntegratedSystemSerializer, + IntegratedSystemSerializer, + ProductSerializer, +) from system_meta.views import IntegratedSystemViewSet, ProductViewSet -from unified_ecommerce.test_utils import BaseViewSetTest +from unified_ecommerce.test_utils import AuthVariegatedModelViewSetTest pytestmark = pytest.mark.django_db -class TestIntegratedSystemViewSet(BaseViewSetTest): +class TestIntegratedSystemViewSet(AuthVariegatedModelViewSetTest): """Tests for the IntegratedSystemViewSet.""" viewset_class = IntegratedSystemViewSet factory_class = ActiveIntegratedSystemFactory - queryset = IntegratedSystem.objects.all() + queryset = IntegratedSystem.all_objects.all() list_url = "/api/v0/meta/integrated_system/" object_url = "/api/v0/meta/integrated_system/{}/" - @pytest.mark.parametrize("is_logged_in", [True, False]) - @pytest.mark.parametrize("is_active_system", [True, False]) - def test_retrieve(self, is_logged_in, is_active_system, client, user_client): + read_only_serializer_class = IntegratedSystemSerializer + read_write_serializer_class = AdminIntegratedSystemSerializer + + @pytest.mark.parametrize( + ("is_logged_in", "use_staff_user", "is_active_system"), + [ + (True, False, True), + (True, True, True), + (False, False, True), + (True, False, False), + (True, True, False), + (False, False, False), + ], + ) + def test_retrieve( # noqa: PLR0913 + self, + is_active_system, + is_logged_in, + use_staff_user, + client, + user_client, + staff_client, + ): """ Test that the viewset can retrieve an object that is either active or possibly inactive. @@ -40,11 +65,30 @@ def test_retrieve(self, is_logged_in, is_active_system, client, user_client): else InactiveIntegratedSystemFactory ) - super().test_retrieve(is_logged_in, client, user_client) + super().test_retrieve( + is_logged_in, use_staff_user, client, user_client, staff_client + ) - @pytest.mark.parametrize("is_active_system", [True, False]) - @pytest.mark.parametrize("is_logged_in", [True, False]) - def test_update(self, is_active_system, is_logged_in, client, user_client): + @pytest.mark.parametrize( + ("is_logged_in", "use_staff_user", "is_active_system"), + [ + (True, False, True), + (True, True, True), + (False, False, True), + (True, False, False), + (True, True, False), + (False, False, False), + ], + ) + def test_update( # noqa: PLR0913 + self, + is_active_system, + is_logged_in, + use_staff_user, + client, + user_client, + staff_client, + ): """Test that the viewset can update an object.""" self.factory_class = ( ActiveIntegratedSystemFactory @@ -54,10 +98,10 @@ def test_update(self, is_active_system, is_logged_in, client, user_client): update_data = {"name": "Updated Name"} (instance, response) = super().test_update( - update_data, is_logged_in, client, user_client + update_data, is_logged_in, use_staff_user, client, user_client, staff_client ) - if is_logged_in: + if is_logged_in and use_staff_user: if not is_active_system: assert instance.name != update_data["name"] assert response.status_code == 404 @@ -68,21 +112,73 @@ def test_update(self, is_active_system, is_logged_in, client, user_client): else: assert instance.name != update_data["name"] - @pytest.mark.parametrize("is_logged_in", [True, False]) - def test_delete(self, is_logged_in, client, user_client): + @pytest.mark.parametrize( + ( + "is_logged_in", + "use_staff_user", + ), + [ + ( + True, + False, + ), + ( + True, + True, + ), + ( + False, + False, + ), + ], + ) + def test_delete( # noqa: PLR0913 + self, is_logged_in, use_staff_user, client, user_client, staff_client + ): """Test that the viewset can delete an object.""" - (instance, response) = super().test_delete(is_logged_in, client, user_client) - if is_logged_in: + self.queryset = IntegratedSystem.objects.all() + (instance, response) = super().test_delete( + is_logged_in, use_staff_user, client, user_client, staff_client + ) + + if is_logged_in and use_staff_user: instance.refresh_from_db() assert response.status_code == 204 assert not instance.is_active else: assert instance.is_active - @pytest.mark.parametrize("is_logged_in", [True, False]) + @pytest.mark.parametrize( + ( + "is_logged_in", + "use_staff_user", + ), + [ + ( + True, + False, + ), + ( + True, + True, + ), + ( + False, + False, + ), + ], + ) @pytest.mark.parametrize("with_bad_data", [True, False]) - def test_create(self, with_bad_data, is_logged_in, client, user_client): + def test_create( # noqa: PLR0913 + self, + with_bad_data, + is_logged_in, + use_staff_user, + client, + user_client, + staff_client, + ): """Test that the viewset can create an object.""" create_data = { "description": "a description", @@ -92,20 +188,21 @@ def test_create(self, with_bad_data, is_logged_in, client, user_client): if not with_bad_data: create_data["name"] = "System Name" - response = super().test_create(create_data, is_logged_in, client, user_client) + response = super().test_create( + create_data, is_logged_in, use_staff_user, client, user_client, staff_client + ) - if is_logged_in: + if is_logged_in and use_staff_user: assert response.status_code == 201 if not with_bad_data else 400 assert ( IntegratedSystem.objects.filter(name="System Name").exists() is not with_bad_data ) else: - assert response.data["detail"].code == "not_authenticated" assert response.status_code == 403 -class TestProductViewSet(BaseViewSetTest): +class TestProductViewSet(AuthVariegatedModelViewSetTest): """Tests for the ProductViewSet.""" viewset_class = ProductViewSet @@ -115,9 +212,29 @@ class TestProductViewSet(BaseViewSetTest): list_url = "/api/v0/meta/product/" object_url = "/api/v0/meta/product/{}/" - @pytest.mark.parametrize("is_logged_in", [True, False]) - @pytest.mark.parametrize("is_active_product", [True, False]) - def test_retrieve(self, is_logged_in, is_active_product, client, user_client): + read_only_serializer_class = ProductSerializer + read_write_serializer_class = ProductSerializer + + @pytest.mark.parametrize( + ("is_logged_in", "use_staff_user", "is_active_product"), + [ + (True, False, True), + (True, True, True), + (False, False, True), + (True, False, False), + (True, True, False), + (False, False, False), + ], + ) + def test_retrieve( # noqa: PLR0913 + self, + is_active_product, + is_logged_in, + use_staff_user, + client, + user_client, + staff_client, + ): """ Test that the viewset can retrieve an object that is either active or possibly inactive. @@ -127,11 +244,30 @@ def test_retrieve(self, is_logged_in, is_active_product, client, user_client): ActiveProductFactory if is_active_product else InactiveProductFactory ) - super().test_retrieve(is_logged_in, client, user_client) + super().test_retrieve( + is_logged_in, use_staff_user, client, user_client, staff_client + ) - @pytest.mark.parametrize("is_active_product", [True, False]) - @pytest.mark.parametrize("is_logged_in", [True, False]) - def test_update(self, is_active_product, is_logged_in, client, user_client): + @pytest.mark.parametrize( + ("is_logged_in", "use_staff_user", "is_active_product"), + [ + (True, False, True), + (True, True, True), + (False, False, True), + (True, False, False), + (True, True, False), + (False, False, False), + ], + ) + def test_update( # noqa: PLR0913 + self, + is_active_product, + is_logged_in, + use_staff_user, + client, + user_client, + staff_client, + ): """Test that the viewset can update an object.""" self.factory_class = ( ActiveProductFactory if is_active_product else InactiveProductFactory @@ -139,10 +275,10 @@ def test_update(self, is_active_product, is_logged_in, client, user_client): update_data = {"name": "Updated Name"} (instance, response) = super().test_update( - update_data, is_logged_in, client, user_client + update_data, is_logged_in, use_staff_user, client, user_client, staff_client ) - if is_logged_in: + if is_logged_in and use_staff_user: if not is_active_product: assert instance.name != update_data["name"] assert response.status_code == 404 @@ -153,21 +289,72 @@ def test_update(self, is_active_product, is_logged_in, client, user_client): else: assert instance.name != update_data["name"] - @pytest.mark.parametrize("is_logged_in", [True, False]) - def test_delete(self, is_logged_in, client, user_client): + @pytest.mark.parametrize( + ( + "is_logged_in", + "use_staff_user", + ), + [ + ( + True, + False, + ), + ( + True, + True, + ), + ( + False, + False, + ), + ], + ) + def test_delete( # noqa: PLR0913 + self, is_logged_in, use_staff_user, client, user_client, staff_client + ): """Test that the viewset can delete an object.""" - (instance, response) = super().test_delete(is_logged_in, client, user_client) + self.queryset = Product.objects.all() + (instance, response) = super().test_delete( + is_logged_in, use_staff_user, client, user_client, staff_client + ) - if is_logged_in: + if is_logged_in and use_staff_user: instance.refresh_from_db() assert response.status_code == 204 assert not instance.is_active else: assert instance.is_active - @pytest.mark.parametrize("is_logged_in", [True, False]) + @pytest.mark.parametrize( + ( + "is_logged_in", + "use_staff_user", + ), + [ + ( + True, + False, + ), + ( + True, + True, + ), + ( + False, + False, + ), + ], + ) @pytest.mark.parametrize("with_bad_data", [True, False]) - def test_create(self, with_bad_data, is_logged_in, client, user_client): + def test_create( # noqa: PLR0913 + self, + with_bad_data, + is_logged_in, + use_staff_user, + client, + user_client, + staff_client, + ): """Test that the viewset can create an object.""" system = IntegratedSystemFactory.create() create_data = { @@ -180,11 +367,12 @@ def test_create(self, with_bad_data, is_logged_in, client, user_client): if not with_bad_data: create_data["description"] = "a description" - response = super().test_create(create_data, is_logged_in, client, user_client) + response = super().test_create( + create_data, is_logged_in, use_staff_user, client, user_client, staff_client + ) - if is_logged_in: + if is_logged_in and use_staff_user: assert response.status_code == 201 if not with_bad_data else 400 assert Product.objects.filter(name="New Name").exists() is not with_bad_data else: - assert response.data["detail"].code == "not_authenticated" assert response.status_code == 403 diff --git a/unified_ecommerce/test_utils.py b/unified_ecommerce/test_utils.py index ef19d2f2..1631814e 100644 --- a/unified_ecommerce/test_utils.py +++ b/unified_ecommerce/test_utils.py @@ -13,6 +13,7 @@ from django.conf import settings from django.core.serializers import serialize from django.core.serializers.json import DjangoJSONEncoder +from django.http import HttpRequest from django.http.response import HttpResponse from rest_framework.renderers import JSONRenderer @@ -213,6 +214,27 @@ def make_timestamps_matchable(objs, **kwargs): ] +def generate_mocked_request(user): + """ + Generate a mocked request for test_process_cybersource_*. + + The RequestFactory misses some stuff, so instead just make a full-fat + HttpRequest and add the things in that we need. + + Args: + - user (User): The user to set in the request. + + Returns: + - HttpRequest: The mocked request. + """ + mocked_request = HttpRequest() + mocked_request.user = user + mocked_request.META["REMOTE_ADDR"] = "127.0.0.1" + mocked_request.META["HTTP_HOST"] = "localhost" + + return mocked_request + + class PickleableMock(Mock): """ A Mock that can be passed to pickle.dumps() @@ -232,12 +254,22 @@ class ViewSetNotConfiguredError(Exception): class BaseSerializerTest: - """Base class for serializer tests.""" + """ + Base class for serializer tests. + + Class variables: + - model_class (class): the model class to test + - serializer_class (class): the serializer class to test + - factory_class (class): the factory class to use for creating instances + - queryset (QuerySet): the queryset to use, if you need a custom one + - only_fields (list): a list of field names that should be included in the serialized output + """ model_class = None serializer_class = None factory_class = None queryset = None + only_fields = None def test_serialize(self): """Test that the serializer can serialize an instance.""" @@ -249,7 +281,11 @@ def test_serialize(self): instance_qs = self.model_class.objects.filter(pk=instance.pk) serializer = self.serializer_class(instance) - dj_serializer = queryset_to_json(instance_qs) + + if self.only_fields: + dj_serializer = instance_qs.values(*self.only_fields).get() + else: + dj_serializer = queryset_to_json(instance_qs) assert_json_equal(*make_timestamps_matchable([serializer.data, dj_serializer])) @@ -281,6 +317,9 @@ def _test_retrieval(self, api_client, url, url_name, **kwargs): """ Test that hitting the specified URL works with the specified client. + You still need to test for the proper response code; this will just + check for 500s and 403. + Args: - api_client (APIClient): the client to use - url (str): the URL to test with @@ -298,7 +337,8 @@ def _test_retrieval(self, api_client, url, url_name, **kwargs): response = api_client.get(url) assert response.status_code < 500 - assert response.status_code == 403 if kwargs["test_non_authenticated"] else 200 + if kwargs["test_non_authenticated"]: + assert response.status_code == 403 return response def test_get_queryset(self): @@ -441,3 +481,243 @@ def test_create(self, create_data, is_logged_in, client, user_client): assert response.status_code == 403 return response + + +class AuthVariegatedModelViewSetTest(BaseViewSetTest): + """ + Extends the BaseViewSetTest class to add support for auth variegated + viewsets. These viewsets use different serializers based on the user's + session. + + Set read_only_viewset_class and read_write_viewset_class to the appropriate + values for your viewset. + """ + + read_only_serializer_class = None + read_write_serializer_class = None + + @pytest.mark.parametrize("is_logged_in", [True, False]) + def test_get_serializer_class(self, is_logged_in, user): + """Test the viewset returns the right serializer class.""" + + viewset = self.viewset_class() + + if is_logged_in: + # Forcing staff/superuser to get the read-write serializer. + user.is_staff = True + user.is_superuser = True + viewset.request = generate_mocked_request(user) + assert viewset.get_serializer_class() == self.read_write_serializer_class + else: + assert viewset.get_serializer_class() == self.read_only_serializer_class + + def _test_retrieval(self, api_client, url, url_name, **kwargs): # noqa: ARG002 + """ + Test that hitting the specified URL works with the specified client. + This will expect a 200 regardless of what you're requesting as you'll + get a different result set and not a denial for these viewsets. + + noqa ARG002 so the API matches the parent class. + + Args: + - api_client (APIClient): the client to use + - url (str): the URL to test with + - url_name (str): the name of the URL (used for the skip message if it's not defined) + + Keyword Args: + - test_non_authenticated (bool): whether or not this should expect a 403 + + Returns: + - response (Response): the response from the API + """ + if not url: + exception_string = f"{url_name} is not defined" + raise ViewSetNotConfiguredError(exception_string) + + response = api_client.get(url) + assert response.status_code < 500 + return response + + def _determine_client_wrapper( # noqa: PLR0913 + self, is_logged_in, use_staff_user, client, user_client, staff_client + ): + """ + Determine the client to use for the test. + + Args: + - is_logged_in (bool): whether or not the client is logged in + - use_staff_user (bool): use the staff user client, rather than the regular user client + - client (APIClient): the client to use for non-logged in requests + - user_client (APIClient): the client to use for logged in regular user requests + - staff_client (APIClient): the client to use for staff user requests + + Returns: + - client (APIClient): the client to use for the test + """ + if is_logged_in: + if use_staff_user: + return staff_client + else: + return user_client + else: + return client + + @pytest.mark.parametrize( + ("is_logged_in", "use_staff_user"), + [(True, False), (True, True), (False, False)], + ) + def test_retrieve(self, *args): + """ + Test that the viewset can retrieve an object, and returns the correct + dataset depending on the status of the user. + + If the user is anonymous _or_ a regular user, we should receive a + dataset that matches the read-only serializer class. Otherwise, we + should receive a dataset that matches the read-write serializer class. + + Args: + - is_logged_in (bool): whether or not the client is logged in + - use_staff_user (bool): use the staff user client, rather than the regular user client + - client (APIClient): the client to use for non-logged in requests + - user_client (APIClient): the client to use for logged in regular user requests + - staff_client (APIClient): the client to use for staff user requests + """ + instance = self.factory_class() + + is_logged_in = args[0] + use_staff_user = args[1] + + use_client = self._determine_client_wrapper(*args) + + response = self._test_retrieval( + use_client, + self.object_url.format(instance.pk), + "object_url", + ) + + if instance.is_active: + if is_logged_in and use_staff_user: + dj_serializer = self.read_write_serializer_class(instance).data + else: + dj_serializer = self.read_only_serializer_class(instance).data + + assert_json_equal( + *make_timestamps_matchable([response.data, dj_serializer]) + ) + + if not instance.is_active: + # Deleted items should return a 404. + assert response.status_code == 404 + + @pytest.mark.parametrize( + ("is_logged_in", "use_staff_user"), + [(True, False), (True, True), (False, False)], + ) + def test_update(self, update_data, *args): + """ + Test that the viewset can update an object. + + Args: + - update_data (dict): the data to use for the update + - is_logged_in (bool): whether or not the client is logged in + - use_staff_user (bool): use the staff user client, rather than the regular user client + - client (APIClient): the client to use for non-logged in requests + - user_client (APIClient): the client to use for logged in regular user requests + - staff_client (APIClient): the client to use for staff user requests + + Returns: + - tuple of (instance, response): the instance that was updated and the response + """ + instance = self.factory_class() + + use_client = self._determine_client_wrapper(*args) + + is_logged_in = args[0] + use_staff_user = args[1] + + response = use_client.patch( + self.object_url.format(instance.pk), data=update_data + ) + + assert response.status_code < 500 + + if not is_logged_in or not use_staff_user: + assert response.status_code == 403 + + return (instance, response) + + @pytest.mark.parametrize( + ("is_logged_in", "use_staff_user"), + [(True, False), (True, True), (False, False)], + ) + def test_delete(self, *args): + """ + Test that the viewset can delete an object. Note that this will actually test + deletion. + + Args: + - is_logged_in (bool): whether or not the client is logged in + - use_staff_user (bool): use the staff user client, rather than the regular user client + - client (APIClient): the client to use for non-logged in requests + - user_client (APIClient): the client to use for logged in regular user requests + - staff_client (APIClient): the client to use for staff user requests + + Returns: + - tuple of (instance, response): the instance that was deleted and the response + """ + instance = self.factory_class() + + before_count = self.queryset.count() + + use_client = self._determine_client_wrapper(*args) + is_logged_in = args[0] + use_staff_user = args[1] + + response = use_client.delete(self.object_url.format(instance.pk)) + + assert response.status_code < 500 + assert ( + response.status_code == 403 + if not is_logged_in or not use_staff_user + else 204 + ) + + assert ( + self.queryset.count() == before_count - 1 + if is_logged_in and use_staff_user + else before_count + ) + + return (instance, response) + + @pytest.mark.parametrize( + ("is_logged_in", "use_staff_user"), + [(True, False), (True, True), (False, False)], + ) + def test_create(self, create_data, *args): + """ + Test that the viewset can create an object. + + Args: + - create_data (dict): the data to use for the update + - is_logged_in (bool): whether or not the client is logged in + - use_staff_user (bool): use the staff user client, rather than the regular user client + - client (APIClient): the client to use for non-logged in requests + - user_client (APIClient): the client to use for logged in regular user requests + - staff_client (APIClient): the client to use for staff user requests + + Returns: + - response (Response): the response from the API + """ + use_client = self._determine_client_wrapper(*args) + is_logged_in = args[0] + use_staff_user = args[1] + + response = use_client.post(self.list_url, data=create_data) + + assert response.status_code < 500 + + if not is_logged_in or not use_staff_user: + assert response.status_code == 403 + + return response From 10c7262e19ed605e1fbcc4d7248cc9b9b0632027 Mon Sep 17 00:00:00 2001 From: James Kachel Date: Thu, 29 Feb 2024 09:33:33 -0600 Subject: [PATCH 3/8] Updating OpenAPI, getting tests to run properly --- openapi.yaml | 56 +++++++++++++++----------------------------------- pyproject.toml | 1 + 2 files changed, 17 insertions(+), 40 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index 629bf406..b24a811c 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -172,12 +172,20 @@ paths: description: Number of results to return per page. schema: type: integer + - in: query + name: name + schema: + type: string - name: offset required: false in: query description: The initial index from which to return the results. schema: type: integer + - in: query + name: system__slug + schema: + type: string tags: - product security: @@ -328,36 +336,18 @@ components: id: type: integer readOnly: true - deleted_on: - type: string - format: date-time - readOnly: true - nullable: true - deleted_by_cascade: - type: boolean - readOnly: true - created_on: - type: string - format: date-time - readOnly: true - updated_on: - type: string - format: date-time - readOnly: true name: type: string maxLength: 255 - description: + slug: type: string - api_key: + nullable: true + maxLength: 80 + description: type: string required: - - created_on - - deleted_by_cascade - - deleted_on - id - name - - updated_on PaginatedIntegratedSystemList: type: object properties: @@ -405,28 +395,14 @@ components: id: type: integer readOnly: true - deleted_on: - type: string - format: date-time - readOnly: true - nullable: true - deleted_by_cascade: - type: boolean - readOnly: true - created_on: - type: string - format: date-time - readOnly: true - updated_on: - type: string - format: date-time - readOnly: true name: type: string maxLength: 255 - description: + slug: type: string - api_key: + nullable: true + maxLength: 80 + description: type: string PatchedProduct: type: object diff --git a/pyproject.toml b/pyproject.toml index 18da4762..f3a889eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,7 @@ python-slugify = "^8.0.1" django-oauth-toolkit = "^2.3.0" requests-oauthlib = "^1.3.1" oauthlib = "^3.2.2" +python-slugify = "^8.0.1" [tool.poetry.group.dev.dependencies] bpython = "^0.24" From 70ea4dd990e08ece6954f3751ccc90b48691370c Mon Sep 17 00:00:00 2001 From: James Kachel Date: Thu, 29 Feb 2024 09:49:13 -0600 Subject: [PATCH 4/8] Fixing generate_test_data command --- system_meta/management/commands/generate_test_data.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/system_meta/management/commands/generate_test_data.py b/system_meta/management/commands/generate_test_data.py index ea4fa59b..2a5d58a6 100644 --- a/system_meta/management/commands/generate_test_data.py +++ b/system_meta/management/commands/generate_test_data.py @@ -52,7 +52,7 @@ def fake_courseware_id(courseware_type: str, **kwargs) -> str: optional_third_digit = random.randint(0, 9) if fake.boolean() else "" optional_run_tag = ( f"+{random.randint(1,3)}T{fake.date_this_decade().year}" - if kwargs["include_run_tag"] + if kwargs.get("include_run_tag", False) else "" ) @@ -102,7 +102,6 @@ def add_test_systems(self) -> None: IntegratedSystem.objects.create( name=f"Test System {i}", description=f"Test System {i} description.", - is_active=True, api_key=uuid.uuid4(), ) @@ -124,7 +123,6 @@ def add_test_products(self, system: str) -> None: description=f"Test Product {i} description.", sku=product_sku, system=system, - is_active=True, price=Decimal(random.random() * 10000).quantize(Decimal("0.01")), system_data={ "courserun": product_sku, From bd086144303d7da67c76dbd1615100b4164326f6 Mon Sep 17 00:00:00 2001 From: James Kachel Date: Wed, 27 Mar 2024 15:18:00 -0500 Subject: [PATCH 5/8] Fixing perms on ProductViewSet so tests pass (and so it works right too) --- pyproject.toml | 1 - system_meta/views.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f3a889eb..18da4762 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,6 @@ python-slugify = "^8.0.1" django-oauth-toolkit = "^2.3.0" requests-oauthlib = "^1.3.1" oauthlib = "^3.2.2" -python-slugify = "^8.0.1" [tool.poetry.group.dev.dependencies] bpython = "^0.24" diff --git a/system_meta/views.py b/system_meta/views.py index 0f1d94f1..65d666b7 100644 --- a/system_meta/views.py +++ b/system_meta/views.py @@ -48,7 +48,9 @@ class ProductViewSet(AuthVariegatedModelViewSet): queryset = Product.objects.all() serializer_class = ProductSerializer - permission_classes = (IsAuthenticated,) + permission_classes = [ + IsAdminUserOrReadOnly, + ] read_write_serializer_class = ProductSerializer read_only_serializer_class = ProductSerializer filter_backends = [ From 239042aa39f866c63f767938331a6fe8813463d9 Mon Sep 17 00:00:00 2001 From: James Kachel Date: Wed, 27 Mar 2024 15:41:08 -0500 Subject: [PATCH 6/8] Clearing unused import --- system_meta/views.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/system_meta/views.py b/system_meta/views.py index 65d666b7..d5087045 100644 --- a/system_meta/views.py +++ b/system_meta/views.py @@ -3,7 +3,7 @@ import logging from django_filters.rest_framework import DjangoFilterBackend -from rest_framework import status, viewsets +from rest_framework import status from rest_framework.decorators import ( api_view, authentication_classes, From 161ac463f7ba8b7f2223f0473a0059e62c5ec414 Mon Sep 17 00:00:00 2001 From: James Kachel Date: Fri, 29 Mar 2024 14:54:12 -0500 Subject: [PATCH 7/8] Added tests for generate_test_data command, fixed some logic issues in there as well --- .../management/commands/generate_test_data.py | 71 +++++--- .../tests/generate_test_data_test.py | 165 ++++++++++++++++++ 2 files changed, 208 insertions(+), 28 deletions(-) create mode 100644 system_meta/management/tests/generate_test_data_test.py diff --git a/system_meta/management/commands/generate_test_data.py b/system_meta/management/commands/generate_test_data.py index 2a5d58a6..3d4bbf97 100644 --- a/system_meta/management/commands/generate_test_data.py +++ b/system_meta/management/commands/generate_test_data.py @@ -19,6 +19,12 @@ from system_meta.models import IntegratedSystem, Product +def get_input(text): + """Wrap the internal input function so we can test it later.""" + + return input(text) + + def fake_courseware_id(courseware_type: str, **kwargs) -> str: """ Generate a fake courseware id. @@ -76,7 +82,7 @@ def add_arguments(self, parser: CommandParser) -> None: parser.add_argument( "--only-systems", action="store_true", - help="Only add test systems.", + help="Only add test systems. Will add two active and one inactive system.", ) parser.add_argument( @@ -86,10 +92,10 @@ def add_arguments(self, parser: CommandParser) -> None: ) parser.add_argument( - "--system", + "--system-slug", type=str, help=( - "The name of the system to add products to." + "The slug of the system to add products to." " Only used with --only-products." ), nargs="?", @@ -99,26 +105,29 @@ def add_test_systems(self) -> None: """Add the test systems.""" max_systems = 3 for i in range(1, max_systems + 1): - IntegratedSystem.objects.create( + system = IntegratedSystem.objects.create( name=f"Test System {i}", description=f"Test System {i} description.", api_key=uuid.uuid4(), ) + self.stdout.write(f"Created system {system.name} - {system.slug}") - def add_test_products(self, system: str) -> None: + def add_test_products(self, system_slug: str) -> None: """Add the test products to the specified system.""" + self.stdout.write(f"Creating test products for {system_slug}") - if not IntegratedSystem.objects.filter(name=system).exists(): + if not IntegratedSystem.objects.filter(slug=system_slug).exists(): self.stdout.write( - self.style.ERROR(f"Integrated system {system} does not exist.") + self.style.ERROR(f"Integrated system {system_slug} does not exist.") ) return - system = IntegratedSystem.objects.get(name=system) + system = IntegratedSystem.objects.get(slug=system_slug) - for i in range(1, 4): + max_products = 3 + for i in range(1, max_products + 1): product_sku = fake_courseware_id("course", include_run_tag=True) - Product.objects.create( + product = Product.objects.create( name=f"Test Product {i}", description=f"Test Product {i} description.", sku=product_sku, @@ -129,6 +138,7 @@ def add_test_products(self, system: str) -> None: "program": fake_courseware_id("program"), }, ) + self.stdout.write(f"Created product {product.id} - {product.sku}") def remove_test_data(self) -> None: """Remove the test data.""" @@ -160,7 +170,7 @@ def remove_test_data(self) -> None: ) ) - if input("Type 'yes' to continue: ") != "yes": + if get_input("Type 'yes' to continue: ") != "yes": self.stdout.write(self.style.ERROR("Aborting.")) return @@ -177,7 +187,7 @@ def handle(self, *args, **options) -> None: # noqa: ARG002 remove = options["remove"] only_systems = options["only_systems"] only_products = options["only_products"] - systems = [options["system"]] if options["system"] else [] + systems = [] with transaction.atomic(): if remove: @@ -186,9 +196,20 @@ def handle(self, *args, **options) -> None: # noqa: ARG002 if not only_products: self.add_test_systems() + systems = [ + system.slug + for system in ( + IntegratedSystem.all_objects.filter( + name__startswith="Test System" + ).all() + ) + ] + + if only_systems: + return - if not only_systems: - if only_products and len(systems) == 0: + if only_products: + if not options["system_slug"] or len(options["system_slug"]) == 0: self.stdout.write( self.style.ERROR( "You must specify a system when using --only-products." @@ -196,21 +217,15 @@ def handle(self, *args, **options) -> None: # noqa: ARG002 ) return else: - systems = [ - system.name - for system in ( - IntegratedSystem.all_objects.filter( - name__startswith="Test System" - ).all() - ) - ] + systems = [options["system_slug"]] - [self.add_test_products(system) for system in systems] - return + self.stdout.write(f"we are creating products now {systems}") + + [self.add_test_products(system) for system in systems] if not only_products: - third_test_system = IntegratedSystem.all_objects.filter( + IntegratedSystem.all_objects.filter( name__startswith="Test System" - ).get() - third_test_system.is_active = False - third_test_system.save(update_fields=("is_active",)) + ).last().delete() + + return diff --git a/system_meta/management/tests/generate_test_data_test.py b/system_meta/management/tests/generate_test_data_test.py new file mode 100644 index 00000000..9729fa3a --- /dev/null +++ b/system_meta/management/tests/generate_test_data_test.py @@ -0,0 +1,165 @@ +"""Tests for the manage_product command""" +# ruff: noqa: W605 + +import re +from io import StringIO + +import faker +import pytest +from django.core.management import call_command + +from system_meta.factories import IntegratedSystemFactory, ProductFactory +from system_meta.management.commands.generate_test_data import fake_courseware_id +from system_meta.models import IntegratedSystem, Product + +pytestmark = pytest.mark.django_db +FAKE = faker.Factory.create() + + +@pytest.mark.parametrize( + ("include_run_tag", "program_or_course"), + [(True, True), (True, False), (False, True), (False, False)], +) +def test_generate_fake_courseware_id(include_run_tag, program_or_course): + """ + Tests that generate_fake_courseware_id generates a courseware id in the proper format. + """ + + courseware_id = fake_courseware_id( + "program" if program_or_course else "course", include_run_tag=include_run_tag + ) + + assert ( + courseware_id.startswith("program-v1:") + if program_or_course + else courseware_id.startswith("course-v1:") + ) + + if program_or_course: + if include_run_tag: + courseware_re = re.compile( + r"program-v1:(MITx|MITxT|edX|xPRO|Sample)\+\d{1,2}\.\d{3}\d?x\+\d{1,3}T\d{4}" + ) + else: + courseware_re = re.compile( + r"program-v1:(MITx|MITxT|edX|xPRO|Sample)\+\d{1,2}\.\d{3}\d?x" + ) + else: # noqa: PLR5501 + if include_run_tag: + courseware_re = re.compile( + r"course-v1:(MITx|MITxT|edX|xPRO|Sample)\+\d{1,2}\.\d{3}\d?x\+\d{1,3}T\d{4}" + ) + else: + courseware_re = re.compile( + r"course-v1:(MITx|MITxT|edX|xPRO|Sample)\+\d{1,2}\.\d{3}\d?x" + ) + + assert courseware_re.match(courseware_id) + + +def test_add_only_systems(): + """Test that only-systems adds test systems with no products.""" + out = StringIO() + + call_command("generate_test_data", only_systems=True, stdout=out) + + assert ( + IntegratedSystem.all_objects.filter(name__startswith="Test System").count() == 3 + ) + assert IntegratedSystem.objects.filter(name__startswith="Test System").count() == 2 + + +@pytest.mark.parametrize("system_type", ["normal", "skip", "wrong"]) +def test_add_only_products(system_type): + """ + Test that only-product adds products for the specified system. + + Args: + - system_type (str): type of system to try + normal is make a system normally (i.e. run the command correctly) + skip is don't make a system at all (should generate a specific error) + wrong is make a system, but specify something incorrect + """ + out = StringIO() + + if system_type == "skip": + call_command("generate_test_data", only_products=True, stdout=out) + + assert "You must specify a system" in out.getvalue() + elif system_type == "wrong": + bad_slug = "some nonsense" + call_command( + "generate_test_data", only_products=True, system_slug=bad_slug, stdout=out + ) + + assert f"Integrated system {bad_slug} does not exist" in out.getvalue() + else: + system = IntegratedSystemFactory.create() + + call_command( + "generate_test_data", + only_products=True, + system_slug=system.slug, + stdout=out, + ) + + assert "Created product" in out.getvalue() + + assert ( + Product.objects.filter( + name__startswith="Test Product", system=system + ).count() + == 3 + ) + + +def test_add_all(): + """Test that just running the command generates expected output.""" + + call_command("generate_test_data") + + assert ( + IntegratedSystem.all_objects.filter(name__startswith="Test System").count() == 3 + ) + assert IntegratedSystem.objects.filter(name__startswith="Test System").count() == 2 + + for system in IntegratedSystem.all_objects.all(): + assert ( + Product.all_objects.filter( + system=system, name__startswith="Test Product" + ).count() + == 3 + ) + + +def test_remove_test_data(mocker): + """Test that the remove test data command does the best it can to just remove test data.""" + + input_mock = mocker.patch( + "system_meta.management.commands.generate_test_data.get_input", + return_value="yes", + ) + out = StringIO() + + ProductFactory.create_batch(3) + before_system_count = IntegratedSystem.all_objects.count() + before_product_count = Product.all_objects.count() + + call_command("generate_test_data") + + assert before_system_count < IntegratedSystem.all_objects.count() + assert before_product_count < Product.all_objects.count() + + call_command("generate_test_data", "--remove", stdout=out) + + assert "Test data removed" in out.getvalue() + input_mock.assert_called() + + # Checking these two ways for a reason: + # - Make sure the test data wasn't soft deleted + # - Make sure it didn't soft-delete any of the products/systems we created + + assert IntegratedSystem.all_objects.count() == before_system_count + assert Product.all_objects.count() == before_product_count + assert IntegratedSystem.objects.count() == before_system_count + assert Product.objects.count() == before_product_count From 24488119b0012fd00e73cfbfa41f77c4b4f8a821 Mon Sep 17 00:00:00 2001 From: James Kachel Date: Fri, 29 Mar 2024 15:04:16 -0500 Subject: [PATCH 8/8] fixing integrated system deletion stuff --- system_meta/management/commands/generate_test_data.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/system_meta/management/commands/generate_test_data.py b/system_meta/management/commands/generate_test_data.py index 3d4bbf97..7931b28e 100644 --- a/system_meta/management/commands/generate_test_data.py +++ b/system_meta/management/commands/generate_test_data.py @@ -206,6 +206,9 @@ def handle(self, *args, **options) -> None: # noqa: ARG002 ] if only_systems: + IntegratedSystem.objects.filter( + name__startswith="Test System" + ).last().delete() return if only_products: @@ -224,7 +227,7 @@ def handle(self, *args, **options) -> None: # noqa: ARG002 [self.add_test_products(system) for system in systems] if not only_products: - IntegratedSystem.all_objects.filter( + IntegratedSystem.objects.filter( name__startswith="Test System" ).last().delete()