diff --git a/coldfront/core/project/utils.py b/coldfront/core/project/utils.py index cb4179a10..886740bc2 100644 --- a/coldfront/core/project/utils.py +++ b/coldfront/core/project/utils.py @@ -7,6 +7,7 @@ from coldfront.core.allocation.utils_.accounting_utils import set_service_units from coldfront.core.project.models import Project from coldfront.core.project.models import ProjectStatusChoice +from coldfront.core.project.models import ProjectUserRoleChoice from coldfront.core.resource.utils import get_compute_resource_names from coldfront.core.resource.utils import get_primary_compute_resource_name from coldfront.core.utils.common import display_time_zone_current_date @@ -220,3 +221,16 @@ def is_primary_cluster_project(project): project_compute_resource_name = get_project_compute_resource_name(project) primary_cluster_resource_name = get_primary_compute_resource_name() return project_compute_resource_name == primary_cluster_resource_name + + +def higher_project_user_role(role_1, role_2): + """Given two ProjectUserRoleChoices, return the "higher" (more + privileged) of the two.""" + assert isinstance(role_1, ProjectUserRoleChoice) + assert isinstance(role_2, ProjectUserRoleChoice) + roles_ascending = ['User', 'Manager', 'Principal Investigator'] + assert role_1.name in roles_ascending + assert role_2.name in roles_ascending + role_1_index = roles_ascending.index(role_1.name) + role_2_index = roles_ascending.index(role_2.name) + return role_1 if role_1_index >= role_2_index else role_2 diff --git a/coldfront/core/user/management/commands/merge_users.py b/coldfront/core/user/management/commands/merge_users.py new file mode 100644 index 000000000..9b13c098b --- /dev/null +++ b/coldfront/core/user/management/commands/merge_users.py @@ -0,0 +1,94 @@ +import logging +import sys + +from django.contrib.auth.models import User +from django.core.management.base import BaseCommand + +from coldfront.core.user.utils_.merge_users import UserMergeRunner +from coldfront.core.utils.common import add_argparse_dry_run_argument +from coldfront.core.utils.reporting.report_message_strategy import EnqueueForLoggingStrategy +from coldfront.core.utils.reporting.report_message_strategy import WriteViaCommandStrategy + + +"""An admin command that merges two Users into one.""" + + +class Command(BaseCommand): + + help = ( + 'Merge two Users into one. The command chooses one instance, transfers ' + 'that instance\'s relationships, requests, etc. to the other, and then ' + 'deletes it.') + + logger = logging.getLogger(__name__) + + def add_arguments(self, parser): + add_argparse_dry_run_argument(parser) + parser.add_argument( + 'username_1', help='The username of the first user.', type=str) + parser.add_argument( + 'username_2', help='The username of the second user.', type=str) + + def handle(self, *args, **options): + dry_run = options['dry_run'] + if not dry_run: + user_confirmation = input( + 'Are you sure you wish to proceed? [Y/y/N/n]: ') + if user_confirmation.strip().lower() != 'y': + self.stdout.write(self.style.WARNING('Merge aborted.')) + sys.exit(0) + + username_1 = options['username_1'] + try: + user_1 = User.objects.get(username=username_1) + except User.DoesNotExist: + self.stderr.write( + self.style.ERROR(f'User "{username_1}" does not exist.')) + return + + username_2 = options['username_2'] + try: + user_2 = User.objects.get(username=username_2) + except User.DoesNotExist: + self.stderr.write( + self.style.ERROR(f'User "{username_2}" does not exist.')) + return + + write_via_command_strategy = WriteViaCommandStrategy(self) + enqueue_for_logging_strategy = EnqueueForLoggingStrategy(self.logger) + reporting_strategies = [ + write_via_command_strategy, enqueue_for_logging_strategy] + + user_merge_runner = UserMergeRunner( + user_1, user_2, reporting_strategies=reporting_strategies) + + src_user = user_merge_runner.src_user + src_user_str = ( + f'{src_user.username} ({src_user.pk}, {src_user.first_name} ' + f'{src_user.last_name})') + dst_user = user_merge_runner.dst_user + dst_user_str = ( + f'{dst_user.username} ({dst_user.pk}, {dst_user.first_name} ' + f'{dst_user.last_name})') + + self.stdout.write(self.style.WARNING(f'Source: {src_user_str}')) + self.stdout.write(self.style.WARNING(f'Destination: {dst_user_str}')) + + if dry_run: + user_merge_runner.dry_run() + self.stdout.write(self.style.WARNING('Dry run of merge complete.')) + else: + enqueue_for_logging_strategy.warning( + f'Initiating a merge of source User {src_user_str} into ' + f'destination User {dst_user_str}.') + try: + user_merge_runner.run() + except Exception as e: + # TODO + pass + else: + self.stdout.write(self.style.SUCCESS('Merge complete.')) + enqueue_for_logging_strategy.success( + f'Successfully merged source User {src_user_str} into ' + f'destination User {dst_user_str}.') + enqueue_for_logging_strategy.log_queued_messages() diff --git a/coldfront/core/user/utils_/merge_users/__init__.py b/coldfront/core/user/utils_/merge_users/__init__.py new file mode 100644 index 000000000..f9bf3c137 --- /dev/null +++ b/coldfront/core/user/utils_/merge_users/__init__.py @@ -0,0 +1,6 @@ +from coldfront.core.user.utils_.merge_users.runner import UserMergeRunner + + +__all__ = [ + 'UserMergeRunner', +] diff --git a/coldfront/core/user/utils_/merge_users/class_handlers.py b/coldfront/core/user/utils_/merge_users/class_handlers.py new file mode 100644 index 000000000..6e3e36a0a --- /dev/null +++ b/coldfront/core/user/utils_/merge_users/class_handlers.py @@ -0,0 +1,329 @@ +import inspect + +from abc import ABC +from abc import abstractmethod + +from django.core.exceptions import ObjectDoesNotExist +from django.db import transaction + +from flags.state import flag_enabled + +from coldfront.core.allocation.models import AllocationUser +from coldfront.core.allocation.models import AllocationUserStatusChoice +from coldfront.core.project.models import ProjectUser +from coldfront.core.project.models import ProjectUserStatusChoice +from coldfront.core.project.utils import higher_project_user_role + + +class ClassHandlerFactory(object): + """A factory for returning a concrete instance of ClassHandler for a + particular class.""" + + def get_handler(self, klass, *args, **kwargs): + """Return an instantiated handler for the given class with the + given arguments and keyword arguments.""" + assert inspect.isclass(klass) + return self._get_handler_class(klass)(*args, **kwargs) + + @staticmethod + def _get_handler_class(klass): + """Return the appropriate handler class for the given class. If + none are applicable, raise a ValueError.""" + handler_class_name = f'{klass.__name__}Handler' + try: + return globals()[handler_class_name] + except KeyError: + raise ValueError(f'No handler for class {klass.__name__}.') + + +class ClassHandler(ABC): + """A class that handles transferring data from a source object of a + particular class, and which belongs to a source user, to a + destination user, when merging User accounts.""" + + @abstractmethod + def __init__(self, src_user, dst_user, src_obj, reporting_strategies=None): + self._src_user = src_user + self._dst_user = dst_user + self._src_obj = src_obj + # A corresponding object may or may not exist for the destination User. + # Attempt to retrieve it in each concrete child class. + self._dst_obj = None + + self._class_name = self._src_obj.__class__.__name__ + + # Report messages using each of the given strategies. + self._reporting_strategies = [] + if isinstance(reporting_strategies, list): + for strategy in reporting_strategies: + self._reporting_strategies.append(strategy) + + def run(self): + """Transfer the source object from the source user to the + destination user.""" + with transaction.atomic(): + if self._dst_obj: + self._set_falsy_attrs() + self._run_special_handling() + if self._dst_obj: + self._dst_obj.save() + + def _get_settable_if_falsy_attrs(self): + """Return a list of attributes that, if falsy in the + destination, should be updated to the value of the corresponding + attribute in the source object.""" + return [] + + def _handle_associated_object(self, transferred=False): + """An object B may be associated with a User through a different + object A. When A is deleted, B may be deleted with it. When A is + transferred, B is transferred with it. Record that this has + occurred.""" + if not transferred: + # The object was deleted. + message = ( + f'{self._class_name}({self._src_obj.pk}): indirectly deleted') + self._report_success_message(message) + else: + # The object was transferred to the destination user. + self._record_update( + self._src_obj.pk, 'user (indirectly associated)', + self._src_user, self._dst_user) + + def _report_success_message(self, message): + """Record a success message with the given text to each of the + reporting strategies.""" + for strategy in self._reporting_strategies: + strategy.success(message) + + def _record_update(self, pk, attr_name, pre_value, post_value): + """Record that the object of this class and with the given + primary key had its attribute with the given name updated from + pre_value to post_value.""" + message = ( + f'{self._class_name}({pk}).{attr_name}: {pre_value} --> ' + f'{post_value}') + self._report_success_message(message) + + def _run_special_handling(self): + """Run handling specific to a particular class, implemented by + each child class.""" + raise NotImplementedError + + def _set_attr_if_falsy(self, attr_name): + """If the attribute with the given name is falsy in the + destination object but not in the source object, update the + former's value to the latter's.""" + assert hasattr(self._src_obj, attr_name) + assert hasattr(self._dst_obj, attr_name) + src_attr = getattr(self._src_obj, attr_name) + dst_attr = getattr(self._dst_obj, attr_name) + if src_attr and not dst_attr: + setattr(self._dst_obj, attr_name, src_attr) + self._record_update(self._dst_obj.pk, attr_name, dst_attr, src_attr) + + def _set_falsy_attrs(self): + """TODO""" + for attr_name in self._get_settable_if_falsy_attrs(): + self._set_attr_if_falsy(attr_name) + self._dst_obj.save() + + def _transfer_src_obj_to_dst_user(self, attr_name='user'): + """TODO""" + setattr(self._src_obj, attr_name, self._dst_user) + self._src_obj.save() + self._record_update( + self._src_obj.pk, attr_name, self._src_user, self._dst_user) + + +class UserProfileHandler(ClassHandler): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._dst_obj = self._dst_user.userprofile + + def _get_settable_if_falsy_attrs(self): + return [ + 'is_pi', + # Only the destination user should have a cluster UID. + # 'cluster_uid', + 'phone_number', + 'access_agreement_signed_date', + 'billing_activity', + ] + + def _run_special_handling(self): + self._set_host_user() + + def _set_host_user(self): + if flag_enabled('LRC_ONLY'): + raise NotImplementedError + # TODO + # Deal with conflicts. + # Handle LBL users. + # self._set_attr_if_falsy('host_user') + + +class SocialAccountHandler(ClassHandler): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _run_special_handling(self): + self._transfer_src_obj_to_dst_user() + + +class EmailAddressHandler(ClassHandler): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _run_special_handling(self): + self._src_obj.primary = False + self._transfer_src_obj_to_dst_user() + + +class AllocationUserHandler(ClassHandler): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + try: + self._dst_obj = AllocationUser.objects.get( + allocation=self._src_obj.allocation, user=self._dst_user) + except ObjectDoesNotExist: + self._dst_obj = None + + def _run_special_handling(self): + allocation = self._src_obj.allocation + + # TODO: Note that only compute Allocations are handled for now. + + assert allocation.resources.count() == 1 + resource = allocation.resources.first() + assert resource.name.endswith(' Compute') + + if self._dst_obj: + status_updated = self._update_status() + if status_updated: + self._dst_obj.save() + else: + self._transfer_src_obj_to_dst_user() + + def _update_status(self): + """Update the status of the destination if it is not "Active" + but the source's is. Return whether an update occurred.""" + active_allocation_user_status = \ + AllocationUserStatusChoice.objects.get(name='Active') + dst_obj_status = self._dst_obj.status + if (dst_obj_status != active_allocation_user_status and + self._src_obj.status == active_allocation_user_status): + self._dst_obj.status = self._src_obj.status + self._record_update( + self._dst_obj.pk, 'status', dst_obj_status.name, + 'Active') + return True + return False + + +class AllocationUserAttributeHandler(ClassHandler): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _run_special_handling(self): + transferred = self._src_obj.allocation_user.user == self._dst_user + self._handle_associated_object(transferred=transferred) + + +class AllocationUserAttributeUsageHandler(ClassHandler): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _run_special_handling(self): + transferred = ( + self._src_obj.allocation_user_attribute.allocation_user.user == + self._dst_user) + self._handle_associated_object(transferred=transferred) + + +class ClusterAccessRequestHandler(ClassHandler): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _run_special_handling(self): + transferred = self._src_obj.allocation_user.user == self._dst_user + self._handle_associated_object(transferred=transferred) + + +class ProjectUserHandler(ClassHandler): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + try: + self._dst_obj = ProjectUser.objects.get( + project=self._src_obj.project, user=self._dst_user) + except ObjectDoesNotExist: + self._dst_obj = None + + def _run_special_handling(self): + if self._dst_obj: + role_updated = self._update_role() + status_updated = self._update_status() + if role_updated or status_updated: + self._dst_obj.save() + else: + self._transfer_src_obj_to_dst_user() + + # TODO: Run the runner? + + def _update_role(self): + """Update the role of the destination if the source's is higher. + Return whether an update occurred.""" + dst_obj_role = self._dst_obj.role + self._dst_obj.role = higher_project_user_role( + dst_obj_role, self._src_obj.role) + if self._dst_obj.role != dst_obj_role: + self._record_update( + self._dst_obj.pk, 'role', dst_obj_role.name, + self._dst_obj.role.name) + return True + return False + + def _update_status(self): + """Update the status of the destination if it is not "Active" + but the source's is. Return whether an update occurred.""" + active_project_user_status = ProjectUserStatusChoice.objects.get( + name='Active') + dst_obj_status = self._dst_obj.status + if (self._dst_obj.status != active_project_user_status and + self._src_obj.status == active_project_user_status): + self._dst_obj.status = self._src_obj.status + self._record_update( + self._dst_obj.pk, 'status', dst_obj_status.name, + 'Active') + return True + return False + + +class ProjectUserJoinRequestHandler(ClassHandler): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _run_special_handling(self): + transferred = self._src_obj.project_user.user == self._dst_user + self._handle_associated_object(transferred=transferred) + + +class SavioProjectAllocationRequestHandler(ClassHandler): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _run_special_handling(self): + if self._src_obj.requester == self._src_user: + self._transfer_src_obj_to_dst_user(attr_name='requester') + if self._src_obj.pi == self._src_user: + self._transfer_src_obj_to_dst_user(attr_name='pi') diff --git a/coldfront/core/user/utils_/merge_users/runner.py b/coldfront/core/user/utils_/merge_users/runner.py new file mode 100644 index 000000000..9f1ecf4db --- /dev/null +++ b/coldfront/core/user/utils_/merge_users/runner.py @@ -0,0 +1,174 @@ +from django.contrib.auth.models import User +from django.contrib.admin.utils import NestedObjects +from django.db import DEFAULT_DB_ALIAS +from django.db import transaction + +from coldfront.core.allocation.utils import has_cluster_access +from coldfront.core.user.models import EmailAddress as OldEmailAddress +from coldfront.core.user.utils_.merge_users.class_handlers import ClassHandlerFactory + + +class UserMergeError(Exception): + pass + + +class UserMergeRollback(Exception): + pass + + +class UserMergeRunner(object): + """A class that merges two User objects into one. + + It identifies one User as a source and the other as a destination, + merges the source's relationships, requests, etc. into the + destination, and then deletes the source. + + It currently only supports merging when only one of the given Users + has cluster access. + """ + + def __init__(self, user_1, user_2, reporting_strategies=None): + """Identify which of the two Users should be merged into.""" + self._dry = False + # src_user's data will be merged into dst_user. + self._src_user = None + self._dst_user = None + self._src_user_pk = None + self._identify_src_and_dst_users(user_1, user_2) + + # Report messages using each of the given strategies. + self._reporting_strategies = [] + if isinstance(reporting_strategies, list): + for strategy in reporting_strategies: + self._reporting_strategies.append(strategy) + + @property + def dst_user(self): + return self._dst_user + + @property + def src_user(self): + return self._src_user + + def dry_run(self): + """Attempt to run the merge, but rollback before committing + changes.""" + self._dry = True + self.run() + + @transaction.atomic + def run(self): + """Transfer dependencies from the source User to the destination + User, then delete the source User.""" + try: + with transaction.atomic(): + self._select_users_for_update() + self._process_src_user_dependencies() + self._src_user.delete() + if self._dry: + self._rollback() + except UserMergeRollback: + # The dry run succeeded, and the transaction was rolled back. + self._reset_users() + except Exception as e: + self._reset_users() + raise e + + @staticmethod + def _classes_to_ignore(): + """Return a set of classes for which no processing should be + done.""" + return { + OldEmailAddress, + } + + def _identify_src_and_dst_users(self, user_1, user_2): + """Given two Users, determine which should be the source (the + one having its data merged and then deleted) and which should be + the destination (the one having data merged into it).""" + user_1_has_cluster_access = has_cluster_access(user_1) + user_2_has_cluster_access = has_cluster_access(user_2) + if not (user_1_has_cluster_access or user_2_has_cluster_access): + src, dst = user_2, user_1 + elif user_1_has_cluster_access and not user_2_has_cluster_access: + src, dst = user_2, user_1 + elif not user_1_has_cluster_access and user_2_has_cluster_access: + src, dst = user_1, user_2 + else: + raise NotImplementedError( + 'Both Users have cluster access. This case is not currently ' + 'handled.') + self._src_user = src + self._dst_user = dst + # Store the primary key of src_user, used to restore the object after + # dry run rollback. + self._src_user_pk = self._src_user.pk + + def _process_src_user_dependencies(self): + """Process each database object associated with the source User + on a case-by-case basis.""" + collector = NestedObjects(using=DEFAULT_DB_ALIAS) + collector.collect([self._src_user]) + objects = collector.nested() + + assert len(objects) == 2 + assert isinstance(objects[0], User) + assert isinstance(objects[1], list) + + classes_to_ignore = self._classes_to_ignore() + + for obj in self._yield_nested_objects(objects[1]): + + if obj.__class__ in classes_to_ignore: + continue + + class_handler_factory = ClassHandlerFactory() + + # Block other threads from retrieving this object until the end of + # the transaction. + obj = obj.__class__.objects.select_for_update().get(pk=obj.pk) + + try: + handler = class_handler_factory.get_handler( + obj.__class__, self._src_user, self._dst_user, obj, + reporting_strategies=self._reporting_strategies) + handler.run() + except ValueError: + raise UserMergeError( + f'No handler for object with class {obj.__class__}.') + except Exception as e: + raise UserMergeError( + f'Failed to process object with class {obj.__class__} and ' + f'primary key {obj.pk}. Details:\n{e}') + + def _reset_users(self): + """Reset user objects because the values of a model's fields + won't be reverted when a transaction rollback happens. + + Source: https://docs.djangoproject.com/en/3.2/topics/db/transactions/#controlling-transactions-explicitly + """ + self._src_user = User.objects.get(pk=self._src_user_pk) + self._dst_user.refresh_from_db() + + def _rollback(self): + """Raise a UserMergeRollback exception to roll the enclosing + transaction back.""" + raise UserMergeRollback('Rolling back.') + + def _select_users_for_update(self): + """Block other threads from retrieving the users until the end + of the transaction.""" + self._src_user = User.objects.select_for_update().get( + pk=self._src_user.pk) + self._dst_user = User.objects.select_for_update().get( + pk=self._dst_user.pk) + + def _yield_nested_objects(self, objects): + """Given a list that contains objects and lists of potentially- + nested objects, return a generator that recursively yields + objects.""" + for obj in objects: + if isinstance(obj, list): + yield from self._yield_nested_objects(obj) + else: + yield obj diff --git a/coldfront/core/utils/reporting/__init__.py b/coldfront/core/utils/reporting/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/coldfront/core/utils/reporting/report_message_strategy.py b/coldfront/core/utils/reporting/report_message_strategy.py new file mode 100644 index 000000000..5ebbd31d7 --- /dev/null +++ b/coldfront/core/utils/reporting/report_message_strategy.py @@ -0,0 +1,84 @@ +import logging + +from abc import ABC + +from django.core.management.base import BaseCommand +from django.utils.termcolors import colorize + + +class ReportMessageStrategy(ABC): + """An interface that uses the Strategy design pattern to vary how + messages are reported.""" + + def error(self, message): + raise NotImplementedError + + def success(self, message): + raise NotImplementedError + + def warning(self, message): + raise NotImplementedError + + +class EnqueueForLoggingStrategy(ReportMessageStrategy): + """A strategy for enqueueing messages to be written to a logging instance + later.""" + + def __init__(self, logger): + assert isinstance(logger, logging.Logger) + self._logger = logger + # Tuples of the form (logging_func, message). + self._queue = [] + + def error(self, message): + logging_func = self._logger.error + self._queue.append((logging_func, message)) + + def success(self, message): + logging_func = self._logger.info + self._queue.append((logging_func, message)) + + def warning(self, message): + logging_func = self._logger.warning + self._queue.append((logging_func, message)) + + def log_queued_messages(self): + for logging_func, message in self._queue: + logging_func(message) + + +class PrintStrategy(ReportMessageStrategy): + """A strategy for printing messages.""" + + def error(self, message): + print(colorize(text=message, fg="red")) + + def success(self, message): + print(colorize(text=message, fg="green")) + + def warning(self, message): + print(colorize(text=message, fg="yellow")) + + +class WriteViaCommandStrategy(ReportMessageStrategy): + """A strategy for writing messages to stdout/stderr via a Django + management command.""" + + def __init__(self, command): + assert isinstance(command, BaseCommand) + self._command = command + + def error(self, message): + stream = self._command.stderr + style = self._command.style.ERROR + stream.write(style(message)) + + def success(self, message): + stream = self._command.stdout + style = self._command.style.SUCCESS + stream.write(style(message)) + + def warning(self, message): + stream = self._command.stdout + style = self._command.style.WANRING + stream.write(style(message))