|
| 1 | +# Copyright (C) 2024 Roberto Rossini <roberros@uio.no> |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: MIT |
| 4 | + |
| 5 | +import functools |
| 6 | +import hashlib |
| 7 | +import json |
| 8 | +import logging |
| 9 | +import math |
| 10 | +import pathlib |
| 11 | +import random |
| 12 | +import sys |
| 13 | +import tempfile |
| 14 | +import time |
| 15 | +import urllib.request |
| 16 | +from typing import Any, Dict, Tuple, Union |
| 17 | + |
| 18 | + |
| 19 | +@functools.cache |
| 20 | +def _get_datasets(max_size: float) -> Dict[str, Dict[str, str]]: |
| 21 | + assert not math.isnan(max_size) |
| 22 | + |
| 23 | + datasets = { |
| 24 | + "4DNFIOTPSS3L": { |
| 25 | + "url": "https://4dn-open-data-public.s3.amazonaws.com/fourfront-webprod/wfoutput/7386f953-8da9-47b0-acb2-931cba810544/4DNFIOTPSS3L.hic", |
| 26 | + "md5": "d8b030bec6918bfbb8581c700990f49d", |
| 27 | + "assembly": "dm6", |
| 28 | + "format": "hic", |
| 29 | + "size_mb": 248.10, |
| 30 | + }, |
| 31 | + "4DNFIC1CLPK7": { |
| 32 | + "url": "https://4dn-open-data-public.s3.amazonaws.com/fourfront-webprod/wfoutput/0dc0b1ba-5509-4464-9814-dfe103ff09a0/4DNFIC1CLPK7.hic", |
| 33 | + "md5": "9648c38d52fb467846cce42a4a072f57", |
| 34 | + "assembly": "galGal5", |
| 35 | + "format": "hic", |
| 36 | + "size_mb": 583.57, |
| 37 | + }, |
| 38 | + } |
| 39 | + |
| 40 | + valid_dsets = {k: v for k, v in datasets.items() if v.get("size_mb", math.inf) < max_size} |
| 41 | + |
| 42 | + if len(valid_dsets) > 0: |
| 43 | + return valid_dsets |
| 44 | + |
| 45 | + raise RuntimeError(f"unable to find any dataset smaller than {max_size:.2f} MB") |
| 46 | + |
| 47 | + |
| 48 | +def _list_datasets(): |
| 49 | + json.dump(_get_datasets(math.inf), fp=sys.stdout, indent=2) |
| 50 | + sys.stdout.write("\n") |
| 51 | + |
| 52 | + |
| 53 | +def _get_random_dataset(max_size: float) -> Tuple[str, Dict[str, str]]: |
| 54 | + dsets = _get_datasets(max_size) |
| 55 | + assert len(dsets) > 0 |
| 56 | + |
| 57 | + key = random.sample(list(dsets.keys()), 1)[0] |
| 58 | + return key, dsets[key] |
| 59 | + |
| 60 | + |
| 61 | +def _lookup_dataset(name: Union[str, None], assembly: Union[str, None], max_size: float) -> Tuple[str, Dict[str, str]]: |
| 62 | + if name is not None: |
| 63 | + max_size = math.inf |
| 64 | + try: |
| 65 | + return name, _get_datasets(max_size)[name] |
| 66 | + except KeyError as e: |
| 67 | + raise RuntimeError( |
| 68 | + f'unable to find dataset "{name}". Please make sure the provided dataset is present in the list produced by stripepy download --list-only.' |
| 69 | + ) from e |
| 70 | + |
| 71 | + assert assembly is not None |
| 72 | + assert max_size >= 0 |
| 73 | + |
| 74 | + dsets = {k: v for k, v in _get_datasets(max_size).items() if v["assembly"] == assembly} |
| 75 | + if len(dsets) == 0: |
| 76 | + raise RuntimeError( |
| 77 | + f'unable to find a dataset using "{assembly}" as reference genome. Please make sure such dataset exists in the list produced by stripepy download --list-only.' |
| 78 | + ) |
| 79 | + |
| 80 | + key = random.sample(list(dsets.keys()), 1)[0] |
| 81 | + return key, dsets[key] |
| 82 | + |
| 83 | + |
| 84 | +def _hash_file(path: pathlib.Path, chunk_size=16 << 20) -> str: |
| 85 | + logging.info('computing MD5 digest for file "%s"...', path) |
| 86 | + with path.open("rb") as f: |
| 87 | + hasher = hashlib.md5() |
| 88 | + while True: |
| 89 | + chunk = f.read(chunk_size) |
| 90 | + if not chunk: |
| 91 | + return hasher.hexdigest() |
| 92 | + hasher.update(chunk) |
| 93 | + |
| 94 | + |
| 95 | +def _download_progress_reporter(chunk_no, max_chunk_size, download_size): |
| 96 | + if download_size == -1: |
| 97 | + if not _download_progress_reporter.skip_progress_report: |
| 98 | + _download_progress_reporter.skip_progress_report = True |
| 99 | + logging.warning("unable to report download progress: remote file size is not known!") |
| 100 | + return |
| 101 | + |
| 102 | + timepoint = _download_progress_reporter.timepoint |
| 103 | + |
| 104 | + if time.time() - timepoint >= 15: |
| 105 | + mb_downloaded = (chunk_no * max_chunk_size) / (1024 << 10) |
| 106 | + download_size_mb = download_size / (1024 << 10) |
| 107 | + progress_pct = (mb_downloaded / download_size_mb) * 100 |
| 108 | + logging.info("downloaded %.2f/%.2f MB (%.2f%%)", mb_downloaded, download_size_mb, progress_pct) |
| 109 | + _download_progress_reporter.timepoint = time.time() |
| 110 | + |
| 111 | + |
| 112 | +# this is Python's way of defining static variables inside functions |
| 113 | +_download_progress_reporter.skip_progress_report = False |
| 114 | +_download_progress_reporter.timepoint = 0.0 |
| 115 | + |
| 116 | + |
| 117 | +def _download_and_checksum(name: str, dset: Dict[str, Any], dest: pathlib.Path): |
| 118 | + with tempfile.NamedTemporaryFile(dir=dest.parent, prefix=f"{dest.stem}.") as tmpfile: |
| 119 | + tmpfile = pathlib.Path(tmpfile.name) |
| 120 | + |
| 121 | + url = dset["url"] |
| 122 | + md5sum = dset["md5"] |
| 123 | + assembly = dset.get("assembly", "unknown") |
| 124 | + |
| 125 | + logging.info('downloading dataset "%s" (assembly=%s)...', name, assembly) |
| 126 | + t0 = time.time() |
| 127 | + urllib.request.urlretrieve(url, tmpfile, reporthook=_download_progress_reporter) |
| 128 | + t1 = time.time() |
| 129 | + logging.info('DONE! Downloading dataset "%s" took %.2fs.', name, t1 - t0) |
| 130 | + |
| 131 | + digest = _hash_file(tmpfile) |
| 132 | + if digest == md5sum: |
| 133 | + logging.info("MD5 checksum match!") |
| 134 | + return tmpfile.rename(dest) |
| 135 | + |
| 136 | + raise RuntimeError( |
| 137 | + f'MD5 checksum for file downloaded from "{url}" does not match: expected {md5sum}, found {digest}.' |
| 138 | + ) |
| 139 | + |
| 140 | + |
| 141 | +def run( |
| 142 | + name: Union[str, None], |
| 143 | + output_path: Union[pathlib.Path, None], |
| 144 | + assembly: Union[str, None], |
| 145 | + max_size: float, |
| 146 | + list_only: bool, |
| 147 | + force: bool, |
| 148 | +): |
| 149 | + t0 = time.time() |
| 150 | + if list_only: |
| 151 | + _list_datasets() |
| 152 | + return |
| 153 | + |
| 154 | + do_random_sample = name is None and assembly is None |
| 155 | + |
| 156 | + if do_random_sample: |
| 157 | + dset_name, config = _get_random_dataset(max_size) |
| 158 | + else: |
| 159 | + dset_name, config = _lookup_dataset(name, assembly, max_size) |
| 160 | + |
| 161 | + if output_path is None: |
| 162 | + output_path = pathlib.Path(f"{dset_name}." + config["format"]) |
| 163 | + |
| 164 | + if output_path.exists() and not force: |
| 165 | + raise RuntimeError(f"refusing to overwrite file {output_path}. Pass --force to overwrite.") |
| 166 | + output_path.unlink(missing_ok=True) |
| 167 | + |
| 168 | + dest = _download_and_checksum(dset_name, config, output_path) |
| 169 | + t1 = time.time() |
| 170 | + |
| 171 | + logging.info('successfully downloaded dataset "%s" to file "%s"', config["url"], dest) |
| 172 | + logging.info(f"file size: %.2fMB. Elapsed time: %.2fs", dest.stat().st_size / (1024 << 10), t1 - t0) |
0 commit comments