diff --git a/src/pyethnicity/_bayesian_models.py b/src/pyethnicity/_bayesian_models.py index b329ae2..98ea9de 100644 --- a/src/pyethnicity/_bayesian_models.py +++ b/src/pyethnicity/_bayesian_models.py @@ -59,6 +59,18 @@ def load(self, resource: Resource) -> pl.DataFrame: RESOURCE_LOADER = ResourceLoader() +def last_name_join(left: pl.DataFrame, right: pl.DataFrame) -> pl.DataFrame: + df = left.join( + right.with_columns(is_matched=pl.lit(True)), + left_on="last_name", + right_on="name", + how="left", + ) + not_matched = df.filter(~pl.col("is_matched")).filter( + pl.col("last_name").str.contains() + ) + + def _remove_chars(expr: pl.Expr) -> pl.Expr: for char in UNWANTED_CHARS: expr = expr.str.replace_all(char, "", literal=True) @@ -74,13 +86,15 @@ def _normalize_name(name: Name, col_name: str) -> pl.DataFrame: pl.LazyFrame({col_name: name}) .with_columns( pl.col(col_name) - .pipe(_remove_chars) .str.to_uppercase() .str.replace_all(r"\s?J\.*?R\.*\s*?$", "") .str.replace_all(r"\s?S\.*?R\.*\s*?$", "") .str.replace_all(r"\s?III\s*?$", "") .str.replace_all(r"\s?IV\s*?$", "") - .apply(_remove_single_chars) + ) + .with_columns( + pl.col(col_name).alias(f"{col_name}_raw"), + pl.col(col_name).pipe(_remove_chars).map_elements(_remove_single_chars), ) .collect() ) @@ -172,24 +186,63 @@ def _bisg_internal( _assert_equal_lengths(last_name, geography) - last_name_cleaned = _normalize_name(last_name, "last_name") + raw = pl.DataFrame({"last_name": last_name, geo_type: geography}).unique() - prob_race_given_last_name = last_name_cleaned.join( - RESOURCE_LOADER.load(prob_race_given_last_name_path), + last_name_cleaned = _normalize_name(raw["last_name"], "last_name").with_row_count( + "index" + ) + + prob_race_given_last_name_full = RESOURCE_LOADER.load( + prob_race_given_last_name_path + ) + + # first, join on the last name + matched_1 = last_name_cleaned.join( + prob_race_given_last_name_full, left_on="last_name", right_on="name", - how="left", + how="inner", + ) + + # when there is a compound name, match on the first one + not_matched = ( + last_name_cleaned.filter(~pl.col("index").is_in(matched_1["index"])) + .filter(pl.col("last_name_raw").str.contains("-")) + .with_columns(last_name=pl.col("last_name_raw").str.split("-").list.get(0)) + ) + matched_2 = not_matched.join( + prob_race_given_last_name_full, + left_on="last_name", + right_on="name", + how="inner", + ) + + # finally, match on the second part of the compound name + not_matched = ( + last_name_cleaned.filter(~pl.col("index").is_in(matched_1["index"])) + .filter(~pl.col("index").is_in(matched_2["index"])) + .with_columns(last_name=pl.col("last_name_raw").str.split("-").list.get(1)) + ) + matched_3 = not_matched.join( + prob_race_given_last_name_full, + left_on="last_name", + right_on="name", + how="inner", + ) + + prob_race_given_last_name = pl.concat( + [matched_1, matched_2, matched_3], how="vertical" ).select(races) - prob_geo_given_race = _resolve_geography(geography, geo_type).select(races) + prob_geo_given_race = _resolve_geography(raw[geo_type], geo_type).select(races) bisg_numer = prob_race_given_last_name * prob_geo_given_race bisg_denom = bisg_numer.sum(axis=1) bisg_probs = bisg_numer / bisg_denom df = bisg_probs.to_pandas() - df.insert(0, "last_name", last_name) - df.insert(1, geo_type, geography) + df.insert(0, "last_name", raw["last_name"]) + df.insert(1, geo_type, raw[geo_type]) return df