Skip to content

Commit

Permalink
Created DBManagerXAccessTests, code refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
doumdi committed Nov 19, 2024
1 parent 7aacc3a commit ac7fd41
Show file tree
Hide file tree
Showing 12 changed files with 181 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_accessibles_sites(self) -> list[TeraSite]:
site_list.append(service_site.service_site_site)
return site_list

def get_accessibles_sites_ids(self) -> list[int]:
def get_accessible_sites_ids(self) -> list[int]:
return [site.id_site for site in self.get_accessibles_sites()]

def get_accessible_participants(self, admin_only=False) -> list[TeraParticipant]:
Expand Down Expand Up @@ -254,7 +254,7 @@ def query_sites_for_user(self, user_id: int, admin_only: bool = False) -> list[T
if user_id in self.get_accessible_users_ids():

user = TeraUser.get_user_by_id(id_user=user_id)
acc_sites_ids = self.get_accessibles_sites_ids()
acc_sites_ids = self.get_accessible_sites_ids()
if user.user_superadmin:
sites = TeraSite.query.order_by(TeraSite.site_name.asc()).all()
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get(self):
projects = [TeraProject.get_project_by_id(args['id_project'])]

if args['id_site']:
if args['id_site'] not in service_access.get_accessibles_sites_ids():
if args['id_site'] not in service_access.get_accessible_sites_ids():
return gettext('Forbidden'), 403
projects = TeraSite.get_site_by_id(args['id_site']).site_projects

Expand Down Expand Up @@ -101,7 +101,7 @@ def post(self):
if json_project['id_project'] == 0 and 'id_site' not in json_project:
return gettext('Missing id_site arguments'), 400

if 'id_site' in json_project and json_project['id_site'] not in service_access.get_accessibles_sites_ids():
if 'id_site' in json_project and json_project['id_site'] not in service_access.get_accessible_sites_ids():
return gettext('Forbidden'), 403

# Do the update!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get(self):

session_types = []
if args['id_site']:
if args['id_site'] in service_access.get_accessibles_sites_ids():
if args['id_site'] in service_access.get_accessible_sites_ids():
session_types = [st.session_type_site_session_type for st in
TeraSessionTypeSite.get_sessions_types_for_site(args['id_site'])]
elif args['id_project']:
Expand Down Expand Up @@ -122,7 +122,7 @@ def post(self):
# STEP 1) Verify access before doing anything
if 'id_sites' in service_session_type_info and \
isinstance(service_session_type_info['id_sites'], list):
accessible_sites = service_access.get_accessibles_sites_ids()
accessible_sites = service_access.get_accessible_sites_ids()
for id_site in service_session_type_info['id_sites']:
if id_site not in accessible_sites:
return gettext('Service doesn\'t have access to all listed sites'), 403
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get(self):

sites = []
if args['id_site']:
if args['id_site'] not in service_access.get_accessibles_sites_ids():
if args['id_site'] not in service_access.get_accessible_sites_ids():
return gettext('Forbidden'), 403
sites = [TeraSite.get_site_by_id(args['id_site'])]
elif args['id_user']:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get(self):

test_types = []
if args['id_site']:
if args['id_site'] in service_access.get_accessibles_sites_ids():
if args['id_site'] in service_access.get_accessible_sites_ids():
test_types = [tt.test_type_site_test_type for tt in
TeraTestTypeSite.get_tests_types_for_site(args['id_site'])]
elif args['id_project']:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get(self):
return gettext('Forbidden'), 403
user_groups = service_access.query_usergroups_for_project(args['id_project'])
elif args['id_site']:
if args['id_site'] not in service_access.get_accessibles_sites_ids():
if args['id_site'] not in service_access.get_accessible_sites_ids():
return gettext('Forbidden'), 403
user_groups = service_access.query_usergroups_for_site(args['id_site'])
else:
Expand Down Expand Up @@ -97,7 +97,7 @@ def post(self):
if 'user_group_services_access' in json_user_group:
json_access = json_user_group.pop('user_group_services_access')
accessible_projects_ids = service_access.get_accessible_projects_ids()
accessible_sites_ids = service_access.get_accessibles_sites_ids()
accessible_sites_ids = service_access.get_accessible_sites_ids()

# Check if access only are for the current service
for access in json_access:
Expand Down Expand Up @@ -233,4 +233,3 @@ def delete(self):
return gettext('Database error'), 500

return '', 200

Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from modules.DatabaseModule.DBManager import DBManager
from modules.DatabaseModule.DBManagerTeraDeviceAccess import DBManagerTeraDeviceAccess
from opentera.db.models.TeraUser import TeraUser
from opentera.db.models.TeraParticipant import TeraParticipant
from opentera.db.models.TeraParticipantGroup import TeraParticipantGroup
from opentera.db.models.TeraService import TeraService
from opentera.db.models.TeraDevice import TeraDevice
from opentera.db.models.TeraProject import TeraProject
from opentera.db.models.TeraSite import TeraSite
from opentera.db.models.TeraSession import TeraSession
from tests.opentera.db.models.BaseModelsTest import BaseModelsTest


class DBManagerTeraDeviceAccessTest(BaseModelsTest):
# TODO - Implement tests
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from modules.DatabaseModule.DBManager import DBManager
from opentera.db.models.TeraUser import TeraUser
from tests.opentera.db.models.BaseModelsTest import BaseModelsTest
from modules.DatabaseModule.DBManagerTeraParticipantAccess import DBManagerTeraParticipantAccess

class DBManagerTeraParticipantAccessTest(BaseModelsTest):
# TODO - Implement tests
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from modules.DatabaseModule.DBManager import DBManager
from modules.DatabaseModule.DBManagerTeraServiceAccess import DBManagerTeraServiceAccess
from opentera.db.models.TeraUser import TeraUser
from opentera.db.models.TeraParticipant import TeraParticipant
from opentera.db.models.TeraParticipantGroup import TeraParticipantGroup
from opentera.db.models.TeraService import TeraService
from opentera.db.models.TeraDevice import TeraDevice
from opentera.db.models.TeraProject import TeraProject
from opentera.db.models.TeraSite import TeraSite
from opentera.db.models.TeraSession import TeraSession
from tests.opentera.db.models.BaseModelsTest import BaseModelsTest


class DBManagerTeraServiceAccessTest(BaseModelsTest):

def test_video_rehab_service_get_accessible_devices_ids_and_uuids(self):
"""
This will test at the same time get_accessible_devices and get_accessible_devices_ids/uuids.
"""
with self._flask_app.app_context():
service = TeraService.get_service_by_key('VideoRehabService')
self.assertIsNotNone(service)
service_access = DBManager.serviceAccess(service)
devices_ids : set[int] = set(service_access.get_accessible_devices_ids())
devices_uuids : set[str] = set(service_access.get_accessible_devices_uuids())

all_devices = TeraDevice.query.all()
accessible_devices = set()
for device in all_devices:
for project in device.device_projects:
if project.id_project in service_access.get_accessible_projects_ids():
accessible_devices.add(device.id_device)

self.assertEqual(len(devices_ids), len(accessible_devices))
self.assertEqual(len(devices_uuids), len(devices_ids))
self.assertEqual(devices_ids, accessible_devices)

def test_video_rehab_service_get_accessible_projects_ids(self):
"""
This will test at the same time get_accessible_projects and get_accessible_projects_ids.
"""
with self._flask_app.app_context():
service = TeraService.get_service_by_key('VideoRehabService')
self.assertIsNotNone(service)
service_access = DBManager.serviceAccess(service)
projects_ids = set(service_access.get_accessible_projects_ids())

all_projects: list[TeraProject] = TeraProject.query.all()
accessible_projects = set()
for project in all_projects:
for project_service in project.project_services:
if project_service.id_service == service.id_service:
accessible_projects.add(project.id_project)
self.assertEqual(len(projects_ids), len(accessible_projects))
self.assertEqual(projects_ids, accessible_projects)

def test_video_rehab_service_get_accessible_sessions_ids(self):
"""
This will test at the same time get_accessible_sessions and get_accessible_sessions_ids.
"""
with self._flask_app.app_context():
service = TeraService.get_service_by_key('VideoRehabService')
self.assertIsNotNone(service)
service_access = DBManager.serviceAccess(service)
sessions_ids = set(service_access.get_accessible_sessions_ids())

all_sessions = TeraSession.query.all()
accessible_sessions = set()

for session in all_sessions:
# Creator
if session.id_creator_service == service.id_service:
accessible_sessions.add(session.id_session)
# Participants
for participant in session.session_participants:
if participant.id_participant in service_access.get_accessible_participants_ids():
accessible_sessions.add(session.id_session)
# Users
for user in session.session_users:
if user.id_user in service_access.get_accessible_users_ids():
accessible_sessions.add(session.id_session)
# Devices
for device in session.session_devices:
if device.id_device in service_access.get_accessible_devices_ids():
accessible_sessions.add(session.id_session)

self.assertEqual(len(sessions_ids), len(accessible_sessions))
self.assertEqual(sessions_ids, accessible_sessions)

def test_video_rehab_service_get_accessible_sites_id(self):
"""
This will test at the same time get_accessible_sites and get_accessible_sites_ids.
"""
with self._flask_app.app_context():
service = TeraService.get_service_by_key('VideoRehabService')
self.assertIsNotNone(service)
service_access = DBManager.serviceAccess(service)
sites_ids = set(service_access.get_accessible_sites_ids())

all_sites = TeraSite.query.all()
accessible_sites = set()
for site in all_sites:
for project in site.site_projects:
if project.id_project in service_access.get_accessible_projects_ids():
accessible_sites.add(site.id_site)
self.assertEqual(len(sites_ids), len(accessible_sites))
self.assertEqual(sites_ids, accessible_sites)

def test_video_rehab_service_get_accessible_participants_id(self):
"""
This will test at the same time get_accessible_participants and get_accessible_participants_ids.
"""
with self._flask_app.app_context():
service = TeraService.get_service_by_key('VideoRehabService')
self.assertIsNotNone(service)
service_access = DBManager.serviceAccess(service)
participants_ids = set(service_access.get_accessible_participants_ids())

all_participants = TeraParticipant.query.all()
accessible_participants = set()
for participant in all_participants:
if participant.id_project in service_access.get_accessible_projects_ids():
accessible_participants.add(participant.id_participant)
self.assertEqual(len(participants_ids), len(accessible_participants))
self.assertEqual(participants_ids, accessible_participants)

def test_video_rehab_service_get_accessible_participant_groups_id(self):
"""
This will test at the same time get_accessible_participant_groups and get_accessible_participant_groups_ids.
"""
with self._flask_app.app_context():
service = TeraService.get_service_by_key('VideoRehabService')
self.assertIsNotNone(service)
service_access = DBManager.serviceAccess(service)
participant_groups_ids = set(service_access.get_accessible_participants_groups_ids())

all_participant_groups = TeraParticipantGroup.query.all()
accessible_participant_groups = set()
for participant_group in all_participant_groups:
if participant_group.id_project in service_access.get_accessible_projects_ids():
accessible_participant_groups.add(participant_group.id_participant_group)
self.assertEqual(len(participant_groups_ids), len(accessible_participant_groups))
self.assertEqual(participant_groups_ids, accessible_participant_groups)
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from modules.DatabaseModule.DBManager import DBManager
from opentera.db.models.TeraUser import TeraUser
from tests.opentera.db.models.BaseModelsTest import BaseModelsTest

from modules.DatabaseModule.DBManagerTeraUserAccess import DBManagerTeraUserAccess

class DBManagerTeraUserAccessTest(BaseModelsTest):

def setUp(self):
super().setUp()

def test_admin_get_accessible_users_ids(self):
with self._flask_app.app_context():
admin_user = TeraUser.get_user_by_username('admin')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_get_endpoint_with_token_auth_no_params(self):
service: TeraService = TeraService.get_service_by_uuid(self.service_uuid)
from modules.DatabaseModule.DBManager import DBManager
service_access = DBManager.serviceAccess(service)
accessible_sites = service_access.get_accessibles_sites_ids()
accessible_sites = service_access.get_accessible_sites_ids()
for site_json in response.json:
id_site = site_json['id_site']
if id_site in accessible_sites:
Expand All @@ -50,7 +50,7 @@ def test_get_endpoint_with_token_auth_and_id_site(self):
service: TeraService = TeraService.get_service_by_uuid(self.service_uuid)
from modules.DatabaseModule.DBManager import DBManager
service_access = DBManager.serviceAccess(service)
accessible_sites = service_access.get_accessibles_sites_ids()
accessible_sites = service_access.get_accessible_sites_ids()

params = {'id_site': site.id_site}
response = self._get_with_service_token_auth(client=self.test_client, token=self.service_token,
Expand All @@ -72,7 +72,7 @@ def test_get_endpoint_with_token_auth_and_id_site(self):
service: TeraService = TeraService.get_service_by_uuid(self.service_uuid)
from modules.DatabaseModule.DBManager import DBManager
service_access = DBManager.serviceAccess(service)
accessible_sites = service_access.get_accessibles_sites_ids()
accessible_sites = service_access.get_accessible_sites_ids()

params = {'id_user': user.id_user}
response = self._get_with_service_token_auth(client=self.test_client, token=self.service_token,
Expand Down

0 comments on commit ac7fd41

Please sign in to comment.