Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
FNTwin committed Jun 18, 2024
1 parent e85839a commit ea08244
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 0 deletions.
179 changes: 179 additions & 0 deletions openqdc/utils/download_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import gzip
import os
import shutil
import socket
import tarfile
import urllib.error
import urllib.request
import zipfile

import fsspec
import gdown
import requests
import tqdm
from loguru import logger
from sklearn.utils import Bunch

from openqdc.utils.io import get_local_cache


def download_url(url, local_filename):
"""
Download a file from a url to a local file.
Parameters
----------
url : str
URL to download from.
local_filename : str
Local path for destination.
"""
logger.info(f"Url: {url} File: {local_filename}")
if "drive.google.com" in url:
gdown.download(url, local_filename, quiet=False)
elif "raw.github" in url:
r = requests.get(url, allow_redirects=True)
with open(local_filename, "wb") as f:
f.write(r.content)
else:
r = requests.get(url, stream=True)
with fsspec.open(local_filename, "wb") as f:
for chunk in tqdm.tqdm(r.iter_content(chunk_size=16384)):
if chunk:
f.write(chunk)


def decompress_tar_gz(local_filename):
"""
Decompress a tar.gz file.
Parameters
----------
local_filename : str
Path to local file to decompress.
"""
parent = os.path.dirname(local_filename)
with tarfile.open(local_filename) as tar:
logger.info(f"Verifying archive extraction states: {local_filename}")
all_names = tar.getnames()
all_extracted = all([os.path.exists(os.path.join(parent, x)) for x in all_names])
if not all_extracted:
logger.info(f"Extracting archive: {local_filename}")
tar.extractall(path=parent)
else:
logger.info(f"Archive already extracted: {local_filename}")


def decompress_zip(local_filename):
"""
Decompress a zip file.
Parameters
----------
local_filename : str
Path to local file to decompress.
"""
parent = os.path.dirname(local_filename)

logger.info(f"Verifying archive extraction states: {local_filename}")
with zipfile.ZipFile(local_filename, "r") as zip_ref:
all_names = zip_ref.namelist()
all_extracted = all([os.path.exists(os.path.join(parent, x)) for x in all_names])
if not all_extracted:
logger.info(f"Extracting archive: {local_filename}")
zip_ref.extractall(parent)
else:
logger.info(f"Archive already extracted: {local_filename}")


def decompress_gz(local_filename):
"""
Decompress a gz file.
Parameters
----------
local_filename : str
Path to local file to decompress.
"""
logger.info(f"Verifying archive extraction states: {local_filename}")
out_filename = local_filename.replace(".gz", "")
if out_filename.endswith("hdf5"):
out_filename = local_filename.replace("hdf5", "h5")

all_extracted = os.path.exists(out_filename)
if not all_extracted:
logger.info(f"Extracting archive: {local_filename}")
with gzip.open(local_filename, "rb") as f_in, open(out_filename, "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
else:
logger.info(f"Archive already extracted: {local_filename}")


def fetch_file(url, local_filename, overwrite=False):
"""
Download a file from a url to a local file. Useful for big files.
Parameters
----------
url : str
URL to download from.
local_filename : str
Local file to save to.
overwrite : bool
Whether to overwrite existing files.
Returns
-------
local_filename : str
Local file.
"""
try:
if os.path.exists(local_filename) and not overwrite:
logger.info("File already exists, skipping download")
else:
download_url(url, local_filename)

# decompress archive if necessary
parent = os.path.dirname(local_filename)
if local_filename.endswith("tar.gz"):
decompress_tar_gz(local_filename)

elif local_filename.endswith("zip"):
decompress_zip(local_filename)

elif local_filename.endswith(".gz"):
decompress_gz(local_filename)

elif local_filename.endswith("xz"):
logger.info(f"Extracting archive: {local_filename}")
os.system(f"cd {parent} && xz -d *.xz")

else:
pass

except (socket.gaierror, urllib.error.URLError) as err:
raise ConnectionError("Could not download {} due to {}".format(url, err))

return local_filename


class DataDownloader:
"""Download data from a remote source.
Parameters
----------
cache_path : str
Path to the cache directory.
overwrite : bool
Whether to overwrite existing files.
"""

def __init__(self, cache_path=None, overwrite=False):
if cache_path is None:
cache_path = get_local_cache()

self.cache_path = cache_path
self.overwrite = overwrite

def from_config(self, config: dict):
b_config = Bunch(**config)
data_path = os.path.join(self.cache_path, b_config.dataset_name)
os.makedirs(data_path, exist_ok=True)

logger.info(f"Downloading the {b_config.dataset_name} dataset")
for local, link in b_config.links.items():
outfile = os.path.join(data_path, local)
fetch_file(link, outfile)
44 changes: 44 additions & 0 deletions tests/test_energies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import pytest

from openqdc.datasets.energies import AtomEnergies, AtomEnergy
from openqdc.methods import PotentialMethod


class Container:
__name__ = "container"
__energy_methods__ = [PotentialMethod.WB97M_D3BJ_DEF2_TZVPPD]
energy_methods = [str(PotentialMethod.WB97M_D3BJ_DEF2_TZVPPD)]
refit_e0s = True

def __init__(self, energy_type="formation"):
self.energy_type = energy_type


@pytest.fixture
def physical_energies():
dummy = Container()
return AtomEnergies(dummy)


def test_atom_energies_object(physical_energies):
assert isinstance(physical_energies, AtomEnergies)


def test_indexing(physical_energies):
assert isinstance(physical_energies[6], AtomEnergy)
assert isinstance(physical_energies[(6, 1)], AtomEnergy)
assert isinstance(physical_energies[6, 1], AtomEnergy)
assert isinstance(physical_energies[("C", 1)], AtomEnergy)
assert isinstance(physical_energies["C", 1], AtomEnergy)
assert physical_energies[("C", 1)] == physical_energies[(6, 1)]
assert not physical_energies[("Cl", -2)] == physical_energies[(6, 1)]
with pytest.raises(KeyError):
physical_energies[("Cl", -6)]


def test_matrix(physical_energies):
matrix = physical_energies.e0s_matrix
assert len(matrix) == 1
assert isinstance(matrix, np.ndarray)
assert np.any(matrix)

0 comments on commit ea08244

Please sign in to comment.