diff --git a/registrar/apps/api/v1/tests/test_views.py b/registrar/apps/api/v1/tests/test_views.py index 75b1f2f4..20c3dc75 100644 --- a/registrar/apps/api/v1/tests/test_views.py +++ b/registrar/apps/api/v1/tests/test_views.py @@ -1,4 +1,5 @@ """ Tests for API views. """ +from contextlib import contextmanager import csv import json import logging @@ -6,6 +7,8 @@ from io import StringIO from posixpath import join as urljoin +from waffle import get_waffle_flag_model + import boto3 import ddt import mock @@ -70,6 +73,19 @@ INACTIVE_CURRICULUM_UUID = '66666666-4444-2222-1111-000000000000' +@contextmanager +def get_waffle_flag(flag_name, is_active, groups): + waffle_model = get_waffle_flag_model() + waffle_flag = waffle_model.objects.create(name=flag_name) + for group in groups: + waffle_flag.groups.add(group) + waffle_flag.everyone = is_active + waffle_flag.save() + waffle_flag.flush() + yield waffle_flag + waffle_flag.delete() + + class RegistrarAPITestCase(TrackTestMixin, APITestCase): """ Base for tests of the Registrar API """ @@ -1497,10 +1513,11 @@ def get_url(self, program_key=None, course_id=None): def mock_course_enrollments_response(self, method, expected_response, response_code=200): self.mock_api_response(self.lms_request_url, expected_response, method=method, response_code=response_code) - def student_course_enrollment(self, status, student_key=None): + def student_course_enrollment(self, status, student_key=None, course_staff=None): return { 'status': status, - 'student_key': student_key or uuid.uuid4().hex[0:10] + 'student_key': student_key or uuid.uuid4().hex[0:10], + 'course_staff': course_staff } def test_program_unauthorized_at_organization(self): @@ -1618,14 +1635,66 @@ def test_successful_program_course_enrollment_write(self, use_external_course_ke { 'status': 'active', 'student_key': '001', + 'course_staff': None, + }, + { + 'status': 'active', + 'student_key': '002', + 'course_staff': None, + }, + { + 'status': 'inactive', + 'student_key': '003', + 'course_staff': None, + } + ]) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.data, expected_lms_response) + + @mock_oauth_login + @responses.activate + @ddt.data(False, True) + def test_successful_program_course_enrollment_write_with_course_staff(self, use_external_course_key): + course_id = self.external_course_key if use_external_course_key else self.course_id + expected_lms_response = { + '001': 'active', + '002': 'active', + '003': 'inactive' + } + self.mock_course_enrollments_response(self.method, expected_lms_response) + + req_data = [ + self.student_course_enrollment('active', '001', True), + self.student_course_enrollment('active', '002', False), + self.student_course_enrollment('inactive', '003', True), + ] + + with self.assert_tracking( + user=self.stem_admin, + program_key=self.cs_program.key, + course_id=course_id, + ): + with get_waffle_flag('enable_course_role_management', True, [self.stem_admin_group]) as overrider: + response = self.request( + self.method, self.get_url(course_id=course_id), self.stem_admin, req_data + ) + + lms_request_body = json.loads(responses.calls[-1].request.body.decode('utf-8')) + self.assertCountEqual(lms_request_body, [ + { + 'status': 'active', + 'student_key': '001', + 'course_staff': True, }, { 'status': 'active', 'student_key': '002', + 'course_staff': False, }, { 'status': 'inactive', 'student_key': '003', + 'course_staff': True, } ]) self.assertEqual(response.status_code, 200)