Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 120 additions & 11 deletions scripts/cromwell_status_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,100 @@
from google.cloud import storage

from cpg_utils import to_path
from cpg_utils.config import config_retrieve
from cpg_utils.constants import CROMWELL_URL
from cpg_utils.cromwell import get_cromwell_oauth_token
from metamist.apis import AnalysisApi, ParticipantApi
from metamist.models import Analysis, AnalysisStatus

batch_url = (
'https://batch.hail.populationgenomics.org.au/api/v1alpha/batches/{batch_id}'
)
job_url = 'https://batch.hail.populationgenomics.org.au/api/v1alpha/batches/{batch_id}/jobs/{job_id}'
job_log_url = 'https://batch.hail.populationgenomics.org.au/api/v1alpha/batches/{batch_id}/jobs/{job_id}/log'


def get_hail_batch_from_id(batch_id: int, headers: dict | None = None):
"""
Get a Hail Batch from its ID.
"""
url = batch_url.format(batch_id=batch_id)
try:
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
print(f'Error fetching Hail Batch: {e!r}')
sys.exit(1)


def get_watch_jobs(batch_id: int, n_jobs: int, headers: dict | None = None):
"""
Get the Cromwell watch jobs from a Batch.
"""
watch_jobs = []
for job_id in range(1, n_jobs + 1):
url = job_url.format(batch_id=batch_id, job_id=job_id)
try:
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
job = response.json()
if not isinstance(job, dict):
print(f'Expected a job dict, got {type(job)}')
sys.exit(1)
if job.get('name', '').endswith('_watch') and job['state'] != 'Success':
# Only add jobs that are watch jobs and not successful
dataset, sg_id = job['name'].split(':')[0].split('/')
watch_jobs.append((job_id, dataset, sg_id))
except requests.RequestException as e:
print(f'Error fetching job {job_id}: {e!r}')
sys.exit(1)

return watch_jobs


def get_workflow_id_from_watch_job(
batch_id: int,
job_id: int,
headers: dict | None = None,
):
"""
Get the Cromwell workflow ID from a watch job log.
"""
url = job_log_url.format(batch_id=batch_id, job_id=job_id)
try:
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
job_log = response.json()
workflow_id_start = job_log['main'].find('Received workflow ID: ')
if workflow_id_start != -1:
workflow_id_start += len('Received workflow ID: ')
workflow_id = job_log['main'][workflow_id_start : workflow_id_start + 36]
print(f'Job {job_id} Workflow ID: {workflow_id}')
return workflow_id
return None

except requests.RequestException as e:
print(f'Error fetching job log: {e!r}')
sys.exit(1)


def abort_workflow(workflow_id: str):
"""
Abort a Cromwell workflow by its ID.
"""
url = f'{CROMWELL_URL}/api/workflows/v1/{workflow_id}/abort'
headers = {
'accept': 'application/json',
'Authorization': f'Bearer {get_cromwell_oauth_token()}',
}
try:
response = requests.post(url, headers=headers, timeout=10)
response.raise_for_status()
print(f'Workflow {workflow_id} aborted successfully.')
except requests.RequestException as e:
print(f'Error aborting workflow {workflow_id}: {e!r}')


def get_workflow_metadata_from_file(workflow_metadata_file_path: str):
try:
Expand Down Expand Up @@ -306,29 +395,49 @@ async def create_sv_analyses_async(sg_datasets: dict, sg_analyses: dict):


@click.command()
@click.option('--dataset', required=True, help='Dataset name(s)', multiple=True)
@click.option(
'--workflow-id',
required=True,
help='Cromwell workflow ID',
multiple=True,
)
@click.option('--batch-id', required=True, help='ID of the Batch containg GATK SV jobs')
@click.option('--dry-run', is_flag=True, help='Dry run mode')
def main(dataset: list[str], workflow_id: list[str], dry_run: bool = False):
def main(batch_id: int, dry_run: bool = False):
"""
A script to parse Cromwell workflow metadata and collect any successful SV outputs.
Copies the outputs to the dataset's main bucket and creates analyses for the successful sub-workflows.

The sub-workflows are identified by the GatherSampleEvidence prefix in their names:
GatherSampleEvidence.scramble, GatherSampleEvidence.wham, GatherSampleEvidence.manta.
"""
sg_peid_map = get_sgid_peid_map(dataset)
batch_auth_token = config_retrieve(['batch', 'auth_token'])
batch_headers = {'Authorization': f'Bearer {batch_auth_token}'}
batch_data = get_hail_batch_from_id(batch_id, headers=batch_headers)
print(f'Batch {batch_id} found: {batch_data["name"]}, {batch_data["n_jobs"]} jobs')
watch_jobs = get_watch_jobs(
batch_id,
batch_data['n_jobs'],
headers=batch_headers,
)
print(f'Found {len(watch_jobs)} unsuccessful watch jobs.')
print('Unsuccessful watch jobs:', [job_id for job_id, _, _ in watch_jobs])

wf_id_info = []
for job_id, dataset, sg_id in watch_jobs:
wf_id = get_workflow_id_from_watch_job(batch_id, job_id, headers=batch_headers)
if wf_id:
wf_id_info.append((dataset, sg_id, wf_id))
else:
print(f'No workflow ID found for job {job_id} in batch {batch_id}')

for dataset, sg_id, workflow_id in wf_id_info:
print(f'Dataset: {dataset}, SG ID: {sg_id}, Workflow ID: {workflow_id}')

wf_ids = [wf_id for _, _, wf_id in wf_id_info]
datasets = {dataset for dataset, _, _ in wf_id_info}
sg_peid_map = get_sgid_peid_map(list(datasets))

sg_analyses_sizes = {}
sg_datasets = {}
sg_analyses = {}
print(f'Parsing {len(workflow_id)} workflows...')
for wf_id in workflow_id:
print(f'Aborting and parsing {len(wf_ids)} workflows...')
for wf_id in wf_ids:
abort_workflow(wf_id)
json_data = get_workflow_metadata_from_api(wf_id)
workflow_results = parse_workflow_status_and_outputs(wf_id, json_data)

Expand Down
Loading