Skip to content

Commit

Permalink
use fakeredis in tests
Browse files Browse the repository at this point in the history
remove deprecated Connection(redis)
  • Loading branch information
Roy Wiggins authored and Roy Wiggins committed Nov 25, 2024
1 parent 154f358 commit 3806b5d
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 128 deletions.
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
asyncio_default_fixture_loop_scope = function
23 changes: 12 additions & 11 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@

@pytest.fixture(scope="module", autouse=True)
def rq_connection():
with Connection(redis):
yield redis
my_redis = FakeStrictRedis()
# with Connection(my_redis):
yield my_redis

@pytest.fixture(scope="module")
def mock_node(receiver_port):
Expand Down Expand Up @@ -200,7 +201,7 @@ def test_query_job(dicom_server, tempdir, rq_connection,fs):
except subprocess.CalledProcessError:
pass
fs.resume()
job = QueryPipeline.create([MOCK_ACCESSIONS[0]], {}, dicom_server, str(tempdir))
job = QueryPipeline.create([MOCK_ACCESSIONS[0]], {}, dicom_server, str(tempdir), redis_server=rq_connection)
w = SimpleWorker(["mercure_fast", "mercure_slow"], connection=rq_connection)

w.work(burst=True)
Expand All @@ -227,7 +228,7 @@ def test_query_job_to_mercure(dicom_server, tempdir, rq_connection, fs, mercure_
).dict(),
}
})
job = QueryPipeline.create([MOCK_ACCESSIONS[0]], {}, dicom_server, None, False, "rule_to_force")
job = QueryPipeline.create([MOCK_ACCESSIONS[0]], {}, dicom_server, None, False, "rule_to_force", rq_connection)
w = SimpleWorker(["mercure_fast", "mercure_slow"], connection=rq_connection)

w.work(burst=True)
Expand Down Expand Up @@ -264,12 +265,12 @@ def tree(path, prefix='', level=0) -> None:
if entry.is_dir():
tree(entry.path, prefix + (' ' if i == len(entries) - 1 else '│ '), level+1)

def test_query_dicomweb(dicomweb_server, tempdir, dummy_datasets, fs):
def test_query_dicomweb(dicomweb_server, tempdir, dummy_datasets, fs, rq_connection):
(tempdir / "outdir").mkdir()
ds = list(dummy_datasets.values())[0]
task = QueryPipeline.create([ds.AccessionNumber], {}, dicomweb_server, (tempdir / "outdir"))
task = QueryPipeline.create([ds.AccessionNumber], {}, dicomweb_server, (tempdir / "outdir"), redis_server=rq_connection)
assert task
w = SimpleWorker(["mercure_fast", "mercure_slow"], connection=redis)
w = SimpleWorker(["mercure_fast", "mercure_slow"], connection=rq_connection)
w.work(burst=True)
# tree(tempdir / "outdir")
outfile = (tempdir / "outdir" / task.id / ds.AccessionNumber / f"{ds.SOPInstanceUID}.dcm")
Expand All @@ -280,7 +281,7 @@ def test_query_dicomweb(dicomweb_server, tempdir, dummy_datasets, fs):

