Skip to content

Commit 3c49690

Browse files
committed
black refactoring
1 parent 1b26274 commit 3c49690

23 files changed

+384
-147
lines changed

PopSynthesis/DataProcessor/DataProcessor.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def process_all_seed(self) -> None:
5858
filtered_hh = hh_df[hh_df["hhid"].isin(hhid_in_pp)]
5959
hhid_in_hh = list(filtered_hh["hhid"].unique())
6060
filtered_pp = pp_df[pp_df["hhid"].isin(hhid_in_hh)]
61-
print(f"Removed {len(hh_df) - len(filtered_hh)} households due to mismatch with pp")
61+
print(
62+
f"Removed {len(hh_df) - len(filtered_hh)} households due to mismatch with pp"
63+
)
6264
print(f"Removed {len(pp_df) - len(filtered_pp)} people due to mismatch with hh")
6365

6466
# households size equal number of persons
@@ -81,7 +83,6 @@ def check_match_hhsz(r):
8183

8284
self.hh_seed_data = filtered_hh
8385
self.pp_seed_data = filtered_pp
84-
8586

8687
def process_households_seed(self) -> pd.DataFrame:
8788
# Import the hh seed data
@@ -91,7 +92,9 @@ def process_households_seed(self) -> pd.DataFrame:
9192
# Next we add weights, we combine weights of both wd and we
9293
hh_df = hh_df.with_columns(pl.col("wdhhwgt_sa3").fill_null(strategy="zero"))
9394
hh_df = hh_df.with_columns(pl.col("wehhwgt_sa3").fill_null(strategy="zero"))
94-
hh_df = hh_df.with_columns(_weight = pl.col("wdhhwgt_sa3") + pl.col("wehhwgt_sa3"))
95+
hh_df = hh_df.with_columns(
96+
_weight=pl.col("wdhhwgt_sa3") + pl.col("wehhwgt_sa3")
97+
)
9598
hh_df = hh_df.drop(["wdhhwgt_sa3", "wehhwgt_sa3"])
9699
hh_df = hh_df.drop_nulls()
97100

@@ -109,15 +112,19 @@ def process_persons_seed(self) -> pd.DataFrame:
109112
# Next we add weights, we combine weights of both wd and we
110113
pp_df = pp_df.with_columns(pl.col("wdperswgt_sa3").fill_null(strategy="zero"))
111114
pp_df = pp_df.with_columns(pl.col("weperswgt_sa3").fill_null(strategy="zero"))
112-
pp_df = pp_df.with_columns(_weight = pl.col("wdperswgt_sa3") + pl.col("weperswgt_sa3"))
115+
pp_df = pp_df.with_columns(
116+
_weight=pl.col("wdperswgt_sa3") + pl.col("weperswgt_sa3")
117+
)
113118
pp_df = pp_df.drop(["wdperswgt_sa3", "weperswgt_sa3"])
114119

115120
pp_df = process_not_accept_values(pp_df)
116121
pp_df = process_rela(pp_df)
117122
pp_df = convert_pp_age_gr(pp_df)
118123
return pp_df.to_pandas()
119-
120-
def output_seed(self, name_pp_seed:str = "pp_seed", name_hh_seed:str = "hh_seed") -> None:
124+
125+
def output_seed(
126+
self, name_pp_seed: str = "pp_seed", name_hh_seed: str = "hh_seed"
127+
) -> None:
121128
pp_loc = self.output_data_path / f"{name_pp_seed}.csv"
122129
hh_loc = self.output_data_path / f"{name_hh_seed}.csv"
123130
self.pp_seed_data.to_csv(pp_loc, index=False)
@@ -131,7 +138,7 @@ def process_households_census(self):
131138

132139
def process_persons_census(self):
133140
NotImplemented
134-
141+
135142
def output_all_files(self):
136143
NotImplemented
137144

PopSynthesis/DataProcessor/utils/const_process.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,16 @@
4343
"$8,000 or more ($416,000 or more)",
4444
]
4545

46-
HH_ATTS = ["hhid", "dwelltype", "owndwell", "hhinc", "totalvehs", "hhsize", "wdhhwgt_sa3", "wehhwgt_sa3"]
46+
HH_ATTS = [
47+
"hhid",
48+
"dwelltype",
49+
"owndwell",
50+
"hhinc",
51+
"totalvehs",
52+
"hhsize",
53+
"wdhhwgt_sa3",
54+
"wehhwgt_sa3",
55+
]
4756

