Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement stripepy download #22

Merged
merged 11 commits into from
Nov 9, 2024
172 changes: 172 additions & 0 deletions src/stripepy/cli/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (C) 2024 Roberto Rossini <roberros@uio.no>
#
# SPDX-License-Identifier: MIT

import functools
import hashlib
import json
import logging
import math
import pathlib
import random
import sys
import tempfile
import time
import urllib.request
from typing import Any, Dict, Tuple, Union


@functools.cache
def _get_datasets(max_size: float) -> Dict[str, Dict[str, str]]:
assert not math.isnan(max_size)

datasets = {
"4DNFIOTPSS3L": {
"url": "https://4dn-open-data-public.s3.amazonaws.com/fourfront-webprod/wfoutput/7386f953-8da9-47b0-acb2-931cba810544/4DNFIOTPSS3L.hic",
"md5": "d8b030bec6918bfbb8581c700990f49d",
"assembly": "dm6",
"format": "hic",
"size_mb": 248.10,
},
"4DNFIC1CLPK7": {
"url": "https://4dn-open-data-public.s3.amazonaws.com/fourfront-webprod/wfoutput/0dc0b1ba-5509-4464-9814-dfe103ff09a0/4DNFIC1CLPK7.hic",
"md5": "9648c38d52fb467846cce42a4a072f57",
"assembly": "galGal5",
"format": "hic",
"size_mb": 583.57,
},
}

valid_dsets = {k: v for k, v in datasets.items() if v.get("size_mb", math.inf) < max_size}

if len(valid_dsets) > 0:
return valid_dsets

raise RuntimeError(f"unable to find any dataset smaller than {max_size:.2f} MB")


def _list_datasets():
json.dump(_get_datasets(math.inf), fp=sys.stdout, indent=2)
sys.stdout.write("\n")


def _get_random_dataset(max_size: float) -> Tuple[str, Dict[str, str]]:
dsets = _get_datasets(max_size)
assert len(dsets) > 0

key = random.sample(list(dsets.keys()), 1)[0]
return key, dsets[key]


def _lookup_dataset(name: Union[str, None], assembly: Union[str, None], max_size: float) -> Tuple[str, Dict[str, str]]:
if name is not None:
max_size = math.inf
try:
return name, _get_datasets(max_size)[name]
except KeyError as e:
raise RuntimeError(
f'unable to find dataset "{name}". Please make sure the provided dataset is present in the list produced by stripepy download --list-only.'
) from e

assert assembly is not None
assert max_size >= 0

dsets = {k: v for k, v in _get_datasets(max_size).items() if v["assembly"] == assembly}
if len(dsets) == 0:
raise RuntimeError(
f'unable to find a dataset using "{assembly}" as reference genome. Please make sure such dataset exists in the list produced by stripepy download --list-only.'
)

key = random.sample(list(dsets.keys()), 1)[0]
return key, dsets[key]


def _hash_file(path: pathlib.Path, chunk_size=16 << 20) -> str:
logging.info('computing MD5 digest for file "%s"...', path)
with path.open("rb") as f:
hasher = hashlib.md5()
while True:
chunk = f.read(chunk_size)
if not chunk:
return hasher.hexdigest()
hasher.update(chunk)


def _download_progress_reporter(chunk_no, max_chunk_size, download_size):
if download_size == -1:
if not _download_progress_reporter.skip_progress_report:
_download_progress_reporter.skip_progress_report = True
logging.warning("unable to report download progress: remote file size is not known!")
return

timepoint = _download_progress_reporter.timepoint

if time.time() - timepoint >= 15:
mb_downloaded = (chunk_no * max_chunk_size) / (1024 << 10)
download_size_mb = download_size / (1024 << 10)
progress_pct = (mb_downloaded / download_size_mb) * 100
logging.info("downloaded %.2f/%.2f MB (%.2f%%)", mb_downloaded, download_size_mb, progress_pct)
_download_progress_reporter.timepoint = time.time()


# this is Python's way of defining static variables inside functions
_download_progress_reporter.skip_progress_report = False
_download_progress_reporter.timepoint = 0.0


def _download_and_checksum(name: str, dset: Dict[str, Any], dest: pathlib.Path):
with tempfile.NamedTemporaryFile(dir=dest.parent, prefix=f"{dest.stem}.") as tmpfile:
tmpfile = pathlib.Path(tmpfile.name)

url = dset["url"]
md5sum = dset["md5"]
assembly = dset.get("assembly", "unknown")

logging.info('downloading dataset "%s" (assembly=%s)...', name, assembly)
t0 = time.time()
urllib.request.urlretrieve(url, tmpfile, reporthook=_download_progress_reporter)
t1 = time.time()
logging.info('DONE! Downloading dataset "%s" took %.2fs.', name, t1 - t0)

