Skip to content

Commit

Permalink
Code refactoring, adding type annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
doumdi committed Nov 19, 2024
1 parent 0b1bf7f commit 7aacc3a
Show file tree
Hide file tree
Showing 16 changed files with 326 additions and 293 deletions.
2 changes: 1 addition & 1 deletion teraserver/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/opentera/OpenTeraServerVersion.py.in
add_custom_target(python-server-version SOURCES opentera/OpenTeraServerVersion.py.in)

# Will always be considered out of date...
add_custom_target(python-all DEPENDS python-env python-messages python-server-version python-translations-compile-only opentera/OpenTeraServerVersion.py)
add_custom_target(python-all ALL DEPENDS python-env python-messages python-server-version python-translations-compile-only opentera/OpenTeraServerVersion.py)

# Build this target if you want to update translations too...
add_custom_target(python-all-with-translations DEPENDS python-all python-translations)
Expand Down
10 changes: 5 additions & 5 deletions teraserver/python/modules/DatabaseModule/DBManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def start_cleanup_task(self) -> task:
return task.deferLater(reactor, seconds_to_midnight, self.cleanup_database)
# return task.deferLater(reactor, 5, self.cleanup_database)

def setup_events_for_2fa_sites(self):
def setup_events_for_2fa_sites(self) -> None:
"""
We need to validate that 2FA is enabled for all users in the site when the flag is set.
This can occur on multiple occasions : when the site is created, updated and also when user
Expand Down Expand Up @@ -279,22 +279,22 @@ def base_model_inserted(mapper, connection, target):
self.publish(event_message.header.topic, event_message.SerializeToString())

@staticmethod
def userAccess(user: TeraUser):
def userAccess(user: TeraUser) -> DBManagerTeraUserAccess:
access = DBManagerTeraUserAccess(user=user)
return access

@staticmethod
def deviceAccess(device: TeraDevice):
def deviceAccess(device: TeraDevice) -> DBManagerTeraDeviceAccess:
access = DBManagerTeraDeviceAccess(device=device)
return access

@staticmethod
def participantAccess(participant: TeraParticipant):
def participantAccess(participant: TeraParticipant) -> DBManagerTeraParticipantAccess:
access = DBManagerTeraParticipantAccess(participant=participant)
return access

@staticmethod
def serviceAccess(service: TeraService):
def serviceAccess(service: TeraService) -> DBManagerTeraServiceAccess:
access = DBManagerTeraServiceAccess(service=service)
return access

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from opentera.db.models.TeraProject import TeraProject
from opentera.db.models.TeraSessionType import TeraSessionType
from opentera.db.models.TeraTestType import TeraTestType
from opentera.db.models.TeraDevice import TeraDevice
from opentera.db.models.TeraSession import TeraSession
from opentera.db.models.TeraParticipant import TeraParticipant
from opentera.db.models.TeraAsset import TeraAsset
from opentera.db.models.TeraService import TeraService
from opentera.db.models.TeraServiceProject import TeraServiceProject

from sqlalchemy import func

Expand All @@ -26,7 +31,7 @@ def query_session(self, session_id: int) -> TeraSession | None:
# return sessions

def query_existing_session(self, session_name: str, session_type_id: int, session_date: datetime,
participant_uuids: list):
participant_uuids: list) -> TeraSession | None:
sessions = TeraSession.query.filter(TeraSession.id_creator_device == self.device.id_device).\
filter(TeraSession.session_name == session_name).filter(TeraSession.id_session_type == session_type_id).\
filter(func.date(TeraSession.session_start_datetime) == session_date.date()).\
Expand All @@ -39,26 +44,26 @@ def query_existing_session(self, session_name: str, session_type_id: int, sessio
return session
return None

def get_accessible_sessions(self):
def get_accessible_sessions(self) -> list[TeraSession]:
query = TeraSession.query.filter(TeraSession.id_creator_device == self.device.id_device)
return query.all()

def get_accessible_sessions_ids(self):
def get_accessible_sessions_ids(self) -> list[int]:
sessions = self.get_accessible_sessions()
return [session.id_session for session in sessions]

def get_accessible_participants(self, admin_only=False):
def get_accessible_participants(self, admin_only=False) -> list[TeraParticipant]:
return self.device.device_participants

def get_accessible_participants_ids(self, admin_only=False):
def get_accessible_participants_ids(self, admin_only=False) -> list[int]:
parts = []

for part in self.get_accessible_participants(admin_only=admin_only):
parts.append(part.id_participant)

return parts

def get_accessible_session_types(self):
def get_accessible_session_types(self) -> list[TeraSessionType]:

# participants = self.get_accessible_participants()
# project_list = []
Expand All @@ -72,14 +77,15 @@ def get_accessible_session_types(self):

return session_types

def get_accessible_session_types_ids(self):
def get_accessible_session_types_ids(self) -> list[int]:
types = []
for my_type in self.get_accessible_session_types():
types.append(my_type.id_session_type)
return types

def get_accessible_assets(self, id_asset: int = None, uuid_asset: str = None, session_id: int = None):
from opentera.db.models.TeraAsset import TeraAsset
def get_accessible_assets(self, id_asset: int = None, uuid_asset: str = None,
session_id: int = None) -> list[TeraAsset]:

query = TeraAsset.query.filter(TeraAsset.id_device == self.device.id_device)
if id_asset:
query = query.filter(TeraAsset.id_asset == id_asset)
Expand All @@ -90,9 +96,7 @@ def get_accessible_assets(self, id_asset: int = None, uuid_asset: str = None, se

return query.all()

def get_accessible_services(self):
from opentera.db.models.TeraService import TeraService
from opentera.db.models.TeraServiceProject import TeraServiceProject
def get_accessible_services(self) -> list[TeraService]:

accessible_projects_ids = [proj.id_project for proj in self.device.device_projects]

Expand All @@ -101,3 +105,12 @@ def get_accessible_services(self):

return query.all()

def get_accessible_tests_types(self) -> list[TeraTestType]:
accessible_projects_ids = [proj.id_project for proj in self.device.device_projects]
query = TeraTestType.query.join(TeraServiceProject).filter(
TeraServiceProject.id_project.in_(accessible_projects_ids)).group_by(TeraTestType.id_test_type)
return query.all()

def get_accessible_tests_types_ids(self) -> list[int]:
test_types = self.get_accessible_tests_types()
return [test_type.id_test_type for test_type in test_types]
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
from opentera.db.models.TeraDeviceParticipant import TeraDeviceParticipant
from opentera.db.models.TeraSessionType import TeraSessionType
from opentera.db.models.TeraSession import TeraSession
from opentera.db.models.TeraTestTypeProject import TeraTestTypeProject
from opentera.db.models.TeraServiceProject import TeraServiceProject
from opentera.db.models.TeraTestType import TeraTestType
from opentera.db.models.TeraService import TeraService
from opentera.db.models.TeraAsset import TeraAsset


from sqlalchemy import func

Expand All @@ -12,29 +18,14 @@ class DBManagerTeraParticipantAccess:
def __init__(self, participant: TeraParticipant):
self.participant = participant

# def query_session(self, limit: int = None, offset: int = None):
# Make sure you filter results with id_participant to return TeraDevices
# that are accessible by current participant
# query = TeraSession.query.filter_by(**filters).join(TeraSessionParticipants). \
# filter_by(id_participant=self.participant.id_participant)
# if limit:
# query = query.limit(limit)
#
# if offset:
# query = query.offset(offset)
#
# return query.all()

def query_device(self, filters: dict):
def query_device(self, filters: dict) -> list[TeraDevice]:
# Make sure you filter results with id_participant to return TeraDevices
# that are accessible by current participant
result = TeraDevice.query.filter_by(**filters).join(TeraDeviceParticipant).\
filter_by(id_participant=self.participant.id_participant).all()
return result

def get_accessible_assets(self, id_asset: int = None, uuid_asset: str = None, session_id: int = None):
from opentera.db.models.TeraAsset import TeraAsset

def get_accessible_assets(self, id_asset: int = None, uuid_asset: str = None, session_id: int = None) -> list[TeraAsset]:
# A participant can only have access to assets that are directly assigned to them (where id_participant is set
# to their value)
query = TeraAsset.query.filter(TeraAsset.id_participant == self.participant.id_participant)
Expand All @@ -47,26 +38,23 @@ def get_accessible_assets(self, id_asset: int = None, uuid_asset: str = None, se

return query.all()

def get_accessible_services(self):
from opentera.db.models.TeraServiceProject import TeraServiceProject
def get_accessible_services(self) -> list[TeraService]:
service_projects = TeraServiceProject.get_services_for_project(id_project=self.participant.id_project)

return [service_project.service_project_service for service_project in service_projects]

def get_accessible_session_types(self):
def get_accessible_session_types(self) -> list[TeraSessionType]:
session_types = TeraSessionType.query.join(TeraSessionType.session_type_projects)\
.filter_by(id_project=self.participant.id_project).all()

return session_types

def get_accessible_session_types_ids(self):
def get_accessible_session_types_ids(self) -> list[int]:
types = []
for my_type in self.get_accessible_session_types():
types.append(my_type.id_session_type)
return types

def query_existing_session(self, session_name: str, session_type_id: int, session_date: datetime,
participant_uuids: list):
participant_uuids: list) -> TeraSession | None:
sessions = TeraSession.query.filter(TeraSession.id_creator_participant == self.participant.id_participant).\
filter(TeraSession.session_name == session_name).filter(TeraSession.id_session_type == session_type_id).\
filter(func.date(TeraSession.session_start_datetime) == session_date.date()).\
Expand All @@ -79,10 +67,18 @@ def query_existing_session(self, session_name: str, session_type_id: int, sessio
return session
return None

def get_accessible_sessions(self):
def get_accessible_sessions(self) -> list[TeraSession]:
query = TeraSession.query.filter(TeraSession.id_creator_participant == self.participant.id_participant)
return query.all()

def get_accessible_sessions_ids(self):
def get_accessible_sessions_ids(self) -> list[int]:
sessions = self.get_accessible_sessions()
return [session.id_session for session in sessions]

def get_accessible_tests_types(self) -> list[TeraTestType]:
test_types = TeraTestType.query.join(TeraTestTypeProject).filter_by(id_project=self.participant.id_project).all()
return test_types

def get_accessible_tests_types_ids(self) -> list[int]:
test_types = self.get_accessible_tests_types()
return [test_type.id_test_type for test_type in test_types]
Loading

0 comments on commit 7aacc3a

Please sign in to comment.