diff --git a/activitysim/abm/models/disaggregate_accessibility.py b/activitysim/abm/models/disaggregate_accessibility.py index fe79d3fcdf..2bb78bf60e 100644 --- a/activitysim/abm/models/disaggregate_accessibility.py +++ b/activitysim/abm/models/disaggregate_accessibility.py @@ -569,14 +569,13 @@ def merge_persons(self): inject.add_table("proto_persons_merged", persons_merged) -def get_disaggregate_logsums(network_los, chunk_size, trace_hh_id): +def get_disaggregate_logsums( + network_los, chunk_size, trace_hh_id, disagg_model_settings +): logsums = {} persons_merged = pipeline.get_table("proto_persons_merged").sort_index( inplace=False ) - disagg_model_settings = read_disaggregate_accessibility_yaml( - "disaggregate_accessibility.yaml" - ) for model_name in [ "workplace_location", @@ -696,8 +695,14 @@ def compute_disaggregate_accessibility(network_los, chunk_size, trace_hh_id): tracing.register_traceable_table(tablename, df) del df + disagg_model_settings = read_disaggregate_accessibility_yaml( + "disaggregate_accessibility.yaml" + ) + # Run location choice - logsums = get_disaggregate_logsums(network_los, chunk_size, trace_hh_id) + logsums = get_disaggregate_logsums( + network_los, chunk_size, trace_hh_id, disagg_model_settings + ) logsums = {k + "_accessibility": v for k, v in logsums.items()} # Combined accessibility table @@ -736,20 +741,20 @@ def compute_disaggregate_accessibility(network_los, chunk_size, trace_hh_id): logsums["proto_disaggregate_accessibility"] = access_df # Drop any tables prematurely created - for tablename in [ - "school_destination_size", - "workplace_destination_size", - ]: - pipeline.drop_table(tablename) + # FIXME: dropping size tables breaks restart functionality for location choice models. + # hopefully this pipeline mess just goes away with move away from orca.... + # for tablename in [ + # "school_destination_size", + # "workplace_destination_size", + # ]: + # pipeline.drop_table(tablename) for ch in list(pipeline.get_rn_generator().channels.keys()): pipeline.get_rn_generator().drop_channel(ch) - # Drop any prematurely added traceables - for trace in [ - x for x in inject.get_injectable("traceable_tables") if "proto_" not in x - ]: - tracing.deregister_traceable_table(trace) + # Dropping all traceable tables + for table in inject.get_injectable("traceable_tables"): + tracing.deregister_traceable_table(table) # need to clear any premature tables that were added during the previous run orca._TABLES.clear() @@ -760,4 +765,22 @@ def compute_disaggregate_accessibility(network_los, chunk_size, trace_hh_id): # Inject accessibility results into pipeline [inject.add_table(k, df) for k, df in logsums.items()] + # available post-processing + for annotations in disagg_model_settings.get("postprocess_proto_tables", []): + tablename = annotations["tablename"] + df = pipeline.get_table(tablename) + assert df is not None + assert annotations is not None + assign_columns( + df=df, + model_settings={ + **annotations["annotate"], + **disagg_model_settings["suffixes"], + }, + trace_label=tracing.extend_trace_label( + "disaggregate_accessibility.postprocess", tablename + ), + ) + pipeline.replace_table(tablename, df) + return diff --git a/activitysim/abm/tables/disaggregate_accessibility.py b/activitysim/abm/tables/disaggregate_accessibility.py index 4c4eb9ad40..db65652f5f 100644 --- a/activitysim/abm/tables/disaggregate_accessibility.py +++ b/activitysim/abm/tables/disaggregate_accessibility.py @@ -151,14 +151,13 @@ def disaggregate_accessibility(persons, households, land_use, accessibility): accessibility_cols = [ x for x in proto_accessibility_df.columns if "accessibility" in x ] + keep_cols = model_settings.get("KEEP_COLS", accessibility_cols) # Parse the merging parameters assert merging_params is not None # Check if already assigned! - if set(accessibility_cols).intersection(persons_merged_df.columns) == set( - accessibility_cols - ): + if set(keep_cols).intersection(persons_merged_df.columns) == set(keep_cols): return # Find the nearest zone (spatially) with accessibilities calculated @@ -190,7 +189,7 @@ def disaggregate_accessibility(persons, households, land_use, accessibility): # because it will get slightly different logsums for households in the same zone. # This is because different destination zones were selected. To resolve, get mean by cols. right_df = ( - proto_accessibility_df.groupby(merge_cols)[accessibility_cols] + proto_accessibility_df.groupby(merge_cols)[keep_cols] .mean() .sort_values(nearest_cols) .reset_index() @@ -223,9 +222,9 @@ def disaggregate_accessibility(persons, households, land_use, accessibility): ) # Predict the nearest person ID and pull the logsums - matched_logsums_df = right_df.loc[clf.predict(x_pop)][ - accessibility_cols - ].reset_index(drop=True) + matched_logsums_df = right_df.loc[clf.predict(x_pop)][keep_cols].reset_index( + drop=True + ) merge_df = pd.concat( [left_df.reset_index(drop=False), matched_logsums_df], axis=1 ).set_index("person_id") @@ -257,12 +256,12 @@ def disaggregate_accessibility(persons, households, land_use, accessibility): # Check that it was correctly left-joined assert all(persons_merged_df[merge_cols] == merge_df[merge_cols]) - assert any(merge_df[accessibility_cols].isnull()) + assert any(merge_df[keep_cols].isnull()) # Inject merged accessibilities so that it can be included in persons_merged function - inject.add_table("disaggregate_accessibility", merge_df[accessibility_cols]) + inject.add_table("disaggregate_accessibility", merge_df[keep_cols]) - return merge_df[accessibility_cols] + return merge_df[keep_cols] inject.broadcast(