Skip to content

Commit

Permalink
refactor for leaner SAA
Browse files Browse the repository at this point in the history
  • Loading branch information
bobkatla committed Oct 14, 2024
1 parent c933ebd commit b55a96d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 18 deletions.
22 changes: 10 additions & 12 deletions PopSynthesis/Methods/IPSF/SAA/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

import pandas as pd

from PopSynthesis.Methods.IPSF.const import POOL_SIZE, output_dir
from PopSynthesis.Methods.IPSF.const import output_dir
from PopSynthesis.Methods.IPSF.SAA.operations.general import (
process_raw_ipu_init,
process_raw_ipu_marg,
adjust_atts_state_match_census,
)
from typing import List, Dict
Expand All @@ -19,33 +19,31 @@ class SAA:
def __init__(
self,
marginal_raw: pd.DataFrame,
seed_raw: pd.DataFrame,
considered_atts: List[str],
ordered_to_adjust_atts: List[str],
att_states: Dict[str, List[str]],
pool: pd.DataFrame,
) -> None:
self.ordered_atts = ordered_to_adjust_atts
self.ordered_atts_to_adjust = ordered_to_adjust_atts
self.considered_atts = considered_atts
self.known_att_states = att_states
self.pool = pool
self.init_required_inputs(marginal_raw, seed_raw)
self.init_required_inputs(marginal_raw)

def init_required_inputs(self, marginal_raw: pd.DataFrame, seed_raw: pd.DataFrame):
converted_segment_marg, converted_seed = process_raw_ipu_init(
marginal_raw, seed_raw
)
self.seed = converted_seed
def init_required_inputs(self, marginal_raw: pd.DataFrame):
converted_segment_marg = process_raw_ipu_marg(marginal_raw, atts=self.considered_atts)
self.segmented_marg = converted_segment_marg

def run(self, output_each_step:bool =True) -> pd.DataFrame:
# Output the synthetic population, the main point
curr_syn_pop = None
adjusted_atts = []
for att in self.ordered_atts:
for att in self.ordered_atts_to_adjust:
sub_census = self.segmented_marg[att].reset_index()
curr_syn_pop = adjust_atts_state_match_census(
att, curr_syn_pop, sub_census, adjusted_atts, self.pool
)
adjusted_atts.append(att)
if output_each_step:
curr_syn_pop.to_csv(output_dir / f"syn_pop_adjusted_{att}.csv")
curr_syn_pop.to_csv(output_dir / f"syn_pop_adjusted_{att}_2.csv")
return curr_syn_pop
15 changes: 9 additions & 6 deletions PopSynthesis/Methods/IPSF/SAA/operations/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,8 @@
calculate_states_diff,
)
from PopSynthesis.Methods.IPSF.SAA.operations.zone_adjustment import zone_adjustment
import multiprocessing as mp


def process_raw_ipu_init(
marg: pd.DataFrame, seed: pd.DataFrame
) -> Tuple[Dict[str, pd.DataFrame], pd.DataFrame]:
atts = [x for x in seed.columns if x not in ["serialno", "sample_geog"]]
def process_raw_ipu_marg(marg: pd.DataFrame, atts: List[str]) -> pd.DataFrame:
segmented_marg = {}
zones = marg[marg.columns[marg.columns.get_level_values(0) == zone_field]].values
zones = [z[0] for z in zones]
Expand All @@ -29,6 +24,14 @@ def process_raw_ipu_init(
sub_marg.loc[:, [zone_field]] = zones
sub_marg = sub_marg.set_index(zone_field)
segmented_marg[att] = sub_marg
return segmented_marg


def process_raw_ipu_init(
marg: pd.DataFrame, seed: pd.DataFrame
) -> Tuple[Dict[str, pd.DataFrame], pd.DataFrame]:
atts = [x for x in seed.columns if x not in ["serialno", "sample_geog"]]
segmented_marg = process_raw_ipu_marg(marg, atts)
new_seed = seed.drop(columns=["sample_geog", "serialno"], errors="ignore")
return segmented_marg, new_seed

Expand Down
3 changes: 3 additions & 0 deletions PopSynthesis/Methods/IPSF/SAA/run_hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from PopSynthesis.Methods.IPSF.const import (
data_dir,
processed_dir,
output_dir,
POOL_SIZE,
)
from PopSynthesis.Methods.IPSF.utils.pool_utils import create_pool
Expand All @@ -27,6 +28,8 @@ def run_main() -> None:
]
hh_seed = hh_seed[order_adjustment]
pool = create_pool(seed=hh_seed, state_names=hh_att_state, pool_sz=POOL_SIZE)


saa = SAA(hh_marg, hh_seed, order_adjustment, hh_att_state, pool)

start_time = time.time()
Expand Down

0 comments on commit b55a96d

Please sign in to comment.