Skip to content

Commit

Permalink
download from github if file is not present
Browse files Browse the repository at this point in the history
  • Loading branch information
CangyuanLi committed Nov 1, 2023
1 parent 973aa0f commit da4956e
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 8 deletions.
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,4 @@ where = src

[options.package_data]
pyethnicity =
data/distributions/*.parquet
data/models/first_last.onnx
data/distributions/*.parquet
10 changes: 6 additions & 4 deletions src/pyethnicity/_bayesian_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .utils.paths import DIST_PATH
from .utils.types import Geography, GeoType, Name, Year
from .utils.utils import RACES, _assert_equal_lengths, _remove_single_chars
from .utils.utils import RACES, _assert_equal_lengths, _download, _remove_single_chars

UNWANTED_CHARS = string.digits + string.punctuation + string.whitespace

Expand All @@ -33,9 +33,11 @@ def __init__(self):

def load(self, resource: Resource) -> pl.DataFrame:
if self._resources[resource] is None:
self._resources[resource] = pl.read_parquet(
DIST_PATH / f"{resource}.parquet"
)
file = f"{resource}.parquet"
if not (DIST_PATH / file).exists():
_download(f"distributions/{file}")

self._resources[resource] = pl.read_parquet(DIST_PATH / file)

data = self._resources[resource]
assert data is not None
Expand Down
9 changes: 7 additions & 2 deletions src/pyethnicity/_ml_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .utils.utils import (
RACES,
_assert_equal_lengths,
_download,
_is_null,
_remove_single_chars,
_std_norm,
Expand All @@ -37,8 +38,12 @@ def __init__(self):

def load(self, model: Model) -> onnxruntime.InferenceSession:
if self._models[model] is None:
file = f"{model}.onnx"
if not (MODEL_PATH / file).exists():
_download(f"models/{file}")

self._models[model] = onnxruntime.InferenceSession(
MODEL_PATH / f"{model}.onnx",
MODEL_PATH / file,
providers=onnxruntime.get_available_providers(),
)

Expand Down Expand Up @@ -303,7 +308,7 @@ def predict_race(
- data/distributions/prob_race_given_last_name.parquet
- data/distributions/prob_zcta_given_race_2010.parquet
- data/distributions/prob_tract_given_race_2010.parquet
- data/distributionsprob_first_name_given_race.parquet
- data/distributions/prob_first_name_given_race.parquet
Examples
--------
Expand Down
17 changes: 17 additions & 0 deletions src/pyethnicity/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from collections.abc import Sequence
from typing import SupportsFloat, SupportsIndex, Union

import requests

from .paths import DAT_PATH
from .types import ArrayLike

RACES = ("asian", "black", "hispanic", "white")
Expand Down Expand Up @@ -36,3 +39,17 @@ def _std_norm(values: Sequence[float]) -> list[float]:

def _is_null(x: Union[SupportsFloat, SupportsIndex]):
return math.isnan(x) or x is None


def _download(path: str):
r = requests.get(
f"https://raw.githubusercontent.com/CangyuanLi/pyethnicity/master/src/pyethnicity/data/{path}"
)
if r.status_code != 200:
raise requests.exceptions.HTTPError(f"{r.status_code}: DOWNLOAD FAILED")

parent_folder = path.split("/")[0]
(DAT_PATH / parent_folder).mkdir(exist_ok=True)

with open(DAT_PATH / path, "wb") as f:
f.write(r.content)

0 comments on commit da4956e

Please sign in to comment.