4857
PP_ATTS = [
4958
"persid",
@@ -55,6 +64,6 @@
5564
"nolicence",
5665
"anywork",
5766
"wdperswgt_sa3",
58-
"weperswgt_sa3"
67+
"weperswgt_sa3",
5968
]
6069
NOT_INCLUDED_IN_BN_LEARN = ["hhid", "persid", "relationship"]

PopSynthesis/DataProcessor/utils/seed/hh/process_general_hh.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def convert_hh_inc(hh_df: pl.DataFrame, check_states: str) -> pl.DataFrame:
4747
return hh_df
4848

4949

50-
def convert_hh_dwell(hh_df: pl.DataFrame) -> pl.DataFrame: # Removing the occupied rent free
50+
def convert_hh_dwell(
51+
hh_df: pl.DataFrame,
52+
) -> pl.DataFrame: # Removing the occupied rent free
5153
col_owndwell = pl.col("owndwell")
5254
expr = (
5355
pl.when(col_owndwell == "Occupied Rent-Free")

PopSynthesis/DataProcessor/utils/seed/hh/process_hh_main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pandas as pd
55

6+
67
def process_hh_main_person(
78
hh_df, main_pp_df, to_csv=False, name_file="connect_hh_main", include_weights=True
89
):

PopSynthesis/DataProcessor/utils/seed/pp/process_relationships.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@
66
MIN_PARENT_CHILD_GAP = 15
77
MIN_GRANDPARENT_GRANDCHILD_GAP = 33
88
# This only apply when we do the conversion for Child and Grandchild
9-
MAX_COUPLE_GAP = 20
9+
MAX_COUPLE_GAP = 20
1010
MIN_PERMITTED_AGE_MARRIED = 16
1111
AVAILABLE_RELATIONSHIPS = [
12-
"Main",
13-
"Spouse",
14-
"Child",
15-
"Grandchild",
16-
"Sibling",
17-
"Others",
18-
"Parent",
19-
"Grandparent",
20-
]
12+
"Main",
13+
"Spouse",
14+
"Child",
15+
"Grandchild",
16+
"Sibling",
17+
"Others",
18+
"Parent",
19+
"Grandparent",
20+
]
2121

2222

2323
class Person:
@@ -385,6 +385,6 @@ def process_rela(pp_df: pl.DataFrame) -> pl.DataFrame:
385385
pp_df["relationship"] = pp_df["persid"].map(result_mapping)
386386

387387
# The households with implausible combinations will have None value for relationship
388-
pp_df =pp_df[~pp_df["relationship"].isna()]
388+
pp_df = pp_df[~pp_df["relationship"].isna()]
389389

390390
return pl.from_pandas(pp_df)

PopSynthesis/Methods/IPSF/SAA/main.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,42 @@
88
import pandas as pd
99

1010
from PopSynthesis.Methods.IPSF.const import POOL_SIZE
11-
from PopSynthesis.Methods.IPSF.SAA.operations.general import process_raw_ipu_init, adjust_atts_state_match_census
11+
from PopSynthesis.Methods.IPSF.SAA.operations.general import (
12+
process_raw_ipu_init,
13+
adjust_atts_state_match_census,
14+
)
1215
from typing import List, Dict
1316

17+
1418
class SAA:
15-
def __init__(self, marginal_raw: pd.DataFrame, seed_raw: pd.DataFrame, ordered_to_adjust_atts:List[str], att_states: Dict[str, List[str]], pool_sz: int = POOL_SIZE) -> None:
19+
def __init__(
20+
self,
21+
marginal_raw: pd.DataFrame,
22+
seed_raw: pd.DataFrame,
23+
ordered_to_adjust_atts: List[str],
24+
att_states: Dict[str, List[str]],
25+
pool_sz: int = POOL_SIZE,
26+
) -> None:
1627
self.ordered_atts = ordered_to_adjust_atts
1728
self.known_att_states = att_states
1829
self.init_required_inputs(marginal_raw, seed_raw)
1930

2031
def init_required_inputs(self, marginal_raw: pd.DataFrame, seed_raw: pd.DataFrame):
21-
converted_segment_marg, converted_seed = process_raw_ipu_init(marginal_raw, seed_raw)
32+
converted_segment_marg, converted_seed = process_raw_ipu_init(
33+
marginal_raw, seed_raw
34+
)
2235
self.seed = converted_seed
2336
self.segmented_marg = converted_segment_marg
2437

