diff --git a/backend/core/helpers.py b/backend/core/helpers.py index 4b4afd492..9653d254c 100644 --- a/backend/core/helpers.py +++ b/backend/core/helpers.py @@ -1175,3 +1175,94 @@ def handle(exc, context): exc = DRFValidationError(detail=data) return drf_exception_handler(exc, context) + + +def duplicate_related_objects( + source_object: models.Model, + duplicate_object: models.Model, + target_folder: Folder, + field_name: str, +): + """ + Duplicates related objects from a source object to a duplicate object, avoiding duplicates in the target folder. + + Parameters: + - source_object (object): The source object containing related objects to duplicate. + - duplicate_object (object): The object where duplicated objects will be linked. + - target_folder (Folder): The folder where duplicated objects will be stored. + - field_name (str): The field name representing the related objects in the source + """ + + def process_related_object( + obj, + duplicate_object, + target_folder, + target_parent_folders, + sub_folders, + field_name, + model_class, + ): + """ + Process a single related object: add, link, or duplicate it based on folder and existence checks. + """ + + # Check if the object already exists in the target folder + existing_obj = get_existing_object(obj, target_folder, model_class) + + if existing_obj: + # If the object exists in the target folder, link it to the duplicate object + link_existing_object(duplicate_object, existing_obj, field_name) + + elif obj.folder in target_parent_folders and obj.is_published: + # If the object's folder is a parent and it's published, link it + link_existing_object(duplicate_object, obj, field_name) + + elif obj.folder in sub_folders: + # If the object's folder is a subfolder of the target folder, link it + link_existing_object(duplicate_object, obj, field_name) + + else: + # Otherwise, duplicate the object and link it + duplicate_and_link_object(obj, duplicate_object, target_folder, field_name) + + def get_existing_object(obj, target_folder, model_class): + """ + Check if an object with the same name already exists in the target folder. + """ + return model_class.objects.filter(name=obj.name, folder=target_folder).first() + + def link_existing_object(duplicate_object, existing_obj, field_name): + """ + Link an existing object to the duplicate object by adding it to the related field. + """ + getattr(duplicate_object, field_name).add(existing_obj) + + def duplicate_and_link_object(new_obj, duplicate_object, target_folder, field_name): + """ + Duplicate an object and link it to the duplicate object. + """ + new_obj.pk = None + new_obj.folder = target_folder + new_obj.save() + link_existing_object(duplicate_object, new_obj, field_name) + + model_class = getattr(type(source_object), field_name).field.related_model + + # Get parent and sub-folders of the target folder + target_parent_folders = target_folder.get_parent_folders() + sub_folders = target_folder.sub_folders() + + # Get all related objects for the specified field + related_objects = getattr(source_object, field_name).all() + + # Process each related object + for obj in related_objects: + process_related_object( + obj, + duplicate_object, + target_folder, + target_parent_folders, + sub_folders, + field_name, + model_class, + ) diff --git a/backend/core/views.py b/backend/core/views.py index 59bd040c7..2d290e56e 100644 --- a/backend/core/views.py +++ b/backend/core/views.py @@ -598,49 +598,6 @@ def treatment_plan_pdf(self, request, pk): serializer_class=RiskAssessmentDuplicateSerializer, ) def duplicate(self, request, pk): - def duplicate_related_objects( - scenario, duplicate_scenario, target_folder, field_name, model_class - ): - """ - Duplicates related objects (e.g., controls, threats, assets) from a source scenario to a duplicate scenario, - ensuring that objects are not duplicated if they already exist in the/a target/parent domain (folder). - - Parameters: - - scenario (object): The source scenario containing the related objects to be duplicated. - - duplicate_scenario (object): The duplicate scenario where the objects will be added. - - target_folder (object): The target folder where the duplicated objects will be stored. - - field_name (str): The field name representing the related objects to duplicate in the scenario. - - model_class (class): The model class of the related objects to be processed. - """ - - # Get parent folders of the target folder - target_parent_folders = target_folder.get_parent_folders() - - # Fetch all related objects for the given field name - related_objects = getattr(scenario, field_name).all() - - for obj in related_objects: - # Check if an object with the same name already exists in the target folder - existing_obj = model_class.objects.filter( - name=obj.name, folder=target_folder - ).first() - - if existing_obj: - # If the object already exists in the targer folder, add the existing one to the duplicate scenario - getattr(duplicate_scenario, field_name).add(existing_obj) - - elif obj.folder in target_parent_folders: - # If the object's folder is a parent of the targert folder, add the object to the duplicate scenario - getattr(duplicate_scenario, field_name).add(obj) - - else: - # If the object doesn't exist, duplicate the object - duplicate_obj = obj - duplicate_obj.pk = None - duplicate_obj.folder = target_folder - duplicate_obj.save() - getattr(duplicate_scenario, field_name).add(duplicate_obj) - (object_ids_view, _, _) = RoleAssignment.get_accessible_object_ids( Folder.get_root_folder(), request.user, RiskAssessment ) @@ -679,27 +636,13 @@ def duplicate_related_objects( justification=scenario.justification, ) - duplicate_related_objects( - scenario, - duplicate_scenario, - duplicate_risk_assessment.project.folder, - "applied_controls", - AppliedControl, - ) - duplicate_related_objects( - scenario, - duplicate_scenario, - duplicate_risk_assessment.project.folder, - "threats", - Threat, - ) - duplicate_related_objects( - scenario, - duplicate_scenario, - duplicate_risk_assessment.project.folder, - "assets", - Asset, - ) + for field in ["applied_controls", "threats", "assets"]: + duplicate_related_objects( + scenario, + duplicate_scenario, + duplicate_risk_assessment.project.folder, + field, + ) if ( duplicate_risk_assessment.project.folder