Skip to content

Commit

Permalink
make predict_race_fl unique by its inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
CangyuanLi committed Jun 5, 2024
1 parent 61002a4 commit 8774f95
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions src/pyethnicity/_ml_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,11 @@ def predict_race_fl(

input_name = _model.get_inputs()[0].name

y_pred = []
for input_ in tqdm.tqdm(cutils.chunk_seq(X, chunksize)):
y_pred.extend(_model.run(None, input_feed={input_name: input_})[0])
with tqdm.tqdm(total=len(X)) as pbar:
y_pred = []
for input_ in cutils.chunk_seq(X, chunksize):
y_pred.extend(_model.run(None, input_feed={input_name: input_})[0])
pbar.update(len(input_))

preds: dict[str, list] = {r: [] for r in RACES}
for row in y_pred:
Expand All @@ -211,10 +213,14 @@ def predict_race_fl(
preds[first_name_col] = first_name
preds[last_name_col] = last_name

df = pl.DataFrame(preds).select(
first_name_col,
last_name_col,
cs.all().exclude(first_name_col, last_name_col),
df = (
pl.DataFrame(preds)
.select(
first_name_col,
last_name_col,
cs.all().exclude(first_name_col, last_name_col),
)
.unique([first_name_col, last_name_col]) # TODO: Push this unique up
)

return df
Expand Down

0 comments on commit 8774f95

Please sign in to comment.