diff --git a/recirq/fermi_hubbard/publication.py b/recirq/fermi_hubbard/publication.py index 0b87ee39..9987bf1f 100644 --- a/recirq/fermi_hubbard/publication.py +++ b/recirq/fermi_hubbard/publication.py @@ -13,10 +13,11 @@ # limitations under the License. """Data specific to experiment published in arXiv:2010.07965.""" -from io import BytesIO from copy import deepcopy +from io import BytesIO import os -from typing import Callable, List, Optional, Tuple +import re +from typing import Callable, Optional, Tuple from urllib.request import urlopen from zipfile import ZipFile @@ -217,7 +218,7 @@ def rainbow23_layouts(sites_count: int = 8) -> Tuple[ZigZagLayout]: def fetch_publication_data( base_dir: Optional[str] = None, - exclude: Optional[List[str]] = None, + exclude: Optional[Tuple[str]] = (), ) -> None: """Downloads and extracts publication data from the Dryad repository at https://doi.org/10.5061/dryad.crjdfn32v, saving to disk. @@ -239,23 +240,26 @@ def fetch_publication_data( if base_dir is None: base_dir = "fermi_hubbard_data" - base_url = "https://datadryad.org/stash/downloads/file_stream/" - data = { - "gaussians_1u1d_nofloquet": "451326", - "gaussians_1u1d": "451327", - "trapping_2u2d": "451328", - "trapping_3u3d": "451329" - } - if exclude is not None: - data = {path: key for path, key in data.items() if path not in exclude} - - for path, key in data.items(): + fnames = { + "gaussians_1u1d_nofloquet", "gaussians_1u1d", "trapping_2u2d", "trapping_3u3d" + }.difference(exclude) + + # Determine file IDs. Note these are not permanent on Dryad. + file_ids = {} + for line in urlopen("https://doi.org/10.5061/dryad.crjdfn32v").readlines(): + for fname in fnames: + if fname + ".zip" in line.decode(): + file_id, = re.findall(r"file_stream/[\d]*\d+", line.decode()) + file_ids.update({fname: file_id}) + + # Download and extract files using IDs. + for path, key in file_ids.items(): print(f"Downloading {path}...") - if os.path.exists(path=base_dir + os.path.sep + path): + if os.path.exists(path=os.path.join(base_dir, path)): print("Data already exists.\n") continue - with urlopen(base_url + key) as stream: + with urlopen("https://datadryad.org/stash/downloads/" + key) as stream: with ZipFile(BytesIO(stream.read())) as zfile: zfile.extractall(base_dir) diff --git a/recirq/fermi_hubbard/publication_test.py b/recirq/fermi_hubbard/publication_test.py index 48cc481c..bd2a4c7a 100644 --- a/recirq/fermi_hubbard/publication_test.py +++ b/recirq/fermi_hubbard/publication_test.py @@ -18,10 +18,10 @@ def test_fetch_publication_data(): base_dir = "fermi_hubbard_data" - fetch_publication_data(base_dir=base_dir, exclude=["trapping_3u3d"]) + fetch_publication_data(base_dir=base_dir, exclude=("trapping_3u3d",)) for path in ("gaussians_1u1d_nofloquet", "gaussians_1u1d", "trapping_2u2d"): - assert os.path.exists(base_dir + os.path.sep + path) + assert os.path.exists(os.path.join(base_dir, path)) fetch_publication_data(base_dir=base_dir) - assert os.path.exists(base_dir + os.path.sep + "trapping_3u3d") + assert os.path.exists(os.path.join(base_dir, "trapping_3u3d"))