Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overflow protection #764

Merged
merged 4 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions activitysim/core/interaction_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def _interaction_sample(
allow_zero_probs=allow_zero_probs,
trace_label=trace_label,
trace_choosers=choosers,
overflow_protection=not allow_zero_probs,
)
chunk_sizer.log_df(trace_label, "probs", probs)

Expand Down
29 changes: 18 additions & 11 deletions activitysim/core/interaction_sample_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,20 +248,27 @@ def _interaction_sample_simulate(

# convert to probabilities (utilities exponentiated and normalized to probs)
# probs is same shape as utilities, one row per chooser and one column for alternative
probs = logit.utils_to_probs(
state,
utilities_df,
allow_zero_probs=allow_zero_probs,
trace_label=trace_label,
trace_choosers=choosers,
)
chunk_sizer.log_df(trace_label, "probs", probs)

if want_logsums:
logsums = logit.utils_to_logsums(
utilities_df, allow_zero_probs=allow_zero_probs
probs, logsums = logit.utils_to_probs(
state,
utilities_df,
allow_zero_probs=allow_zero_probs,
trace_label=trace_label,
trace_choosers=choosers,
overflow_protection=not allow_zero_probs,
return_logsums=True,
)
chunk_sizer.log_df(trace_label, "logsums", logsums)
else:
probs = logit.utils_to_probs(
state,
utilities_df,
allow_zero_probs=allow_zero_probs,
trace_label=trace_label,
trace_choosers=choosers,
overflow_protection=not allow_zero_probs,
)
chunk_sizer.log_df(trace_label, "probs", probs)

del utilities_df
chunk_sizer.log_df(trace_label, "utilities_df", None)
Expand Down
50 changes: 48 additions & 2 deletions activitysim/core/logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
import warnings
from builtins import object

import numpy as np
Expand Down Expand Up @@ -130,6 +131,8 @@ def utils_to_probs(
exponentiated=False,
allow_zero_probs=False,
trace_choosers=None,
overflow_protection: bool = True,
return_logsums: bool = False,
):
"""
Convert a table of utilities to probabilities.
Expand All @@ -155,6 +158,20 @@ def utils_to_probs(
by report_bad_choices because it can't deduce hh_id from the interaction_dataset
which is indexed on index values from alternatives df

overflow_protection : bool, default True
Always shift utility values such that the maximum utility in each row is
zero. This constant per-row shift should not fundamentally alter the
computed probabilities, but will ensure that an overflow does not occur
that will create infinite or NaN values. This will also provide effective
protection against underflow; extremely rare probabilities will round to
zero, but by definition they are extremely rare and losing them entirely
should not impact the simulation in a measureable fashion, and at least one
(and sometimes only one) alternative is guaranteed to have non-zero
probability, as long as at least one alternative has a finite utility value.
If utility values are certain to be well-behaved and non-extreme, enabling
overflow_protection will have no benefit but impose a modest computational
overhead cost.

Returns
-------
probs : pandas.DataFrame
Expand All @@ -167,9 +184,27 @@ def utils_to_probs(
# utils_arr = utils.values.astype('float')
utils_arr = utils.values

if utils_arr.dtype == np.float32 and utils_arr.max() > 85:
if allow_zero_probs:
if overflow_protection:
warnings.warn(
"cannot set overflow_protection with allow_zero_probs", stacklevel=2
)
overflow_protection = utils_arr.dtype == np.float32 and utils_arr.max() > 85
if overflow_protection:
raise ValueError(
"cannot prevent expected overflow with allow_zero_probs"
)
else:
overflow_protection = overflow_protection or (
utils_arr.dtype == np.float32 and utils_arr.max() > 85
)

if overflow_protection:
# exponentiated utils will overflow, downshift them
utils_arr -= utils_arr.max(1, keepdims=True)
shifts = utils_arr.max(1, keepdims=True)
utils_arr -= shifts
else:
shifts = None

if not exponentiated:
# TODO: reduce memory usage by exponentiating in-place.
Expand All @@ -185,6 +220,15 @@ def utils_to_probs(

arr_sum = utils_arr.sum(axis=1)

if return_logsums:
with np.errstate(divide="ignore" if allow_zero_probs else "warn"):
logsums = np.log(arr_sum)
if shifts is not None:
logsums += np.squeeze(shifts, 1)
logsums = pd.Series(logsums, index=utils.index)
else:
logsums = None

if not allow_zero_probs:
zero_probs = arr_sum == 0.0
if zero_probs.any():
Expand Down Expand Up @@ -222,6 +266,8 @@ def utils_to_probs(

probs = pd.DataFrame(utils_arr, columns=utils.columns, index=utils.index)

if return_logsums:
return probs, logsums
return probs


Expand Down
1 change: 1 addition & 0 deletions activitysim/core/pathbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,7 @@ def build_virtual_path(
utilities_df,
allow_zero_probs=True,
trace_label=trace_label,
overflow_protection=False,
)
chunk_sizer.log_df(trace_label, "probs", probs)

Expand Down
1 change: 1 addition & 0 deletions activitysim/core/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,7 @@ def compute_nested_probabilities(
trace_label=trace_label,
exponentiated=True,
allow_zero_probs=True,
overflow_protection=False,
)

nested_probabilities = pd.concat([nested_probabilities, probs], axis=1)
Expand Down
28 changes: 26 additions & 2 deletions activitysim/core/test/test_logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,40 @@ def test_utils_to_probs_raises():
idx = pd.Index(name="household_id", data=[1])
with pytest.raises(RuntimeError) as excinfo:
logit.utils_to_probs(
state, pd.DataFrame([[1, 2, np.inf, 3]], index=idx), trace_label=None
state,
pd.DataFrame([[1, 2, np.inf, 3]], index=idx),
trace_label=None,
overflow_protection=False,
)
assert "infinite exponentiated utilities" in str(excinfo.value)

with pytest.raises(RuntimeError) as excinfo:
logit.utils_to_probs(
state, pd.DataFrame([[-999, -999, -999, -999]], index=idx), trace_label=None
state,
pd.DataFrame([[1, 2, 9999, 3]], index=idx),
trace_label=None,
overflow_protection=False,
)
assert "infinite exponentiated utilities" in str(excinfo.value)

with pytest.raises(RuntimeError) as excinfo:
logit.utils_to_probs(
state,
pd.DataFrame([[-999, -999, -999, -999]], index=idx),
trace_label=None,
overflow_protection=False,
)
assert "all probabilities are zero" in str(excinfo.value)

# test that overflow protection works
z = logit.utils_to_probs(
state,
pd.DataFrame([[1, 2, 9999, 3]], index=idx),
trace_label=None,
overflow_protection=True,
)
assert np.asarray(z).ravel() == pytest.approx(np.asarray([0.0, 0.0, 1.0, 0.0]))


def test_make_choices_only_one():
state = workflow.State().default_settings()
Expand Down
2 changes: 1 addition & 1 deletion activitysim/estimation/test/test_larch_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_location_model(
[
("non_mandatory_tour_scheduling", "SLSQP"),
("joint_tour_scheduling", "SLSQP"),
("atwork_subtour_scheduling", "SLSQP"),
# ("atwork_subtour_scheduling", "SLSQP"), # TODO this test is unstable, needs to be updated with better data
("mandatory_tour_scheduling_work", "SLSQP"),
("mandatory_tour_scheduling_school", "SLSQP"),
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ tour_id,person_id,tour_type,tour_type_count,tour_type_num,tour_num,tour_count,to
2373898,57899,work,1,1,1,1,mandatory,1,3402.0,3746.0,20552,47.0,7.0,17.0,10.0,,,WALK,1.0388895039783694,no_subtours,,0out_0in,work
2373980,57901,work,2,1,1,2,mandatory,1,3115.0,3746.0,20552,25.0,6.0,12.0,6.0,,,SHARED3FREE,0.6022315390131013,no_subtours,,0out_0in,work
2373981,57901,work,2,2,2,2,mandatory,1,3115.0,3746.0,20552,150.0,15.0,20.0,5.0,,,SHARED2FREE,0.6232767878249469,no_subtours,,1out_0in,work
2563802,62531,school,1,1,1,1,mandatory,1,3460.0,3316.0,21869,180.0,20.0,20.0,0.0,,,SHARED3FREE,-0.7094603590463964,,,0out_0in,school
2563802,62531,school,1,1,1,1,mandatory,1,3460.0,3316.0,21869,181.0,20.0,21.0,1.0,,,SHARED3FREE,-0.7094603590463964,,,0out_0in,school
2563821,62532,escort,1,1,1,1,non_mandatory,1,3398.0,3316.0,21869,20.0,6.0,7.0,1.0,,12.499268454965652,SHARED2FREE,-1.4604154628072699,,,0out_0in,escort
2563862,62533,escort,3,1,1,4,non_mandatory,1,3402.0,3316.0,21869,1.0,5.0,6.0,1.0,,12.534424209198946,SHARED3FREE,-1.2940574569954848,,,0out_3in,escort
2563863,62533,escort,3,2,2,4,non_mandatory,1,3519.0,3316.0,21869,99.0,11.0,11.0,0.0,,12.466623656700463,SHARED2FREE,-0.9326373013150777,,,0out_0in,escort
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ trip_id,person_id,household_id,primary_purpose,trip_num,outbound,trip_count,dest
18991850,57901,20552,work,2,True,2,3115,3460,2373981,work,,16,DRIVEALONEFREE,0.10597046751418379
18991853,57901,20552,work,1,False,1,3746,3115,2373981,home,,20,SHARED2FREE,0.23660752783217825
20510417,62531,21869,school,1,True,1,3460,3316,2563802,school,,20,SHARED3FREE,-1.4448137456466916
20510421,62531,21869,school,1,False,1,3316,3460,2563802,home,,20,WALK,-1.5207459403958272
20510421,62531,21869,school,1,False,1,3316,3460,2563802,home,,21,WALK,-1.5207459403958272
20510569,62532,21869,escort,1,True,1,3398,3316,2563821,escort,,6,SHARED2FREE,0.17869598454022895
20510573,62532,21869,escort,1,False,1,3316,3398,2563821,home,,7,DRIVEALONEFREE,0.20045149458253975
20510897,62533,21869,escort,1,True,1,3402,3316,2563862,escort,,5,SHARED3FREE,0.7112775892674524
Expand Down
Loading