From a8d23077e96a734db19e2e368a6ca8a691f3df5c Mon Sep 17 00:00:00 2001 From: EddieLF Date: Thu, 29 May 2025 10:42:54 +1000 Subject: [PATCH] Find jobs automatically from Hail Batch and abort them through cromwell --- scripts/cromwell_status_parser.py | 131 +++++++++++++++++++++++++++--- 1 file changed, 120 insertions(+), 11 deletions(-) diff --git a/scripts/cromwell_status_parser.py b/scripts/cromwell_status_parser.py index 7739b75c..3d56d340 100644 --- a/scripts/cromwell_status_parser.py +++ b/scripts/cromwell_status_parser.py @@ -7,11 +7,100 @@ from google.cloud import storage from cpg_utils import to_path +from cpg_utils.config import config_retrieve from cpg_utils.constants import CROMWELL_URL from cpg_utils.cromwell import get_cromwell_oauth_token from metamist.apis import AnalysisApi, ParticipantApi from metamist.models import Analysis, AnalysisStatus +batch_url = ( + 'https://batch.hail.populationgenomics.org.au/api/v1alpha/batches/{batch_id}' +) +job_url = 'https://batch.hail.populationgenomics.org.au/api/v1alpha/batches/{batch_id}/jobs/{job_id}' +job_log_url = 'https://batch.hail.populationgenomics.org.au/api/v1alpha/batches/{batch_id}/jobs/{job_id}/log' + + +def get_hail_batch_from_id(batch_id: int, headers: dict | None = None): + """ + Get a Hail Batch from its ID. + """ + url = batch_url.format(batch_id=batch_id) + try: + response = requests.get(url, headers=headers, timeout=10) + response.raise_for_status() + return response.json() + except requests.RequestException as e: + print(f'Error fetching Hail Batch: {e!r}') + sys.exit(1) + + +def get_watch_jobs(batch_id: int, n_jobs: int, headers: dict | None = None): + """ + Get the Cromwell watch jobs from a Batch. + """ + watch_jobs = [] + for job_id in range(1, n_jobs + 1): + url = job_url.format(batch_id=batch_id, job_id=job_id) + try: + response = requests.get(url, headers=headers, timeout=10) + response.raise_for_status() + job = response.json() + if not isinstance(job, dict): + print(f'Expected a job dict, got {type(job)}') + sys.exit(1) + if job.get('name', '').endswith('_watch') and job['state'] != 'Success': + # Only add jobs that are watch jobs and not successful + dataset, sg_id = job['name'].split(':')[0].split('/') + watch_jobs.append((job_id, dataset, sg_id)) + except requests.RequestException as e: + print(f'Error fetching job {job_id}: {e!r}') + sys.exit(1) + + return watch_jobs + + +def get_workflow_id_from_watch_job( + batch_id: int, + job_id: int, + headers: dict | None = None, +): + """ + Get the Cromwell workflow ID from a watch job log. + """ + url = job_log_url.format(batch_id=batch_id, job_id=job_id) + try: + response = requests.get(url, headers=headers, timeout=10) + response.raise_for_status() + job_log = response.json() + workflow_id_start = job_log['main'].find('Received workflow ID: ') + if workflow_id_start != -1: + workflow_id_start += len('Received workflow ID: ') + workflow_id = job_log['main'][workflow_id_start : workflow_id_start + 36] + print(f'Job {job_id} Workflow ID: {workflow_id}') + return workflow_id + return None + + except requests.RequestException as e: + print(f'Error fetching job log: {e!r}') + sys.exit(1) + + +def abort_workflow(workflow_id: str): + """ + Abort a Cromwell workflow by its ID. + """ + url = f'{CROMWELL_URL}/api/workflows/v1/{workflow_id}/abort' + headers = { + 'accept': 'application/json', + 'Authorization': f'Bearer {get_cromwell_oauth_token()}', + } + try: + response = requests.post(url, headers=headers, timeout=10) + response.raise_for_status() + print(f'Workflow {workflow_id} aborted successfully.') + except requests.RequestException as e: + print(f'Error aborting workflow {workflow_id}: {e!r}') + def get_workflow_metadata_from_file(workflow_metadata_file_path: str): try: @@ -306,15 +395,9 @@ async def create_sv_analyses_async(sg_datasets: dict, sg_analyses: dict): @click.command() -@click.option('--dataset', required=True, help='Dataset name(s)', multiple=True) -@click.option( - '--workflow-id', - required=True, - help='Cromwell workflow ID', - multiple=True, -) +@click.option('--batch-id', required=True, help='ID of the Batch containg GATK SV jobs') @click.option('--dry-run', is_flag=True, help='Dry run mode') -def main(dataset: list[str], workflow_id: list[str], dry_run: bool = False): +def main(batch_id: int, dry_run: bool = False): """ A script to parse Cromwell workflow metadata and collect any successful SV outputs. Copies the outputs to the dataset's main bucket and creates analyses for the successful sub-workflows. @@ -322,13 +405,39 @@ def main(dataset: list[str], workflow_id: list[str], dry_run: bool = False): The sub-workflows are identified by the GatherSampleEvidence prefix in their names: GatherSampleEvidence.scramble, GatherSampleEvidence.wham, GatherSampleEvidence.manta. """ - sg_peid_map = get_sgid_peid_map(dataset) + batch_auth_token = config_retrieve(['batch', 'auth_token']) + batch_headers = {'Authorization': f'Bearer {batch_auth_token}'} + batch_data = get_hail_batch_from_id(batch_id, headers=batch_headers) + print(f'Batch {batch_id} found: {batch_data["name"]}, {batch_data["n_jobs"]} jobs') + watch_jobs = get_watch_jobs( + batch_id, + batch_data['n_jobs'], + headers=batch_headers, + ) + print(f'Found {len(watch_jobs)} unsuccessful watch jobs.') + print('Unsuccessful watch jobs:', [job_id for job_id, _, _ in watch_jobs]) + + wf_id_info = [] + for job_id, dataset, sg_id in watch_jobs: + wf_id = get_workflow_id_from_watch_job(batch_id, job_id, headers=batch_headers) + if wf_id: + wf_id_info.append((dataset, sg_id, wf_id)) + else: + print(f'No workflow ID found for job {job_id} in batch {batch_id}') + + for dataset, sg_id, workflow_id in wf_id_info: + print(f'Dataset: {dataset}, SG ID: {sg_id}, Workflow ID: {workflow_id}') + + wf_ids = [wf_id for _, _, wf_id in wf_id_info] + datasets = {dataset for dataset, _, _ in wf_id_info} + sg_peid_map = get_sgid_peid_map(list(datasets)) sg_analyses_sizes = {} sg_datasets = {} sg_analyses = {} - print(f'Parsing {len(workflow_id)} workflows...') - for wf_id in workflow_id: + print(f'Aborting and parsing {len(wf_ids)} workflows...') + for wf_id in wf_ids: + abort_workflow(wf_id) json_data = get_workflow_metadata_from_api(wf_id) workflow_results = parse_workflow_status_and_outputs(wf_id, json_data)