2538
def run(self) -> pd.DataFrame:
2639
# Output the synthetic population, the main point
2740
curr_syn_pop = None
2841
adjusted_atts = []
29-
pool = self.seed # change later
42+
pool = self.seed # change later
3043
for att in self.ordered_atts:
3144
sub_census = self.segmented_marg[att].reset_index()
32-
curr_syn_pop = adjust_atts_state_match_census(att, curr_syn_pop, sub_census, adjusted_atts, pool)
45+
curr_syn_pop = adjust_atts_state_match_census(
46+
att, curr_syn_pop, sub_census, adjusted_atts, pool
47+
)
3348
adjusted_atts.append(att)
3449
return curr_syn_pop

PopSynthesis/Methods/IPSF/SAA/operations/compare_census.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@
88
from PopSynthesis.Methods.IPSF.const import zone_field, count_field
99

1010

11-
def calculate_states_diff(att:str, syn_pop: pd.DataFrame, sub_census: pd.DataFrame) -> pd.DataFrame:
11+
def calculate_states_diff(
12+
att: str, syn_pop: pd.DataFrame, sub_census: pd.DataFrame
13+
) -> pd.DataFrame:
1214
""" This calculate the differences between current syn_pop and the census at a specific geo_lev """
1315
sub_syn_pop_count = syn_pop[[zone_field, att]].value_counts().reset_index()
14-
tranformed_sub_syn_count = sub_syn_pop_count.pivot(index=zone_field, columns=att, values=count_field).fillna(0)
16+
tranformed_sub_syn_count = sub_syn_pop_count.pivot(
17+
index=zone_field, columns=att, values=count_field
18+
).fillna(0)
1519
sub_census = sub_census.set_index(zone_field)
1620
# Always census is the ground truth, check for missing and fill
1721
missing_zones = set(sub_census.index) - set(tranformed_sub_syn_count.index)
@@ -26,10 +30,3 @@ def calculate_states_diff(att:str, syn_pop: pd.DataFrame, sub_census: pd.DataFra
2630
# no nan values
2731
assert not results.isna().any().any()
2832
return results
29-
30-
31-
32-
33-
34-
35-

PopSynthesis/Methods/IPSF/SAA/operations/general.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,25 @@
66

77
from typing import List, Union, Tuple, Dict
88
from PopSynthesis.Methods.IPSF.const import count_field, zone_field, data_dir
9-
from PopSynthesis.Methods.IPSF.SAA.operations.compare_census import calculate_states_diff
9+
from PopSynthesis.Methods.IPSF.SAA.operations.compare_census import (
10+
calculate_states_diff,
11+
)
1012
from PopSynthesis.Methods.IPSF.SAA.operations.zone_adjustment import zone_adjustment
11-
from PopSynthesis.Methods.IPSF.utils.condensed_tools import CondensedDF, sample_from_condensed
13+
from PopSynthesis.Methods.IPSF.utils.condensed_tools import (
14+
CondensedDF,
15+
sample_from_condensed,
16+
)
1217

1318

14-
def process_raw_ipu_init(marg: pd.DataFrame, seed: pd.DataFrame) -> Tuple[Dict[str, pd.DataFrame], pd.DataFrame]:
15-
atts = [x for x in seed.columns if x not in ["serialno", "sample_geog"] ]
19+
def process_raw_ipu_init(
20+
marg: pd.DataFrame, seed: pd.DataFrame
21+
) -> Tuple[Dict[str, pd.DataFrame], pd.DataFrame]:
22+
atts = [x for x in seed.columns if x not in ["serialno", "sample_geog"]]
1623
segmented_marg = {}
17-
zones = marg[marg.columns[marg.columns.get_level_values(0)==zone_field]].values
24+
zones = marg[marg.columns[marg.columns.get_level_values(0) == zone_field]].values
1825
zones = [z[0] for z in zones]
1926
for att in atts:
20-
sub_marg = marg[marg.columns[marg.columns.get_level_values(0)==att]]
27+
sub_marg = marg[marg.columns[marg.columns.get_level_values(0) == att]]
2128
if sub_marg.empty:
2229
print(f"Don't have this att {att} in census")
2330
continue
@@ -29,15 +36,21 @@ def process_raw_ipu_init(marg: pd.DataFrame, seed: pd.DataFrame) -> Tuple[Dict[s
2936
return segmented_marg, new_seed
3037

3138

32-
def sample_from_pl(df: pl.DataFrame, n: int, count_field:str = count_field, with_replacement=True) -> pl.DataFrame:
39+
def sample_from_pl(
40+
df: pl.DataFrame, n: int, count_field: str = count_field, with_replacement=True
41+
) -> pl.DataFrame:
3342
# Normalize weights to sum to 1
3443
weights = df[count_field].to_numpy()
35-
weights = weights/weights.sum()
36-
sample_indices = np.random.choice(df.height, size=n, replace=with_replacement, p=weights)
44+
weights = weights / weights.sum()
45+
sample_indices = np.random.choice(
46+
df.height, size=n, replace=with_replacement, p=weights
47+
)
3748
return df[sample_indices.tolist()]
3849

3950

40-
def init_syn_pop_saa(att:str, marginal_data: pd.DataFrame, pool: pd.DataFrame) -> pl.DataFrame:
51+
def init_syn_pop_saa(
52+
att: str, marginal_data: pd.DataFrame, pool: pd.DataFrame
53+
) -> pl.DataFrame:
4154
pool = pl.from_pandas(pool)
4255
marginal_data = pl.from_pandas(marginal_data)
4356
assert zone_field in marginal_data
@@ -51,13 +64,11 @@ def init_syn_pop_saa(att:str, marginal_data: pd.DataFrame, pool: pd.DataFrame) -
5164
for state in states:
5265
sub_pool = pool.filter(pl.col(att) == state)
5366
if len(sub_pool) == 0:
54-
print(
55-
f"WARNING: cannot see {att}_{state} in the pool, sample by the rest"
56-
)
67+
print(f"WARNING: cannot see {att}_{state} in the pool, sample by the rest")
5768
sub_pool = pool # if there are none, we take all
5869
for zone in marginal_data[zone_field]:
5970
condition = marginal_data.filter(pl.col(zone_field) == zone)
60-
census_val = condition.select(state).to_numpy()[0,0]
71+
census_val = condition.select(state).to_numpy()[0, 0]
6172

6273
sub_syn_pop = sample_from_pl(sub_pool, census_val)
6374

@@ -68,23 +79,33 @@ def init_syn_pop_saa(att:str, marginal_data: pd.DataFrame, pool: pd.DataFrame) -
6879
return pl.concat(sub_pops)
6980

7081

71-
def adjust_atts_state_match_census(att: str, curr_syn_pop: Union[None, pd.DataFrame], census_data_by_att: pd.DataFrame, adjusted_atts: List[str], pool: pd.DataFrame) -> pd.DataFrame:
82+
def adjust_atts_state_match_census(
83+
att: str,
84+
curr_syn_pop: Union[None, pd.DataFrame],
85+
census_data_by_att: pd.DataFrame,
86+
adjusted_atts: List[str],
87+
pool: pd.DataFrame,
88+
) -> pd.DataFrame:
7289
if curr_syn_pop is None:
7390
updated_syn_pop = init_syn_pop_saa(att, census_data_by_att, pool).to_pandas()
7491
else:
7592
updated_syn_pop = curr_syn_pop
7693

77-
states_diff_census = calculate_states_diff(att, curr_syn_pop, census_data_by_att)
94+
states_diff_census = calculate_states_diff(
95+
att, curr_syn_pop, census_data_by_att
96+
)
7897
assert (states_diff_census.sum(axis=1) == 0).all()
7998
# With state diff we can now do adjustment for each zone, can parallel it?
8099
pop_syn_across_zones = []
81100
for zid, zone_states_diff in states_diff_census.iterrows():
82101
print(f"Processing zone {zid}")
83-
sub_syn_pop = updated_syn_pop[updated_syn_pop[zone_field]==zid]
84-
zone_adjusted_syn_pop = zone_adjustment(att, sub_syn_pop, zone_states_diff, pool, adjusted_atts)
102+
sub_syn_pop = updated_syn_pop[updated_syn_pop[zone_field] == zid]
103+
zone_adjusted_syn_pop = zone_adjustment(
104+
att, sub_syn_pop, zone_states_diff, pool, adjusted_atts
105+
)
85106
if zone_adjusted_syn_pop is not None:
86107
pop_syn_across_zones.append(zone_adjusted_syn_pop)
87-
108+
88109
updated_syn_pop = pd.concat(pop_syn_across_zones)
89110

90-
return updated_syn_pop
111+
return updated_syn_pop

0 commit comments

Comments
 (0)