digest = _hash_file(tmpfile)
if digest == md5sum:
logging.info("MD5 checksum match!")
return tmpfile.rename(dest)

raise RuntimeError(
f'MD5 checksum for file downloaded from "{url}" does not match: expected {md5sum}, found {digest}.'
)


def run(
name: Union[str, None],
output_path: Union[pathlib.Path, None],
assembly: Union[str, None],
max_size: float,
list_only: bool,
force: bool,
):
t0 = time.time()
if list_only:
_list_datasets()
return

do_random_sample = name is None and assembly is None

if do_random_sample:
dset_name, config = _get_random_dataset(max_size)
else:
dset_name, config = _lookup_dataset(name, assembly, max_size)

if output_path is None:
output_path = pathlib.Path(f"{dset_name}." + config["format"])

if output_path.exists() and not force:
raise RuntimeError(f"refusing to overwrite file {output_path}. Pass --force to overwrite.")
output_path.unlink(missing_ok=True)

dest = _download_and_checksum(dset_name, config, output_path)
t1 = time.time()

logging.info('successfully downloaded dataset "%s" to file "%s"', config["url"], dest)
logging.info(f"file size: %.2fMB. Elapsed time: %.2fs", dest.stat().st_size / (1024 << 10), t1 - t0)
66 changes: 66 additions & 0 deletions src/stripepy/cli/setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import math
import pathlib
from importlib.metadata import version
from typing import Any, Dict, Tuple
Expand Down Expand Up @@ -32,6 +33,13 @@ def _probability(arg) -> float:
raise ValueError("Not a valid probability")


def _non_zero_positive_float(arg) -> float:
if (n := float(arg)) > 0:
return n

raise ValueError("Not a non-zero, positive float")


def _make_stripepy_call_subcommand(main_parser) -> argparse.ArgumentParser:
sc: argparse.ArgumentParser = main_parser.add_parser(
"call",
Expand Down Expand Up @@ -137,6 +145,61 @@ def _make_stripepy_call_subcommand(main_parser) -> argparse.ArgumentParser:
return sc


def _make_stripepy_download_subcommand(main_parser) -> argparse.ArgumentParser:
sc: argparse.ArgumentParser = main_parser.add_parser(
"download",
help="Helper command to simplify downloading datasets that can be used to test StripePy.",
)

def get_avail_ref_genomes():
from .download import _get_datasets

return {record["assembly"] for record in _get_datasets(math.inf).values() if "assembly" in record}

grp = sc.add_mutually_exclusive_group(required=False)
grp.add_argument(
"--assembly",
type=str,
choices=get_avail_ref_genomes(),
help="Restrict downloads to the given reference genome assembly.",
)
grp.add_argument(
"--name",
type=str,
help="Name of the dataset to be downloaded.\n"
"When not provided, randomly select and download a dataset based on the provided CLI options (if any).",
)
grp.add_argument(
"--list-only",
action="store_true",
default=False,
help="Print the list of available datasets and return.",
)

sc.add_argument(
"--max-size",
type=_non_zero_positive_float,
default=512.0,
help="Upper bound for the size of the files to be considered when --name is not provided.",
)
sc.add_argument(
"-o",
"--output",
type=pathlib.Path,
dest="output_path",
help="Path where to store the downloaded file.",
)
sc.add_argument(
"-f",
"--force",
action="store_true",
default=False,
help="Overwrite existing file(s).",
)

return sc


def _make_cli() -> argparse.ArgumentParser:
cli = argparse.ArgumentParser(
description="stripepy is designed to recognize linear patterns in contact maps (.hic, .mcool, .cool) "
Expand All @@ -149,6 +212,7 @@ def _make_cli() -> argparse.ArgumentParser:
)

_make_stripepy_call_subcommand(sub_parser)
_make_stripepy_download_subcommand(sub_parser)

cli.add_argument(
"-v",
Expand Down Expand Up @@ -207,5 +271,7 @@ def parse_args() -> Tuple[str, Any]:
subcommand = args.pop("subcommand")
if subcommand == "call":
return subcommand, _process_stripepy_call_args(args)
if subcommand == "download":
return subcommand, args

raise NotImplementedError
13 changes: 12 additions & 1 deletion src/stripepy/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from .cli import call, setup
import logging

from .cli import call, download, setup


def _setup_logger(level: str):
fmt = "[%(asctime)s] %(levelname)s: %(message)s"
logging.basicConfig(level=level, format=fmt)
logging.getLogger().setLevel(level)


def main():
subcommand, args = setup.parse_args()
_setup_logger("INFO") # TODO make tunable

if subcommand == "call":
return call.run(**args)
if subcommand == "download":
return download.run(**args)

raise NotImplementedError

Expand Down