From 3806b5d75d3fa61bba43ec2695cadc0fbd6f5de9 Mon Sep 17 00:00:00 2001 From: Roy Wiggins Date: Mon, 25 Nov 2024 20:58:24 +0000 Subject: [PATCH] use fakeredis in tests remove deprecated Connection(redis) --- pytest.ini | 2 + tests/test_query.py | 23 ++--- tests/test_router.py | 2 +- webinterface/dashboards/query/jobs.py | 112 ++++++++++++----------- webinterface/dashboards/query_routes.py | 116 +++++++++++------------- 5 files changed, 127 insertions(+), 128 deletions(-) create mode 100644 pytest.ini diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..6a7d170 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_default_fixture_loop_scope = function \ No newline at end of file diff --git a/tests/test_query.py b/tests/test_query.py index 0a2cef2..9c57445 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -32,8 +32,9 @@ @pytest.fixture(scope="module", autouse=True) def rq_connection(): - with Connection(redis): - yield redis + my_redis = FakeStrictRedis() + # with Connection(my_redis): + yield my_redis @pytest.fixture(scope="module") def mock_node(receiver_port): @@ -200,7 +201,7 @@ def test_query_job(dicom_server, tempdir, rq_connection,fs): except subprocess.CalledProcessError: pass fs.resume() - job = QueryPipeline.create([MOCK_ACCESSIONS[0]], {}, dicom_server, str(tempdir)) + job = QueryPipeline.create([MOCK_ACCESSIONS[0]], {}, dicom_server, str(tempdir), redis_server=rq_connection) w = SimpleWorker(["mercure_fast", "mercure_slow"], connection=rq_connection) w.work(burst=True) @@ -227,7 +228,7 @@ def test_query_job_to_mercure(dicom_server, tempdir, rq_connection, fs, mercure_ ).dict(), } }) - job = QueryPipeline.create([MOCK_ACCESSIONS[0]], {}, dicom_server, None, False, "rule_to_force") + job = QueryPipeline.create([MOCK_ACCESSIONS[0]], {}, dicom_server, None, False, "rule_to_force", rq_connection) w = SimpleWorker(["mercure_fast", "mercure_slow"], connection=rq_connection) w.work(burst=True) @@ -264,12 +265,12 @@ def tree(path, prefix='', level=0) -> None: if entry.is_dir(): tree(entry.path, prefix + (' ' if i == len(entries) - 1 else '│ '), level+1) -def test_query_dicomweb(dicomweb_server, tempdir, dummy_datasets, fs): +def test_query_dicomweb(dicomweb_server, tempdir, dummy_datasets, fs, rq_connection): (tempdir / "outdir").mkdir() ds = list(dummy_datasets.values())[0] - task = QueryPipeline.create([ds.AccessionNumber], {}, dicomweb_server, (tempdir / "outdir")) + task = QueryPipeline.create([ds.AccessionNumber], {}, dicomweb_server, (tempdir / "outdir"), redis_server=rq_connection) assert task - w = SimpleWorker(["mercure_fast", "mercure_slow"], connection=redis) + w = SimpleWorker(["mercure_fast", "mercure_slow"], connection=rq_connection) w.work(burst=True) # tree(tempdir / "outdir") outfile = (tempdir / "outdir" / task.id / ds.AccessionNumber / f"{ds.SOPInstanceUID}.dcm") @@ -280,7 +281,7 @@ def test_query_dicomweb(dicomweb_server, tempdir, dummy_datasets, fs): def test_query_operations(dicomweb_server, tempdir, dummy_datasets, fs, rq_connection): (tempdir / "outdir").mkdir() - task = QueryPipeline.create([ds.AccessionNumber for ds in dummy_datasets.values()], {}, dicomweb_server, (tempdir / "outdir")) + task = QueryPipeline.create([ds.AccessionNumber for ds in dummy_datasets.values()], {}, dicomweb_server, (tempdir / "outdir"), redis_server=rq_connection) assert task assert task.meta['total'] == len(dummy_datasets) assert task.meta['completed'] == 0 @@ -290,7 +291,7 @@ def test_query_operations(dicomweb_server, tempdir, dummy_datasets, fs, rq_conne assert job.get_status() == "canceled" assert jobs - w = SimpleWorker(["mercure_fast", "mercure_slow"], connection=redis) + w = SimpleWorker(["mercure_fast", "mercure_slow"], connection=rq_connection) w.work(burst=True) outfile = (tempdir / "outdir" / task.id) task.get_meta() @@ -313,10 +314,10 @@ def test_query_operations(dicomweb_server, tempdir, dummy_datasets, fs, rq_conne def test_query_retry(dicom_server_2: Tuple[DicomTarget,DummyDICOMServer], tempdir, dummy_datasets, fs, rq_connection): (tempdir / "outdir").mkdir() target, server = dicom_server_2 - task = QueryPipeline.create([ds.AccessionNumber for ds in dummy_datasets.values()], {}, target, (tempdir / "outdir")) + task = QueryPipeline.create([ds.AccessionNumber for ds in dummy_datasets.values()], {}, target, (tempdir / "outdir"), redis_server=rq_connection) server.remaining_allowed_accessions = 1 # Only one accession is allowed to be retrieved - w = SimpleWorker(["mercure_fast", "mercure_slow"], connection=redis) + w = SimpleWorker(["mercure_fast", "mercure_slow"], connection=rq_connection) w.work(burst=True) task.get_meta() assert task.meta['completed'] == 1 diff --git a/tests/test_router.py b/tests/test_router.py index f3a4b47..eac64cf 100755 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -156,7 +156,7 @@ def test_route_series_fail4(fs: FakeFilesystem, mercure_config, mocked): task_id, 0, "", - f"Problem while pushing file to outgoing {series_uid}#baz\nSource folder {config.incoming_folder}/{series_uid}\nTarget folder {config.outgoing_folder}/{task_id}", + f"Problem while pushing file to outgoing [{series_uid}#baz]\nSource folder {config.incoming_folder}/{series_uid}\nTarget folder {config.outgoing_folder}/{task_id}", ) assert list(Path(config.outgoing_folder).glob("**/*.dcm")) == [] diff --git a/webinterface/dashboards/query/jobs.py b/webinterface/dashboards/query/jobs.py index bf85934..2e49d76 100644 --- a/webinterface/dashboards/query/jobs.py +++ b/webinterface/dashboards/query/jobs.py @@ -93,13 +93,13 @@ class ClassBasedRQTask(): _queue: str = '' @classmethod - def queue(cls) -> Queue: - return Queue(cls._queue, connection=redis) + def queue(cls, connection=None) -> Queue: + return Queue(cls._queue, connection=(connection or redis)) - def create_job(self, rq_options={}, **kwargs) -> Job: + def create_job(self, connection, rq_options={}, **kwargs) -> Job: fields = dataclasses.fields(self) meta = {field.name: getattr(self, field.name) for field in fields} - return Job.create(self._execute, kwargs=kwargs, meta=meta, **rq_options) + return Job.create(self._execute, connection=connection, kwargs=kwargs, meta=meta, **rq_options) @classmethod def _execute(cls, **kwargs) -> Any: @@ -305,65 +305,67 @@ def execute(self, *, accessions, subjobs, path:str, destination: Optional[str], class QueryPipeline(): job: Job - def __init__(self, job: Union[Job,str]): + connection: Connection + def __init__(self, job: Union[Job,str], connection=redis): + self.connection = connection if isinstance(job, str): - if not (result:=Job.fetch(job)): + if not (result:=Job.fetch(job,connection=self.connection)): raise Exception("Invalid Job ID") self.job = result else: self.job = job - assert self.job.meta.get('type') == 'batch', f"Job type must be batch, got {self.job.meta['type']}" @classmethod - def create(cls, accessions: List[str], search_filters:Dict[str, List[str]], dicom_node: Union[DicomWebTarget, DicomTarget], destination_path: Optional[str], offpeak:bool=False, force_rule:Optional[str]=None) -> 'QueryPipeline': + def create(cls, accessions: List[str], search_filters:Dict[str, List[str]], dicom_node: Union[DicomWebTarget, DicomTarget], destination_path: Optional[str], offpeak:bool=False, force_rule:Optional[str]=None, redis_server=None) -> 'QueryPipeline': """ Create a job to process the given accessions and store them in the specified destination path. """ - - with Connection(redis): - get_accession_jobs: List[Job] = [] - check_job = CheckAccessionsTask().create_job(accessions=accessions, search_filters=search_filters, node=dicom_node) - for accession in accessions: - get_accession_task = GetAccessionTask(offpeak=offpeak).create_job( - accession=str(accession), - node=dicom_node, - force_rule=force_rule, - search_filters=search_filters, - rq_options=dict( - depends_on=cast(List[Union[Dependency, Job]],[check_job]), - timeout=30*60, - result_ttl=-1 - ) - ) - get_accession_jobs.append(get_accession_task) - depends = Dependency( - jobs=cast(List[Union[Job,str]],get_accession_jobs), - allow_failure=True, # allow_failure defaults to False - ) - main_job = MainTask(total=len(get_accession_jobs), offpeak=offpeak).create_job( - accessions = accessions, - subjobs = [check_job.id]+[j.id for j in get_accession_jobs], - destination = destination_path, + connection = redis_server or redis + get_accession_jobs: List[Job] = [] + check_job = CheckAccessionsTask().create_job(connection,accessions=accessions, search_filters=search_filters, node=dicom_node) + for accession in accessions: + get_accession_task = GetAccessionTask(offpeak=offpeak).create_job( + connection, + accession=str(accession), node=dicom_node, - move_promptly = True, - rq_options = dict(depends_on=depends, timeout=-1, result_ttl=-1), - force_rule=force_rule - ) - check_job.meta["parent"] = main_job.id - for j in get_accession_jobs: - j.meta["parent"] = main_job.id - j.kwargs["path"] = Path(config.mercure.jobs_folder) / str(main_job.id) / j.kwargs['accession'] - j.kwargs["path"].mkdir(parents=True) - - main_job.kwargs["path"] = Path(config.mercure.jobs_folder) / str(main_job.id) - - CheckAccessionsTask.queue().enqueue_job(check_job) + force_rule=force_rule, + search_filters=search_filters, + rq_options=dict( + depends_on=cast(List[Union[Dependency, Job]],[check_job]), + timeout=30*60, + result_ttl=-1 + ) + ) + get_accession_jobs.append(get_accession_task) + depends = Dependency( + jobs=cast(List[Union[Job,str]],get_accession_jobs), + allow_failure=True, # allow_failure defaults to False + ) + main_job = MainTask(total=len(get_accession_jobs), offpeak=offpeak).create_job( + connection, + accessions = accessions, + subjobs = [check_job.id]+[j.id for j in get_accession_jobs], + destination = destination_path, + node=dicom_node, + move_promptly = True, + rq_options = dict(depends_on=depends, timeout=-1, result_ttl=-1), + force_rule=force_rule + ) + check_job.meta["parent"] = main_job.id for j in get_accession_jobs: - GetAccessionTask.queue().enqueue_job(j) - MainTask.queue().enqueue_job(main_job) + j.meta["parent"] = main_job.id + j.kwargs["path"] = Path(config.mercure.jobs_folder) / str(main_job.id) / j.kwargs['accession'] + j.kwargs["path"].mkdir(parents=True) - wrapped_job = cls(main_job) + main_job.kwargs["path"] = Path(config.mercure.jobs_folder) / str(main_job.id) + + CheckAccessionsTask.queue(connection).enqueue_job(check_job) + for j in get_accession_jobs: + GetAccessionTask.queue(connection).enqueue_job(j) + MainTask.queue(connection).enqueue_job(main_job) + + wrapped_job = cls(main_job, connection) if offpeak and not helper._is_offpeak(config.mercure.offpeak_start, config.mercure.offpeak_end, datetime.now().time()): wrapped_job.pause() @@ -377,7 +379,7 @@ def pause(self) -> None: Pause the current job, including all its subjobs. """ for job_id in self.job.kwargs.get('subjobs',[]): - subjob = Job.fetch(job_id) + subjob = Job.fetch(job_id, connection=self.connection) if subjob and (subjob.is_deferred or subjob.is_queued): logger.debug(f"Pausing {subjob}") subjob.meta['paused'] = True @@ -392,11 +394,11 @@ def resume(self) -> None: Resume a paused job by unpausing all its subjobs """ for subjob_id in self.job.kwargs.get('subjobs',[]): - subjob = Job.fetch(subjob_id) + subjob = Job.fetch(subjob_id, connection=self.connection) if subjob and subjob.meta.get('paused', None): subjob.meta['paused'] = False subjob.save_meta() # type: ignore - Queue(subjob.origin).canceled_job_registry.requeue(subjob_id) + Queue(subjob.origin, connection=self.connection).canceled_job_registry.requeue(subjob_id) self.job.get_meta() self.job.meta['paused'] = False self.job.save_meta() # type: ignore @@ -415,8 +417,8 @@ def retry(self) -> None: logger.info(f"Retrying {subjob} ({status}) {meta}") if status == "failed" and (job_path:=Path(subjob.kwargs['path'])).exists(): shutil.rmtree(job_path) # Clean up after a failed job - Queue(subjob.origin).enqueue_job(subjob) - Queue(self.job.origin).enqueue_job(self.job) + Queue(subjob.origin, connection=self.connection).enqueue_job(subjob) + Queue(self.job.origin, connection=self.connection).enqueue_job(self.job) @classmethod def update_all_offpeak(cls) -> None: @@ -445,7 +447,7 @@ def update_offpeak(self, is_offpeak) -> None: self.pause() def get_subjobs(self) -> Generator[Job, None, None]: - return (j for j in (Queue(self.job.origin).fetch_job(job) for job in self.job.kwargs.get('subjobs', [])) if j is not None) + return (j for j in (Queue(self.job.origin, connection=self.connection).fetch_job(job) for job in self.job.kwargs.get('subjobs', [])) if j is not None) def get_status(self) -> JobStatus: return cast(JobStatus,self.job.get_status()) diff --git a/webinterface/dashboards/query_routes.py b/webinterface/dashboards/query_routes.py index 0c29f6c..332ac6f 100644 --- a/webinterface/dashboards/query_routes.py +++ b/webinterface/dashboards/query_routes.py @@ -25,81 +25,77 @@ @router.post("/query/retry_job") @requires(["authenticated", "admin"], redirect="login") async def post_retry_job(request): - with Connection(redis): - job = QueryPipeline(request.query_params['id']) - - if not job: - return JSONErrorResponse(f"Job with id {request.query_params['id']} not found.", status_code=404) - - try: - job.retry() - except Exception as e: - logger.exception("Failed to retry job", exc_info=True) - return JSONErrorResponse("Failed to retry job",status_code=500) + job = QueryPipeline(request.query_params['id']) + + if not job: + return JSONErrorResponse(f"Job with id {request.query_params['id']} not found.", status_code=404) + + try: + job.retry() + except Exception as e: + logger.exception("Failed to retry job", exc_info=True) + return JSONErrorResponse("Failed to retry job",status_code=500) return JSONResponse({}) @router.post("/query/pause_job") @requires(["authenticated", "admin"], redirect="login") async def post_pause_job(request): - with Connection(redis): - job = QueryPipeline(request.query_params['id']) + job = QueryPipeline(request.query_params['id']) - if not job: - return JSONErrorResponse('Job not found', status_code=404) - if job.is_finished or job.is_failed: - return JSONErrorResponse('Job is already finished', status_code=400) + if not job: + return JSONErrorResponse('Job not found', status_code=404) + if job.is_finished or job.is_failed: + return JSONErrorResponse('Job is already finished', status_code=400) - try: - job.pause() - except Exception as e: - logger.exception(f"Failed to pause job {request.query_params['id']}") - return JSONErrorResponse('Failed to pause job', status_code=500) + try: + job.pause() + except Exception as e: + logger.exception(f"Failed to pause job {request.query_params['id']}") + return JSONErrorResponse('Failed to pause job', status_code=500) return JSONResponse({'status': 'success'}, status_code=200) @router.post("/query/resume_job") @requires(["authenticated", "admin"], redirect="login") async def post_resume_job(request): - with Connection(redis): - job = QueryPipeline(request.query_params['id']) - if not job: - return JSONErrorResponse('Job not found', status_code=404) - if job.is_finished or job.is_failed: - return JSONErrorResponse('Job is already finished', status_code=400) + job = QueryPipeline(request.query_params['id']) + if not job: + return JSONErrorResponse('Job not found', status_code=404) + if job.is_finished or job.is_failed: + return JSONErrorResponse('Job is already finished', status_code=400) - try: - job.resume() - except Exception as e: - logger.exception(f"Failed to resume job {request.query_params['id']}") - return JSONErrorResponse('Failed to resume job', status_code=500) + try: + job.resume() + except Exception as e: + logger.exception(f"Failed to resume job {request.query_params['id']}") + return JSONErrorResponse('Failed to resume job', status_code=500) return JSONResponse({'status': 'success'}, status_code=200) @router.get("/query/job_info") @requires(["authenticated", "admin"], redirect="login") async def get_job_info(request): job_id = request.query_params['id'] - with Connection(redis): - job = QueryPipeline(job_id) - if not job: - return JSONErrorResponse('Job not found', status_code=404) - - subjob_info:List[Dict[str,Any]] = [] - for subjob in job.get_subjobs(): - if not subjob: - continue - if subjob.meta.get('type') != 'get_accession': - continue - info = { - 'id': subjob.get_id(), - 'ended_at': subjob.ended_at.isoformat().split('.')[0] if subjob.ended_at else "", - 'created_at_dt':subjob.created_at, - 'accession': subjob.kwargs['accession'], - 'progress': subjob.meta.get('progress'), - 'paused': subjob.meta.get('paused',False), - 'status': subjob.get_status() - } - if info['status'] == 'canceled' and info['paused']: - info['status'] = 'paused' - subjob_info.append(info) + job = QueryPipeline(job_id) + if not job: + return JSONErrorResponse('Job not found', status_code=404) + + subjob_info:List[Dict[str,Any]] = [] + for subjob in job.get_subjobs(): + if not subjob: + continue + if subjob.meta.get('type') != 'get_accession': + continue + info = { + 'id': subjob.get_id(), + 'ended_at': subjob.ended_at.isoformat().split('.')[0] if subjob.ended_at else "", + 'created_at_dt':subjob.created_at, + 'accession': subjob.kwargs['accession'], + 'progress': subjob.meta.get('progress'), + 'paused': subjob.meta.get('paused',False), + 'status': subjob.get_status() + } + if info['status'] == 'canceled' and info['paused']: + info['status'] = 'paused' + subjob_info.append(info) subjob_info = sorted(subjob_info, key=lambda x:x['created_at_dt']) @@ -164,8 +160,7 @@ async def query_jobs(request): """ tasks_info = [] try: - with Connection(redis): - query_tasks = list(QueryPipeline.get_all()) + query_tasks = list(QueryPipeline.get_all()) except Exception as e: logger.exception("Error retrieving query tasks.") return JSONErrorResponse("Error retrieving query tasks.", status_code=500) @@ -272,9 +267,8 @@ async def check_accessions(request): return JSONErrorResponse(f"Invalid DICOM node '{node_name}'.", status_code=400) try: - with Connection(redis): - job = CheckAccessionsTask().create_job(accessions=accessions, node=node, search_filters=search_filters) - CheckAccessionsTask.queue().enqueue_job(job) + job = CheckAccessionsTask().create_job(connection=redis,accessions=accessions, node=node, search_filters=search_filters) + CheckAccessionsTask.queue(redis).enqueue_job(job) except Exception as e: logger.exception("Error during accessions check task creation") return JSONErrorResponse(str(e), status_code=500)