Skip to content

Commit c427c9c

Browse files
authored
Merge pull request #22 from robomics/feature/stripepy-download
Implement stripepy download
2 parents 336fc1a + 205e6db commit c427c9c

File tree

3 files changed

+250
-1
lines changed

3 files changed

+250
-1
lines changed

src/stripepy/cli/download.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright (C) 2024 Roberto Rossini <roberros@uio.no>
2+
#
3+
# SPDX-License-Identifier: MIT
4+
5+
import functools
6+
import hashlib
7+
import json
8+
import logging
9+
import math
10+
import pathlib
11+
import random
12+
import sys
13+
import tempfile
14+
import time
15+
import urllib.request
16+
from typing import Any, Dict, Tuple, Union
17+
18+
19+
@functools.cache
20+
def _get_datasets(max_size: float) -> Dict[str, Dict[str, str]]:
21+
assert not math.isnan(max_size)
22+
23+
datasets = {
24+
"4DNFIOTPSS3L": {
25+
"url": "https://4dn-open-data-public.s3.amazonaws.com/fourfront-webprod/wfoutput/7386f953-8da9-47b0-acb2-931cba810544/4DNFIOTPSS3L.hic",
26+
"md5": "d8b030bec6918bfbb8581c700990f49d",
27+
"assembly": "dm6",
28+
"format": "hic",
29+
"size_mb": 248.10,
30+
},
31+
"4DNFIC1CLPK7": {
32+
"url": "https://4dn-open-data-public.s3.amazonaws.com/fourfront-webprod/wfoutput/0dc0b1ba-5509-4464-9814-dfe103ff09a0/4DNFIC1CLPK7.hic",
33+
"md5": "9648c38d52fb467846cce42a4a072f57",
34+
"assembly": "galGal5",
35+
"format": "hic",
36+
"size_mb": 583.57,
37+
},
38+
}
39+
40+
valid_dsets = {k: v for k, v in datasets.items() if v.get("size_mb", math.inf) < max_size}
41+
42+
if len(valid_dsets) > 0:
43+
return valid_dsets
44+
45+
raise RuntimeError(f"unable to find any dataset smaller than {max_size:.2f} MB")
46+
47+
48+
def _list_datasets():
49+
json.dump(_get_datasets(math.inf), fp=sys.stdout, indent=2)
50+
sys.stdout.write("\n")
51+
52+
53+
def _get_random_dataset(max_size: float) -> Tuple[str, Dict[str, str]]:
54+
dsets = _get_datasets(max_size)
55+
assert len(dsets) > 0
56+
57+
key = random.sample(list(dsets.keys()), 1)[0]
58+
return key, dsets[key]
59+
60+
61+
def _lookup_dataset(name: Union[str, None], assembly: Union[str, None], max_size: float) -> Tuple[str, Dict[str, str]]:
62+
if name is not None:
63+
max_size = math.inf
64+
try:
65+
return name, _get_datasets(max_size)[name]
66+
except KeyError as e:
67+
raise RuntimeError(
68+
f'unable to find dataset "{name}". Please make sure the provided dataset is present in the list produced by stripepy download --list-only.'
69+
) from e
70+
71+
assert assembly is not None
72+
assert max_size >= 0
73+
74+
dsets = {k: v for k, v in _get_datasets(max_size).items() if v["assembly"] == assembly}
75+
if len(dsets) == 0:
76+
raise RuntimeError(
77+
f'unable to find a dataset using "{assembly}" as reference genome. Please make sure such dataset exists in the list produced by stripepy download --list-only.'
78+
)
79+
80+
key = random.sample(list(dsets.keys()), 1)[0]
81+
return key, dsets[key]
82+
83+
84+
def _hash_file(path: pathlib.Path, chunk_size=16 << 20) -> str:
85+
logging.info('computing MD5 digest for file "%s"...', path)
86+
with path.open("rb") as f:
87+
hasher = hashlib.md5()
88+
while True:
89+
chunk = f.read(chunk_size)
90+
if not chunk:
91+
return hasher.hexdigest()
92+
hasher.update(chunk)
93+
94+
95+
def _download_progress_reporter(chunk_no, max_chunk_size, download_size):
96+
if download_size == -1:
97+
if not _download_progress_reporter.skip_progress_report:
98+
_download_progress_reporter.skip_progress_report = True
99+
logging.warning("unable to report download progress: remote file size is not known!")
100+
return
101+
102+
timepoint = _download_progress_reporter.timepoint
103+
104+
if time.time() - timepoint >= 15:
105+
mb_downloaded = (chunk_no * max_chunk_size) / (1024 << 10)
106+
download_size_mb = download_size / (1024 << 10)
107+
progress_pct = (mb_downloaded / download_size_mb) * 100
108+
logging.info("downloaded %.2f/%.2f MB (%.2f%%)", mb_downloaded, download_size_mb, progress_pct)
109+
_download_progress_reporter.timepoint = time.time()
110+
111+
112+
# this is Python's way of defining static variables inside functions
113+
_download_progress_reporter.skip_progress_report = False
114+
_download_progress_reporter.timepoint = 0.0
115+
116+
117+
def _download_and_checksum(name: str, dset: Dict[str, Any], dest: pathlib.Path):
118+
with tempfile.NamedTemporaryFile(dir=dest.parent, prefix=f"{dest.stem}.") as tmpfile:
119+
tmpfile = pathlib.Path(tmpfile.name)
120+
121+
url = dset["url"]
122+
md5sum = dset["md5"]
123+
assembly = dset.get("assembly", "unknown")
124+
125+
logging.info('downloading dataset "%s" (assembly=%s)...', name, assembly)
126+
t0 = time.time()
127+
urllib.request.urlretrieve(url, tmpfile, reporthook=_download_progress_reporter)
128+
t1 = time.time()
129+
logging.info('DONE! Downloading dataset "%s" took %.2fs.', name, t1 - t0)
130+
131+
digest = _hash_file(tmpfile)
132+
if digest == md5sum:
133+
logging.info("MD5 checksum match!")
134+
return tmpfile.rename(dest)
135+
136+
raise RuntimeError(
137+
f'MD5 checksum for file downloaded from "{url}" does not match: expected {md5sum}, found {digest}.'
138+
)
139+
140+
141+
def run(
142+
name: Union[str, None],
143+
output_path: Union[pathlib.Path, None],
144+
assembly: Union[str, None],
145+
max_size: float,
146+
list_only: bool,
147+
force: bool,
148+
):
149+
t0 = time.time()
150+
if list_only:
151+
_list_datasets()
152+
return
153+
154+
do_random_sample = name is None and assembly is None
155+
156+
if do_random_sample:
157+
dset_name, config = _get_random_dataset(max_size)
158+
else:
159+
dset_name, config = _lookup_dataset(name, assembly, max_size)
160+
161+
if output_path is None:
162+
output_path = pathlib.Path(f"{dset_name}." + config["format"])
163+
164+
if output_path.exists() and not force:
165+
raise RuntimeError(f"refusing to overwrite file {output_path}. Pass --force to overwrite.")
166+
output_path.unlink(missing_ok=True)
167+
168+
dest = _download_and_checksum(dset_name, config, output_path)
169+
t1 = time.time()
170+
171+
logging.info('successfully downloaded dataset "%s" to file "%s"', config["url"], dest)
172+
logging.info(f"file size: %.2fMB. Elapsed time: %.2fs", dest.stat().st_size / (1024 << 10), t1 - t0)

