From b55a96d068676d5c71b5eed0afdbaa6ed3db68c1 Mon Sep 17 00:00:00 2001 From: Duc Minh La Date: Mon, 14 Oct 2024 16:47:29 +1100 Subject: [PATCH] refactor for leaner SAA --- PopSynthesis/Methods/IPSF/SAA/main.py | 22 +++++++++---------- .../Methods/IPSF/SAA/operations/general.py | 15 ++++++++----- PopSynthesis/Methods/IPSF/SAA/run_hh.py | 3 +++ 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/PopSynthesis/Methods/IPSF/SAA/main.py b/PopSynthesis/Methods/IPSF/SAA/main.py index 7f99e60..994b41f 100644 --- a/PopSynthesis/Methods/IPSF/SAA/main.py +++ b/PopSynthesis/Methods/IPSF/SAA/main.py @@ -7,9 +7,9 @@ import pandas as pd -from PopSynthesis.Methods.IPSF.const import POOL_SIZE, output_dir +from PopSynthesis.Methods.IPSF.const import output_dir from PopSynthesis.Methods.IPSF.SAA.operations.general import ( - process_raw_ipu_init, + process_raw_ipu_marg, adjust_atts_state_match_census, ) from typing import List, Dict @@ -19,33 +19,31 @@ class SAA: def __init__( self, marginal_raw: pd.DataFrame, - seed_raw: pd.DataFrame, + considered_atts: List[str], ordered_to_adjust_atts: List[str], att_states: Dict[str, List[str]], pool: pd.DataFrame, ) -> None: - self.ordered_atts = ordered_to_adjust_atts + self.ordered_atts_to_adjust = ordered_to_adjust_atts + self.considered_atts = considered_atts self.known_att_states = att_states self.pool = pool - self.init_required_inputs(marginal_raw, seed_raw) + self.init_required_inputs(marginal_raw) - def init_required_inputs(self, marginal_raw: pd.DataFrame, seed_raw: pd.DataFrame): - converted_segment_marg, converted_seed = process_raw_ipu_init( - marginal_raw, seed_raw - ) - self.seed = converted_seed + def init_required_inputs(self, marginal_raw: pd.DataFrame): + converted_segment_marg = process_raw_ipu_marg(marginal_raw, atts=self.considered_atts) self.segmented_marg = converted_segment_marg def run(self, output_each_step:bool =True) -> pd.DataFrame: # Output the synthetic population, the main point curr_syn_pop = None adjusted_atts = [] - for att in self.ordered_atts: + for att in self.ordered_atts_to_adjust: sub_census = self.segmented_marg[att].reset_index() curr_syn_pop = adjust_atts_state_match_census( att, curr_syn_pop, sub_census, adjusted_atts, self.pool ) adjusted_atts.append(att) if output_each_step: - curr_syn_pop.to_csv(output_dir / f"syn_pop_adjusted_{att}.csv") + curr_syn_pop.to_csv(output_dir / f"syn_pop_adjusted_{att}_2.csv") return curr_syn_pop diff --git a/PopSynthesis/Methods/IPSF/SAA/operations/general.py b/PopSynthesis/Methods/IPSF/SAA/operations/general.py index 74f0810..6073be8 100644 --- a/PopSynthesis/Methods/IPSF/SAA/operations/general.py +++ b/PopSynthesis/Methods/IPSF/SAA/operations/general.py @@ -10,13 +10,8 @@ calculate_states_diff, ) from PopSynthesis.Methods.IPSF.SAA.operations.zone_adjustment import zone_adjustment -import multiprocessing as mp - -def process_raw_ipu_init( - marg: pd.DataFrame, seed: pd.DataFrame -) -> Tuple[Dict[str, pd.DataFrame], pd.DataFrame]: - atts = [x for x in seed.columns if x not in ["serialno", "sample_geog"]] +def process_raw_ipu_marg(marg: pd.DataFrame, atts: List[str]) -> pd.DataFrame: segmented_marg = {} zones = marg[marg.columns[marg.columns.get_level_values(0) == zone_field]].values zones = [z[0] for z in zones] @@ -29,6 +24,14 @@ def process_raw_ipu_init( sub_marg.loc[:, [zone_field]] = zones sub_marg = sub_marg.set_index(zone_field) segmented_marg[att] = sub_marg + return segmented_marg + + +def process_raw_ipu_init( + marg: pd.DataFrame, seed: pd.DataFrame +) -> Tuple[Dict[str, pd.DataFrame], pd.DataFrame]: + atts = [x for x in seed.columns if x not in ["serialno", "sample_geog"]] + segmented_marg = process_raw_ipu_marg(marg, atts) new_seed = seed.drop(columns=["sample_geog", "serialno"], errors="ignore") return segmented_marg, new_seed diff --git a/PopSynthesis/Methods/IPSF/SAA/run_hh.py b/PopSynthesis/Methods/IPSF/SAA/run_hh.py index f15fabe..18335b8 100644 --- a/PopSynthesis/Methods/IPSF/SAA/run_hh.py +++ b/PopSynthesis/Methods/IPSF/SAA/run_hh.py @@ -6,6 +6,7 @@ from PopSynthesis.Methods.IPSF.const import ( data_dir, processed_dir, + output_dir, POOL_SIZE, ) from PopSynthesis.Methods.IPSF.utils.pool_utils import create_pool @@ -27,6 +28,8 @@ def run_main() -> None: ] hh_seed = hh_seed[order_adjustment] pool = create_pool(seed=hh_seed, state_names=hh_att_state, pool_sz=POOL_SIZE) + + saa = SAA(hh_marg, hh_seed, order_adjustment, hh_att_state, pool) start_time = time.time()