def test_query_operations(dicomweb_server, tempdir, dummy_datasets, fs, rq_connection):
(tempdir / "outdir").mkdir()
task = QueryPipeline.create([ds.AccessionNumber for ds in dummy_datasets.values()], {}, dicomweb_server, (tempdir / "outdir"))
task = QueryPipeline.create([ds.AccessionNumber for ds in dummy_datasets.values()], {}, dicomweb_server, (tempdir / "outdir"), redis_server=rq_connection)
assert task
assert task.meta['total'] == len(dummy_datasets)
assert task.meta['completed'] == 0
Expand All @@ -290,7 +291,7 @@ def test_query_operations(dicomweb_server, tempdir, dummy_datasets, fs, rq_conne
assert job.get_status() == "canceled"
assert jobs

w = SimpleWorker(["mercure_fast", "mercure_slow"], connection=redis)
w = SimpleWorker(["mercure_fast", "mercure_slow"], connection=rq_connection)
w.work(burst=True)
outfile = (tempdir / "outdir" / task.id)
task.get_meta()
Expand All @@ -313,10 +314,10 @@ def test_query_operations(dicomweb_server, tempdir, dummy_datasets, fs, rq_conne
def test_query_retry(dicom_server_2: Tuple[DicomTarget,DummyDICOMServer], tempdir, dummy_datasets, fs, rq_connection):
(tempdir / "outdir").mkdir()
target, server = dicom_server_2
task = QueryPipeline.create([ds.AccessionNumber for ds in dummy_datasets.values()], {}, target, (tempdir / "outdir"))
task = QueryPipeline.create([ds.AccessionNumber for ds in dummy_datasets.values()], {}, target, (tempdir / "outdir"), redis_server=rq_connection)

server.remaining_allowed_accessions = 1 # Only one accession is allowed to be retrieved
w = SimpleWorker(["mercure_fast", "mercure_slow"], connection=redis)
w = SimpleWorker(["mercure_fast", "mercure_slow"], connection=rq_connection)
w.work(burst=True)
task.get_meta()
assert task.meta['completed'] == 1
Expand Down
2 changes: 1 addition & 1 deletion tests/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_route_series_fail4(fs: FakeFilesystem, mercure_config, mocked):
task_id,
0,
"",
f"Problem while pushing file to outgoing {series_uid}#baz\nSource folder {config.incoming_folder}/{series_uid}\nTarget folder {config.outgoing_folder}/{task_id}",
f"Problem while pushing file to outgoing [{series_uid}#baz]\nSource folder {config.incoming_folder}/{series_uid}\nTarget folder {config.outgoing_folder}/{task_id}",
)
assert list(Path(config.outgoing_folder).glob("**/*.dcm")) == []

Expand Down
112 changes: 57 additions & 55 deletions webinterface/dashboards/query/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ class ClassBasedRQTask():
_queue: str = ''

@classmethod
def queue(cls) -> Queue:
return Queue(cls._queue, connection=redis)
def queue(cls, connection=None) -> Queue:
return Queue(cls._queue, connection=(connection or redis))

def create_job(self, rq_options={}, **kwargs) -> Job:
def create_job(self, connection, rq_options={}, **kwargs) -> Job:
fields = dataclasses.fields(self)
meta = {field.name: getattr(self, field.name) for field in fields}
return Job.create(self._execute, kwargs=kwargs, meta=meta, **rq_options)
return Job.create(self._execute, connection=connection, kwargs=kwargs, meta=meta, **rq_options)

@classmethod
def _execute(cls, **kwargs) -> Any:
Expand Down Expand Up @@ -305,65 +305,67 @@ def execute(self, *, accessions, subjobs, path:str, destination: Optional[str],

class QueryPipeline():
job: Job
def __init__(self, job: Union[Job,str]):
connection: Connection
def __init__(self, job: Union[Job,str], connection=redis):
self.connection = connection
if isinstance(job, str):
if not (result:=Job.fetch(job)):
if not (result:=Job.fetch(job,connection=self.connection)):
raise Exception("Invalid Job ID")
self.job = result
else:
self.job = job

assert self.job.meta.get('type') == 'batch', f"Job type must be batch, got {self.job.meta['type']}"

@classmethod
def create(cls, accessions: List[str], search_filters:Dict[str, List[str]], dicom_node: Union[DicomWebTarget, DicomTarget], destination_path: Optional[str], offpeak:bool=False, force_rule:Optional[str]=None) -> 'QueryPipeline':
def create(cls, accessions: List[str], search_filters:Dict[str, List[str]], dicom_node: Union[DicomWebTarget, DicomTarget], destination_path: Optional[str], offpeak:bool=False, force_rule:Optional[str]=None, redis_server=None) -> 'QueryPipeline':
"""
Create a job to process the given accessions and store them in the specified destination path.
"""

with Connection(redis):
get_accession_jobs: List[Job] = []
check_job = CheckAccessionsTask().create_job(accessions=accessions, search_filters=search_filters, node=dicom_node)
for accession in accessions:
get_accession_task = GetAccessionTask(offpeak=offpeak).create_job(
accession=str(accession),
node=dicom_node,
force_rule=force_rule,
search_filters=search_filters,
rq_options=dict(
depends_on=cast(List[Union[Dependency, Job]],[check_job]),
timeout=30*60,
result_ttl=-1
)
)
get_accession_jobs.append(get_accession_task)
depends = Dependency(
jobs=cast(List[Union[Job,str]],get_accession_jobs),
allow_failure=True, # allow_failure defaults to False
)
main_job = MainTask(total=len(get_accession_jobs), offpeak=offpeak).create_job(
accessions = accessions,
subjobs = [check_job.id]+[j.id for j in get_accession_jobs],
destination = destination_path,
connection = redis_server or redis
get_accession_jobs: List[Job] = []
check_job = CheckAccessionsTask().create_job(connection,accessions=accessions, search_filters=search_filters, node=dicom_node)
for accession in accessions:
get_accession_task = GetAccessionTask(offpeak=offpeak).create_job(
connection,
accession=str(accession),
node=dicom_node,
move_promptly = True,
rq_options = dict(depends_on=depends, timeout=-1, result_ttl=-1),
force_rule=force_rule
)
check_job.meta["parent"] = main_job.id
for j in get_accession_jobs:
j.meta["parent"] = main_job.id
j.kwargs["path"] = Path(config.mercure.jobs_folder) / str(main_job.id) / j.kwargs['accession']
j.kwargs["path"].mkdir(parents=True)

main_job.kwargs["path"] = Path(config.mercure.jobs_folder) / str(main_job.id)

CheckAccessionsTask.queue().enqueue_job(check_job)
force_rule=force_rule,
search_filters=search_filters,
rq_options=dict(
depends_on=cast(List[Union[Dependency, Job]],[check_job]),
timeout=30*60,
result_ttl=-1
)
)
get_accession_jobs.append(get_accession_task)
depends = Dependency(
jobs=cast(List[Union[Job,str]],get_accession_jobs),
allow_failure=True, # allow_failure defaults to False
)
main_job = MainTask(total=len(get_accession_jobs), offpeak=offpeak).create_job(
connection,
accessions = accessions,
subjobs = [check_job.id]+[j.id for j in get_accession_jobs],
destination = destination_path,
node=dicom_node,
move_promptly = True,
rq_options = dict(depends_on=depends, timeout=-1, result_ttl=-1),
force_rule=force_rule
)
check_job.meta["parent"] = main_job.id
for j in get_accession_jobs:
GetAccessionTask.queue().enqueue_job(j)
MainTask.queue().enqueue_job(main_job)
j.meta["parent"] = main_job.id
j.kwargs["path"] = Path(config.mercure.jobs_folder) / str(main_job.id) / j.kwargs['accession']
j.kwargs["path"].mkdir(parents=True)

wrapped_job = cls(main_job)
main_job.kwargs["path"] = Path(config.mercure.jobs_folder) / str(main_job.id)

CheckAccessionsTask.queue(connection).enqueue_job(check_job)
for j in get_accession_jobs:
GetAccessionTask.queue(connection).enqueue_job(j)
MainTask.queue(connection).enqueue_job(main_job)

wrapped_job = cls(main_job, connection)
if offpeak and not helper._is_offpeak(config.mercure.offpeak_start, config.mercure.offpeak_end, datetime.now().time()):
wrapped_job.pause()

Expand All @@ -377,7 +379,7 @@ def pause(self) -> None:
Pause the current job, including all its subjobs.
"""
for job_id in self.job.kwargs.get('subjobs',[]):
subjob = Job.fetch(job_id)
subjob = Job.fetch(job_id, connection=self.connection)
if subjob and (subjob.is_deferred or subjob.is_queued):
logger.debug(f"Pausing {subjob}")
subjob.meta['paused'] = True
Expand All @@ -392,11 +394,11 @@ def resume(self) -> None:
Resume a paused job by unpausing all its subjobs
"""
for subjob_id in self.job.kwargs.get('subjobs',[]):
subjob = Job.fetch(subjob_id)
subjob = Job.fetch(subjob_id, connection=self.connection)
if subjob and subjob.meta.get('paused', None):
subjob.meta['paused'] = False
subjob.save_meta() # type: ignore
Queue(subjob.origin).canceled_job_registry.requeue(subjob_id)
Queue(subjob.origin, connection=self.connection).canceled_job_registry.requeue(subjob_id)
self.job.get_meta()
self.job.meta['paused'] = False
self.job.save_meta() # type: ignore
Expand All @@ -415,8 +417,8 @@ def retry(self) -> None:
logger.info(f"Retrying {subjob} ({status}) {meta}")
if status == "failed" and (job_path:=Path(subjob.kwargs['path'])).exists():
shutil.rmtree(job_path) # Clean up after a failed job
Queue(subjob.origin).enqueue_job(subjob)
Queue(self.job.origin).enqueue_job(self.job)
Queue(subjob.origin, connection=self.connection).enqueue_job(subjob)
Queue(self.job.origin, connection=self.connection).enqueue_job(self.job)

@classmethod
def update_all_offpeak(cls) -> None:
Expand Down Expand Up @@ -445,7 +447,7 @@ def update_offpeak(self, is_offpeak) -> None:
self.pause()

def get_subjobs(self) -> Generator[Job, None, None]:
return (j for j in (Queue(self.job.origin).fetch_job(job) for job in self.job.kwargs.get('subjobs', [])) if j is not None)
return (j for j in (Queue(self.job.origin, connection=self.connection).fetch_job(job) for job in self.job.kwargs.get('subjobs', [])) if j is not None)

def get_status(self) -> JobStatus:
return cast(JobStatus,self.job.get_status())
Expand Down
Loading

0 comments on commit 3806b5d

Please sign in to comment.