From da4956e04dc4324555e5734db6d5d26789b6b4b2 Mon Sep 17 00:00:00 2001 From: Cangyuan Li Date: Wed, 1 Nov 2023 00:48:39 -0400 Subject: [PATCH] download from github if file is not present --- setup.cfg | 3 +-- src/pyethnicity/_bayesian_models.py | 10 ++++++---- src/pyethnicity/_ml_models.py | 9 +++++++-- src/pyethnicity/utils/utils.py | 17 +++++++++++++++++ 4 files changed, 31 insertions(+), 8 deletions(-) diff --git a/setup.cfg b/setup.cfg index 85f09dc..ea343c2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,5 +36,4 @@ where = src [options.package_data] pyethnicity = - data/distributions/*.parquet - data/models/first_last.onnx \ No newline at end of file + data/distributions/*.parquet \ No newline at end of file diff --git a/src/pyethnicity/_bayesian_models.py b/src/pyethnicity/_bayesian_models.py index 701e9c2..5dc7adc 100644 --- a/src/pyethnicity/_bayesian_models.py +++ b/src/pyethnicity/_bayesian_models.py @@ -9,7 +9,7 @@ from .utils.paths import DIST_PATH from .utils.types import Geography, GeoType, Name, Year -from .utils.utils import RACES, _assert_equal_lengths, _remove_single_chars +from .utils.utils import RACES, _assert_equal_lengths, _download, _remove_single_chars UNWANTED_CHARS = string.digits + string.punctuation + string.whitespace @@ -33,9 +33,11 @@ def __init__(self): def load(self, resource: Resource) -> pl.DataFrame: if self._resources[resource] is None: - self._resources[resource] = pl.read_parquet( - DIST_PATH / f"{resource}.parquet" - ) + file = f"{resource}.parquet" + if not (DIST_PATH / file).exists(): + _download(f"distributions/{file}") + + self._resources[resource] = pl.read_parquet(DIST_PATH / file) data = self._resources[resource] assert data is not None diff --git a/src/pyethnicity/_ml_models.py b/src/pyethnicity/_ml_models.py index 1d6f53a..212be3a 100644 --- a/src/pyethnicity/_ml_models.py +++ b/src/pyethnicity/_ml_models.py @@ -16,6 +16,7 @@ from .utils.utils import ( RACES, _assert_equal_lengths, + _download, _is_null, _remove_single_chars, _std_norm, @@ -37,8 +38,12 @@ def __init__(self): def load(self, model: Model) -> onnxruntime.InferenceSession: if self._models[model] is None: + file = f"{model}.onnx" + if not (MODEL_PATH / file).exists(): + _download(f"models/{file}") + self._models[model] = onnxruntime.InferenceSession( - MODEL_PATH / f"{model}.onnx", + MODEL_PATH / file, providers=onnxruntime.get_available_providers(), ) @@ -303,7 +308,7 @@ def predict_race( - data/distributions/prob_race_given_last_name.parquet - data/distributions/prob_zcta_given_race_2010.parquet - data/distributions/prob_tract_given_race_2010.parquet - - data/distributionsprob_first_name_given_race.parquet + - data/distributions/prob_first_name_given_race.parquet Examples -------- diff --git a/src/pyethnicity/utils/utils.py b/src/pyethnicity/utils/utils.py index a9f21a9..28d96dd 100644 --- a/src/pyethnicity/utils/utils.py +++ b/src/pyethnicity/utils/utils.py @@ -3,6 +3,9 @@ from collections.abc import Sequence from typing import SupportsFloat, SupportsIndex, Union +import requests + +from .paths import DAT_PATH from .types import ArrayLike RACES = ("asian", "black", "hispanic", "white") @@ -36,3 +39,17 @@ def _std_norm(values: Sequence[float]) -> list[float]: def _is_null(x: Union[SupportsFloat, SupportsIndex]): return math.isnan(x) or x is None + + +def _download(path: str): + r = requests.get( + f"https://raw.githubusercontent.com/CangyuanLi/pyethnicity/master/src/pyethnicity/data/{path}" + ) + if r.status_code != 200: + raise requests.exceptions.HTTPError(f"{r.status_code}: DOWNLOAD FAILED") + + parent_folder = path.split("/")[0] + (DAT_PATH / parent_folder).mkdir(exist_ok=True) + + with open(DAT_PATH / path, "wb") as f: + f.write(r.content)