Skip to content

Commit

Permalink
black fix and also clean for better logging
Browse files Browse the repository at this point in the history
  • Loading branch information
bobkatla committed Oct 15, 2024
1 parent a3a8613 commit ed93759
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 47 deletions.
3 changes: 2 additions & 1 deletion PopSynthesis/Methods/IPSF/CSP/CSP.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
We will add the loop here as well (the longest is the fisrt SAA which is seperately)
"""


class CSP:
def __init__(self) -> None:
NotImplemented

def run():
NotImplemented
NotImplemented
58 changes: 45 additions & 13 deletions PopSynthesis/Methods/IPSF/CSP/operations/convert_seeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,52 +8,78 @@
import pandas as pd
from typing import Dict, List

def convert_seeds_to_pairs(hh_seed: pd.DataFrame, pp_seed: pd.DataFrame, id_col:str, pp_segment_col: str, main_state: str) -> Dict[str, pd.DataFrame]:

def convert_seeds_to_pairs(
hh_seed: pd.DataFrame,
pp_seed: pd.DataFrame,
id_col: str,
pp_segment_col: str,
main_state: str,
) -> Dict[str, pd.DataFrame]:
pp_states = pp_seed[pp_segment_col].unique()
assert main_state in pp_states
assert id_col in hh_seed.columns
assert id_col in pp_seed.columns

hh_seed[id_col] = hh_seed[id_col].astype(str)
pp_seed[id_col] = pp_seed[id_col].astype(str)
hh_name = "HH" # simply for naming convention
hh_name = "HH" # simply for naming convention
hh_seed = add_pp_seg_count(hh_seed, pp_seed, pp_segment_col, id_col)

segmented_pp = segment_pp_seed(pp_seed, pp_segment_col)
assert len(segmented_pp[main_state]) == len(hh_seed)

# pair up HH - Main first
result_pairs = {f"{hh_name}-{main_state}": pair_by_id(hh_seed, segmented_pp[main_state], id_col, hh_name, main_state)}
result_pairs = {
f"{hh_name}-{main_state}": pair_by_id(
hh_seed, segmented_pp[main_state], id_col, hh_name, main_state
)
}
assert len(result_pairs[f"{hh_name}-{main_state}"]) == len(hh_seed)
for pp_state in pp_states:
if pp_state != main_state:
result_pairs[f"{main_state}-{pp_state}"] = pair_by_id(segmented_pp[main_state], segmented_pp[pp_state], id_col, main_state, pp_state)
result_pairs[f"{main_state}-{pp_state}"] = pair_by_id(
segmented_pp[main_state],
segmented_pp[pp_state],
id_col,
main_state,
pp_state,
)
return result_pairs


def pair_by_id(df1: pd.DataFrame, df2: pd.DataFrame, id: str, name1:str="x", name2:str="y") -> pd.DataFrame:
def pair_by_id(
df1: pd.DataFrame, df2: pd.DataFrame, id: str, name1: str = "x", name2: str = "y"
) -> pd.DataFrame:
# join by the id col with inner (so only matched one got accepted)
# Likely id will not be unique for df2 as the normal rela may have multiple in 1 hh
join_result = df1.merge(df2, on=id, how="inner", suffixes=[f"_{name1}", f"_{name2}"])
join_result = df1.merge(
df2, on=id, how="inner", suffixes=[f"_{name1}", f"_{name2}"]
)
assert len(join_result) == min(len(df1), len(df2))
return join_result


def segment_pp_seed(pp_seed: pd.DataFrame, segment_col: str) -> Dict[str, pd.DataFrame]:
result_seg_pp = {}
for state in pp_seed[segment_col].unique():
result_seg_pp[state] = pp_seed[pp_seed[segment_col]==state]
result_seg_pp[state] = pp_seed[pp_seed[segment_col] == state]
return result_seg_pp


def pair_states_dict(states1: Dict[str, List[str]], states2: Dict[str, List[str]], name1:str="x", name2:str="y") -> Dict[str, List[str]]:
def pair_states_dict(
states1: Dict[str, List[str]],
states2: Dict[str, List[str]],
name1: str = "x",
name2: str = "y",
) -> Dict[str, List[str]]:
# This is to create the states list for pool creation using BN (also as ref if needed)
states_in_1 = set(states1.keys())
states_in_2 = set(states2.keys())
states_unique_1 = states_in_1 - states_in_2
states_unique_2 = states_in_2 - states_in_1
states_common = states_in_1 & states_in_2

results = {s: states1[s] for s in states_unique_1}
for s in states_unique_2:
results[s] = states2[s]
Expand All @@ -63,12 +89,18 @@ def pair_states_dict(states1: Dict[str, List[str]], states2: Dict[str, List[str]
return results


def add_pp_seg_count(hh_seed: pd.DataFrame, pp_seed: pd.DataFrame, segment_col: str, id_col: str) -> pd.DataFrame:
def add_pp_seg_count(
hh_seed: pd.DataFrame, pp_seed: pd.DataFrame, segment_col: str, id_col: str
) -> pd.DataFrame:
"""Add the count for each segment (e.g. relationship) into hh_seed for CSP"""
possible_seg_states = list(pp_seed[segment_col].unique())
filtered_pp_seed = pp_seed.groupby(id_col)[segment_col].apply(lambda x: list(x))

def process_seg_count(r):
seg_count = filtered_pp_seed[r[id_col]]
return [seg_count.count(x) for x in possible_seg_states]
hh_seed[possible_seg_states] = hh_seed.apply(process_seg_count, axis=1, result_type="expand")

hh_seed[possible_seg_states] = hh_seed.apply(
process_seg_count, axis=1, result_type="expand"
)
return hh_seed
2 changes: 1 addition & 1 deletion PopSynthesis/Methods/IPSF/CSP/operations/rela_syn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
Also a wrapper to take in all pools and the original (only HH)
Then we will combine all again and output the kept and removed
This can go through SAA again (with the help of the loop check census)
"""
"""
9 changes: 6 additions & 3 deletions PopSynthesis/Methods/IPSF/CSP/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import pandas as pd
import pickle
from PopSynthesis.Methods.IPSF.const import data_dir, POOL_SIZE, processed_dir
from PopSynthesis.Methods.IPSF.CSP.operations.convert_seeds import convert_seeds_to_pairs, pair_states_dict
from PopSynthesis.Methods.IPSF.CSP.operations.convert_seeds import (
convert_seeds_to_pairs,
pair_states_dict,
)
from PopSynthesis.Methods.IPSF.utils.pool_utils import create_pool


Expand All @@ -14,7 +17,7 @@ def main():
with open(processed_dir / "dict_pool_pairs.pickle", "rb") as handle:
pools_ref = pickle.load(handle)
print(pools_ref)


if __name__ == "__main__":
main()
main()
10 changes: 7 additions & 3 deletions PopSynthesis/Methods/IPSF/SAA/SAA.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ def __init__(
self.init_required_inputs(marginal_raw)

def init_required_inputs(self, marginal_raw: pd.DataFrame):
converted_segment_marg = process_raw_ipu_marg(marginal_raw, atts=self.considered_atts)
converted_segment_marg = process_raw_ipu_marg(
marginal_raw, atts=self.considered_atts
)
self.segmented_marg = converted_segment_marg

def run(self, output_each_step:bool =True, extra_name:str="") -> pd.DataFrame:
def run(self, output_each_step: bool = True, extra_name: str = "") -> pd.DataFrame:
# Output the synthetic population, the main point
curr_syn_pop = None
adjusted_atts = []
Expand All @@ -43,5 +45,7 @@ def run(self, output_each_step:bool =True, extra_name:str="") -> pd.DataFrame:
)
adjusted_atts.append(att)
if output_each_step:
curr_syn_pop.to_csv(output_dir / f"syn_pop_adjusted_{att}{extra_name}.csv")
curr_syn_pop.to_csv(
output_dir / f"syn_pop_adjusted_{att}{extra_name}.csv"
)
return curr_syn_pop
1 change: 1 addition & 0 deletions PopSynthesis/Methods/IPSF/SAA/operations/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from PopSynthesis.Methods.IPSF.SAA.operations.zone_adjustment import zone_adjustment


def process_raw_ipu_marg(marg: pd.DataFrame, atts: List[str]) -> pd.DataFrame:
segmented_marg = {}
zones = marg[marg.columns[marg.columns.get_level_values(0) == zone_field]].values
Expand Down
8 changes: 6 additions & 2 deletions PopSynthesis/Methods/IPSF/SAA/operations/zone_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def zone_adjustment(

neg_states = diff_census[diff_census < 0].index.tolist()
pos_states = diff_census[diff_census > 0].index.tolist()
zeros_states = diff_census[diff_census == 0].index.tolist() # Only for later processing
zeros_states = diff_census[
diff_census == 0
].index.tolist() # Only for later processing
pairs_adjust = list(itertools.product(neg_states, pos_states))
random.shuffle(pairs_adjust)

Expand Down Expand Up @@ -215,6 +217,8 @@ def zone_adjustment(
)
final_resulted_syn[zone_field] = zone
if len(final_resulted_syn) != ori_num_syn:
raise ValueError(f"Error processing at zone {zone}: expected {ori_num_syn} records, got {len(final_resulted_syn)}")
raise ValueError(
f"Error processing at zone {zone}: expected {ori_num_syn} records, got {len(final_resulted_syn)}"
)
print(f"Finished zone {zone}")
return final_resulted_syn
41 changes: 23 additions & 18 deletions PopSynthesis/Methods/IPSF/SAA/run_hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,68 +2,74 @@


import pandas as pd
import numpy as np
from PopSynthesis.Methods.IPSF.const import (
data_dir,
output_dir,
processed_dir,
zone_field,
)
from PopSynthesis.Methods.IPSF.utils.synthetic_checked_census import adjust_kept_rec_match_census, get_diff_marg, convert_full_to_marg_count
from PopSynthesis.Methods.IPSF.utils.synthetic_checked_census import (
adjust_kept_rec_match_census,
get_diff_marg,
convert_full_to_marg_count,
)
from PopSynthesis.Methods.IPSF.SAA.SAA import SAA
import random
import time


def run_main() -> None:
hh_marg = pd.read_csv(data_dir / "hh_marginals_ipu.csv", header=[0, 1])
hh_marg = hh_marg.drop(columns=hh_marg.columns[hh_marg.columns.get_level_values(0)=="sample_geog"][0])

hh_marg = hh_marg.drop(
columns=hh_marg.columns[hh_marg.columns.get_level_values(0) == "sample_geog"][0]
)

order_adjustment = [
"hhsize",
"hhinc",
"totalvehs",
"dwelltype",
"owndwell",
]
] # these must exist in both marg and syn
considered_atts = [
"hhsize",
"hhinc",
"totalvehs",
"dwelltype",
"owndwell",
]
] # exist in final syn

hh_marg = hh_marg.head(2)
pool = pd.read_csv(processed_dir / "HH_pool_small_test.csv")
pool = pd.read_csv(processed_dir / "HH_pool.csv")
start_time = time.time()

n_run_time = 0
n_removed_err_hh = np.inf
# init with the total HH we want
n_removed_err_hh = hh_marg.sum().sum() / len(order_adjustment)
MAX_RUN_TIME = 30
chosen_hhs = []
err_rm_hh = []
while n_run_time < MAX_RUN_TIME and n_removed_err_hh > 0:
# randomly shuffle for each adjustment
random.shuffle(order_adjustment)
print(f"For run {n_run_time}, order is: {order_adjustment}, aim for {n_removed_err_hh} HHs")
err_rm_hh.append(n_removed_err_hh)
print(
f"For run {n_run_time}, order is: {order_adjustment}, aim for {n_removed_err_hh} HHs"
)
saa = SAA(hh_marg, considered_atts, order_adjustment, pool)
###
final_syn_pop = saa.run(extra_name=f"_{n_run_time}")
###
# error check
marg_from_kept_hh = convert_full_to_marg_count(
final_syn_pop, [zone_field]
marg_from_kept_hh = convert_full_to_marg_count(final_syn_pop, [zone_field])
converted_hh_marg = hh_marg.set_index(
hh_marg.columns[hh_marg.columns.get_level_values(0) == zone_field][0]
)
converted_hh_marg = hh_marg.set_index(hh_marg.columns[hh_marg.columns.get_level_values(0)==zone_field][0])
diff_marg = get_diff_marg(converted_hh_marg, marg_from_kept_hh)

kept_hh = adjust_kept_rec_match_census(final_syn_pop, diff_marg)

# checking
kept_marg = convert_full_to_marg_count(
kept_hh, [zone_field]
)
kept_marg = convert_full_to_marg_count(kept_hh, [zone_field])
new_diff_marg = get_diff_marg(converted_hh_marg, kept_marg)
# check it is no neg indeed
checking_not_neg = new_diff_marg < 0
Expand All @@ -75,7 +81,6 @@ def run_main() -> None:

n_run_time += 1
n_removed_err_hh = len(final_syn_pop) - len(kept_hh)
err_rm_hh.append(n_removed_err_hh)
if n_run_time == MAX_RUN_TIME:
# not adjusting anymore
chosen_hhs.append(final_syn_pop)
Expand Down
21 changes: 15 additions & 6 deletions PopSynthesis/Methods/IPSF/utils/synthetic_checked_census.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from PopSynthesis.Methods.IPSF.const import count_field, zone_field
from typing import List, Literal, Tuple, Dict


def segment_df(df: pd.DataFrame, chunk_sz: int) -> List[pd.DataFrame]:
start = 0
ls_df = []
Expand All @@ -25,7 +26,9 @@ def segment_df(df: pd.DataFrame, chunk_sz: int) -> List[pd.DataFrame]:
return ls_df


def convert_count_to_full(count_df: pd.DataFrame, count_col: str = count_field) -> pd.DataFrame:
def convert_count_to_full(
count_df: pd.DataFrame, count_col: str = count_field
) -> pd.DataFrame:
assert count_col in count_df.columns
repeated_idx = list(count_df.index.repeat(count_df[count_col]))
fin = count_df.loc[repeated_idx]
Expand All @@ -35,10 +38,10 @@ def convert_count_to_full(count_df: pd.DataFrame, count_col: str = count_field)


def convert_full_to_marg_count(
full_pop: pd.DataFrame, filter_ls: list[str]=[]
full_pop: pd.DataFrame, filter_ls: list[str] = []
) -> pd.DataFrame:
assert zone_field in full_pop.columns
cols = [x for x in full_pop.columns if x not in filter_ls+[zone_field]]
cols = [x for x in full_pop.columns if x not in filter_ls + [zone_field]]
ls_temp_hold = []
for att in cols:
full_pop[att] = full_pop[att].astype(str)
Expand All @@ -59,7 +62,9 @@ def convert_full_to_marg_count(
return new_marg_hh


def add_0_to_missing(df: pd.DataFrame, ls_missing: List[str], axis: Literal[0, 1]) -> pd.DataFrame:
def add_0_to_missing(
df: pd.DataFrame, ls_missing: List[str], axis: Literal[0, 1]
) -> pd.DataFrame:
for missing in ls_missing:
if axis == 1: # by row
df.loc[missing] = 0
Expand All @@ -68,7 +73,9 @@ def add_0_to_missing(df: pd.DataFrame, ls_missing: List[str], axis: Literal[0, 1
return df


def get_diff_marg(converted_census_marg: pd.DataFrame, converted_new_hh_marg: pd.DataFrame) -> pd.DataFrame:
def get_diff_marg(
converted_census_marg: pd.DataFrame, converted_new_hh_marg: pd.DataFrame
) -> pd.DataFrame:
print("getting the diff marg df")
converted_census_marg.index = converted_census_marg.index.astype(str)
converted_new_hh_marg.index = converted_new_hh_marg.index.astype(str)
Expand Down Expand Up @@ -104,7 +111,9 @@ def convert_to_dict_ls(tup: Tuple[Tuple[str, str]]) -> Dict[str, str]:
return di


def adjust_kept_rec_match_census(syn_records: pd.DataFrame, diff_census: pd.DataFrame) -> pd.DataFrame:
def adjust_kept_rec_match_census(
syn_records: pd.DataFrame, diff_census: pd.DataFrame
) -> pd.DataFrame:
# The point is to remove the chosen in
syn_records = syn_records.astype(str)
count_kept = syn_records.value_counts()
Expand Down

0 comments on commit ed93759

Please sign in to comment.