Skip to content

Commit

Permalink
preprocessor and annotation in auto ownership
Browse files Browse the repository at this point in the history
  • Loading branch information
dhensle committed Dec 16, 2023
1 parent a8e755f commit 1c98505
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 2 deletions.
22 changes: 20 additions & 2 deletions activitysim/abm/models/auto_ownership.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# See full license in LICENSE.txt.
import logging

from activitysim.core import config, inject, pipeline, simulate, tracing
from activitysim.core import config, expressions, inject, pipeline, simulate, tracing

from .util import estimation
from .util import estimation, annotate

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -32,6 +32,21 @@ def auto_ownership_simulate(households, households_merged, chunk_size, trace_hh_

logger.info("Running %s with %d households", trace_label, len(choosers))

# - preprocessor
preprocessor_settings = model_settings.get("preprocessor", None)
if preprocessor_settings:

locals_d = {}
if constants is not None:
locals_d.update(constants)

expressions.assign_columns(
df=choosers,
model_settings=preprocessor_settings,
locals_dict=locals_d,
trace_label=trace_label,
)

if estimator:
estimator.write_model_settings(model_settings, model_settings_file_name)
estimator.write_spec(model_settings)
Expand Down Expand Up @@ -69,5 +84,8 @@ def auto_ownership_simulate(households, households_merged, chunk_size, trace_hh_
"auto_ownership", households.auto_ownership, value_counts=True
)

if model_settings.get("annotate_households"):
annotate.annotate_households(model_settings, trace_label)

if trace_hh_id:
tracing.trace_df(households, label="auto_ownership", warn_if_empty=True)
38 changes: 38 additions & 0 deletions activitysim/abm/models/util/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,44 @@
logger = logging.getLogger(__name__)


def annotate_households(model_settings, trace_label, locals_dict={}):
"""
Add columns to the households table in the pipeline according to spec.
Parameters
----------
model_settings : dict
trace_label : str
"""
households = inject.get_table("households").to_frame()
expressions.assign_columns(
df=households,
model_settings=model_settings.get("annotate_households"),
locals_dict=locals_dict,
trace_label=tracing.extend_trace_label(trace_label, "annotate_households"),
)
pipeline.replace_table("households", households)


def annotate_persons(model_settings, trace_label, locals_dict={}):
"""
Add columns to the persons table in the pipeline according to spec.
Parameters
----------
model_settings : dict
trace_label : str
"""
persons = inject.get_table("persons").to_frame()
expressions.assign_columns(
df=persons,
model_settings=model_settings.get("annotate_persons"),
locals_dict=locals_dict,
trace_label=tracing.extend_trace_label(trace_label, "annotate_persons"),
)
pipeline.replace_table("persons", persons)


def annotate_tours(model_settings, trace_label, locals_dict={}):
"""
Add columns to the tours table in the pipeline according to spec.
Expand Down

0 comments on commit 1c98505

Please sign in to comment.