src/stripepy/cli/setup.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import math
23
import pathlib
34
from importlib.metadata import version
45
from typing import Any, Dict, Tuple
@@ -32,6 +33,13 @@ def _probability(arg) -> float:
3233
raise ValueError("Not a valid probability")
3334

3435

36+
def _non_zero_positive_float(arg) -> float:
37+
if (n := float(arg)) > 0:
38+
return n
39+
40+
raise ValueError("Not a non-zero, positive float")
41+
42+
3543
def _make_stripepy_call_subcommand(main_parser) -> argparse.ArgumentParser:
3644
sc: argparse.ArgumentParser = main_parser.add_parser(
3745
"call",
@@ -137,6 +145,61 @@ def _make_stripepy_call_subcommand(main_parser) -> argparse.ArgumentParser:
137145
return sc
138146

139147

148+
def _make_stripepy_download_subcommand(main_parser) -> argparse.ArgumentParser:
149+
sc: argparse.ArgumentParser = main_parser.add_parser(
150+
"download",
151+
help="Helper command to simplify downloading datasets that can be used to test StripePy.",
152+
)
153+
154+
def get_avail_ref_genomes():
155+
from .download import _get_datasets
156+
157+
return {record["assembly"] for record in _get_datasets(math.inf).values() if "assembly" in record}
158+
159+
grp = sc.add_mutually_exclusive_group(required=False)
160+
grp.add_argument(
161+
"--assembly",
162+
type=str,
163+
choices=get_avail_ref_genomes(),
164+
help="Restrict downloads to the given reference genome assembly.",
165+
)
166+
grp.add_argument(
167+
"--name",
168+
type=str,
169+
help="Name of the dataset to be downloaded.\n"
170+
"When not provided, randomly select and download a dataset based on the provided CLI options (if any).",
171+
)
172+
grp.add_argument(
173+
"--list-only",
174+
action="store_true",
175+
default=False,
176+
help="Print the list of available datasets and return.",
177+
)
178+
179+
sc.add_argument(
180+
"--max-size",
181+
type=_non_zero_positive_float,
182+
default=512.0,
183+
help="Upper bound for the size of the files to be considered when --name is not provided.",
184+
)
185+
sc.add_argument(
186+
"-o",
187+
"--output",
188+
type=pathlib.Path,
189+
dest="output_path",
190+
help="Path where to store the downloaded file.",
191+
)
192+
sc.add_argument(
193+
"-f",
194+
"--force",
195+
action="store_true",
196+
default=False,
197+
help="Overwrite existing file(s).",
198+
)
199+
200+
return sc
201+
202+
140203
def _make_cli() -> argparse.ArgumentParser:
141204
cli = argparse.ArgumentParser(
142205
description="stripepy is designed to recognize linear patterns in contact maps (.hic, .mcool, .cool) "
@@ -149,6 +212,7 @@ def _make_cli() -> argparse.ArgumentParser:
149212
)
150213

151214
_make_stripepy_call_subcommand(sub_parser)
215+
_make_stripepy_download_subcommand(sub_parser)
152216

153217
cli.add_argument(
154218
"-v",
@@ -207,5 +271,7 @@ def parse_args() -> Tuple[str, Any]:
207271
subcommand = args.pop("subcommand")
208272
if subcommand == "call":
209273
return subcommand, _process_stripepy_call_args(args)
274+
if subcommand == "download":
275+
return subcommand, args
210276

211277
raise NotImplementedError

src/stripepy/main.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,22 @@
1-
from .cli import call, setup
1+
import logging
2+
3+
from .cli import call, download, setup
4+
5+
6+
def _setup_logger(level: str):
7+
fmt = "[%(asctime)s] %(levelname)s: %(message)s"
8+
logging.basicConfig(level=level, format=fmt)
9+
logging.getLogger().setLevel(level)
210

311

412
def main():
513
subcommand, args = setup.parse_args()
14+
_setup_logger("INFO") # TODO make tunable
615

716
if subcommand == "call":
817
return call.run(**args)
18+
if subcommand == "download":
19+
return download.run(**args)
920

1021
raise NotImplementedError
1122

0 commit comments

Comments
 (0)