Skip to content

Commit e27474d

Browse files
Merge pull request #661 from EducationalTestingService/660/truncate_feature_values
660/truncate feature values
2 parents 10e33f0 + 66b98af commit e27474d

35 files changed

+374
-72
lines changed

doc/config_rsmexplain.rst.inc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ If this option is set to ``false``, the feature values for the responses in ``ba
9595

9696
If ``experiment_dir`` contains the rsmtool configuration file, that file's value for ``standardize_features`` will override the value specified by the user. The reason is that if ``rsmtool`` trained the model with (or without) standardized features, then ``rsmexplain`` must do the same for the explanations to be meaningful.
9797

98+
.. _truncate_outliers_rsmexplain:
99+
100+
truncate_outliers *(Optional)*
101+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
102+
If this option is set to ``false``, outliers (values more than 4 standard deviations away from the mean) in feature columns will _not_ be truncated. Defaults to ``true``.
103+
98104
.. _use_wandb_rsmexplain:
99105

100106
use_wandb *(Optional)*

doc/config_rsmpredict.rst.inc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ subgroups *(Optional)*
7474
~~~~~~~~~~~~~~~~~~~~~~
7575
A list of column names indicating grouping variables used for generating analyses specific to each of those defined subgroups. For example, ``["prompt, gender, native_language, test_country"]``. All these columns will be included into the predictions file with the original names.
7676

77+
.. _truncate_outliers_rsmpredict:
78+
79+
truncate_outliers *(Optional)*
80+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
81+
If this option is set to ``false``, outliers (values more than 4 standard deviations away from the mean) in feature columns will _not_ be truncated. Defaults to ``true``.
82+
7783
.. _use_wandb_rsmpredict:
7884

7985
use_wandb *(Optional)*

doc/config_rsmtool.rst.inc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,12 @@ Defaults to 0.4998.
355355

356356
For more fine-grained control over the trimming range, you can set ``trim_tolerance`` to `0` and use ``trim_min`` and ``trim_max`` to specify the exact floor and ceiling values.
357357

358+
.. _truncate_outliers:
359+
360+
truncate_outliers *(Optional)*
361+
"""""""""""""""""""""""""""""""
362+
If this option is set to ``false``, outliers (values more than 4 standard deviations away from the mean) in feature columns will _not_ be truncated. Defaults to ``true``.
363+
358364
.. _use_scaled_predictions_rsmtool:
359365
360366
use_scaled_predictions *(Optional)*

doc/config_rsmxval.rst.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ In addition to the fields described so far, an ``rsmxval`` configuration file al
8383
- ``trim_max``
8484
- ``trim_min``
8585
- ``trim_tolerance``
86+
- ``truncate_outliers``
8687
- ``use_scaled_predictions``
8788
- ``use_thumbnails``
8889
- ``use_truncation_thresholds``

rsmtool/preprocessor.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,7 @@ def preprocess_feature(
10201020
exclude_zero_sd=False,
10211021
raise_error=True,
10221022
truncations=None,
1023+
truncate_outliers=True,
10231024
):
10241025
"""
10251026
Remove outliers and transform the values in given numpy array.
@@ -1050,6 +1051,9 @@ def preprocess_feature(
10501051
truncations : pandas DataFrame, optional
10511052
A set of pre-defined truncation values.
10521053
Defaults to ``None``.
1054+
truncate_outliers : bool, optional
1055+
Whether to truncate outlier values.
1056+
Defaults to ``True``.
10531057
10541058
Returns
10551059
-------
@@ -1063,16 +1067,21 @@ def preprocess_feature(
10631067
If the preprocessed feature values have zero standard deviation
10641068
and ``exclude_zero_sd`` is set to ``True``.
10651069
"""
1066-
if truncations is not None:
1067-
# clamp outlier values using the truncations set
1068-
features_no_outliers = self.remove_outliers_using_truncations(
1069-
values, feature_name, truncations
1070-
)
1070+
if truncate_outliers:
1071+
if truncations is not None:
1072+
# clamp outlier values using the truncations set
1073+
features_no_outliers = self.remove_outliers_using_truncations(
1074+
values, feature_name, truncations
1075+
)
10711076

1077+
else:
1078+
# clamp any outlier values that are 4 standard deviations
1079+
# away from the mean
1080+
features_no_outliers = self.remove_outliers(
1081+
values, mean=feature_mean, sd=feature_sd
1082+
)
10721083
else:
1073-
# clamp any outlier values that are 4 standard deviations
1074-
# away from the mean
1075-
features_no_outliers = self.remove_outliers(values, mean=feature_mean, sd=feature_sd)
1084+
features_no_outliers = values
10761085

10771086
# apply the requested transformation to the feature
10781087
transformed_feature = FeatureTransformer().transform_feature(
@@ -1105,6 +1114,7 @@ def preprocess_features(
11051114
df_feature_specs,
11061115
standardize_features=True,
11071116
use_truncations=False,
1117+
truncate_outliers=True,
11081118
):
11091119
"""
11101120
Preprocess features in given data using corresponding specifications.
@@ -1128,11 +1138,15 @@ def preprocess_features(
11281138
standardize_features : bool, optional
11291139
Whether to standardize the features.
11301140
Defaults to ``True``.
1141+
truncate_outliers : bool, optional
1142+
Truncate outlier values if set in the config file
1143+
Defaults to ``True``.
11311144
use_truncations : bool, optional
11321145
Whether we should use the truncation set
11331146
for removing outliers.
11341147
Defaults to ``False``.
11351148
1149+
11361150
Returns
11371151
-------
11381152
df_train_preprocessed : pandas DataFrame
@@ -1178,6 +1192,7 @@ def preprocess_features(
11781192
train_feature_sd,
11791193
exclude_zero_sd=True,
11801194
truncations=truncations,
1195+
truncate_outliers=truncate_outliers,
11811196
)
11821197

11831198
testing_feature_values = df_test[feature_name].values
@@ -1188,6 +1203,7 @@ def preprocess_features(
11881203
train_feature_mean,
11891204
train_feature_sd,
11901205
truncations=truncations,
1206+
truncate_outliers=truncate_outliers,
11911207
)
11921208

11931209
# Standardize the features using the mean and sd computed on the
@@ -1708,6 +1724,9 @@ def process_data_rsmtool(self, config_obj, data_container_obj):
17081724
# should we standardize the features
17091725
standardize_features = config_obj["standardize_features"]
17101726

1727+
# should outliers be truncated?
1728+
truncate_outliers = config_obj.get("truncate_outliers", True)
1729+
17111730
# if we are excluding zero scores but trim_min
17121731
# is set to 0, then we need to warn the user
17131732
if exclude_zero_scores and spec_trim_min == 0:
@@ -1973,6 +1992,7 @@ def process_data_rsmtool(self, config_obj, data_container_obj):
19731992
feature_specs,
19741993
standardize_features,
19751994
use_truncations,
1995+
truncate_outliers,
19761996
)
19771997

19781998
# configuration options that either override previous values or are
@@ -2471,6 +2491,9 @@ def process_data_rsmpredict(self, config_obj, data_container_obj):
24712491
# should features be standardized?
24722492
standardize_features = config_obj.get("standardize_features", True)
24732493

2494+
# should outliers be truncated?
2495+
truncate_outliers = config_obj.get("truncate_outliers", True)
2496+
24742497
# should we predict expected scores
24752498
predict_expected_scores = config_obj["predict_expected_scores"]
24762499

@@ -2531,7 +2554,10 @@ def process_data_rsmpredict(self, config_obj, data_container_obj):
25312554
)
25322555

25332556
(df_features_preprocessed, df_excluded) = self.preprocess_new_data(
2534-
df_input, df_feature_info, standardize_features
2557+
df_input,
2558+
df_feature_info,
2559+
standardize_features=standardize_features,
2560+
truncate_outliers=truncate_outliers,
25352561
)
25362562

25372563
trim_min = df_postproc_params["trim_min"].values[0]
@@ -2646,6 +2672,9 @@ def process_data_rsmexplain(self, config_obj, data_container_obj):
26462672
# should features be standardized?
26472673
standardize_features = config_obj.get("standardize_features", True)
26482674

2675+
# should outliers be truncated?
2676+
truncate_outliers = config_obj.get("truncate_outliers", True)
2677+
26492678
# rename the ID columns in both frames
26502679
df_background_preprocessed = self.rename_default_columns(
26512680
df_background_features,
@@ -2689,10 +2718,16 @@ def process_data_rsmexplain(self, config_obj, data_container_obj):
26892718

26902719
# now pre-process all the features that go into the model
26912720
(df_background_preprocessed, _) = self.preprocess_new_data(
2692-
df_background_preprocessed, df_feature_info, standardize_features
2721+
df_background_preprocessed,
2722+
df_feature_info,
2723+
standardize_features=standardize_features,
2724+
truncate_outliers=truncate_outliers,
26932725
)
26942726
(df_explain_preprocessed, _) = self.preprocess_new_data(
2695-
df_explain_preprocessed, df_feature_info, standardize_features
2727+
df_explain_preprocessed,
2728+
df_feature_info,
2729+
standardize_features=standardize_features,
2730+
truncate_outliers=truncate_outliers,
26962731
)
26972732

26982733
# set ID column as index for the background and explain feature frames
@@ -2748,7 +2783,9 @@ def process_data(self, config_obj, data_container_obj, context="rsmtool"):
27482783
f"'rsmeval', 'rsmpredict', 'rsmexplain']. You specified `{context}`."
27492784
)
27502785

2751-
def preprocess_new_data(self, df_input, df_feature_info, standardize_features=True):
2786+
def preprocess_new_data(
2787+
self, df_input, df_feature_info, standardize_features=True, truncate_outliers=True
2788+
):
27522789
"""
27532790
Preprocess feature values using the parameters in ``df_feature_info``.
27542791
@@ -2780,6 +2817,10 @@ def preprocess_new_data(self, df_input, df_feature_info, standardize_features=Tr
27802817
Whether the features should be standardized prior to prediction.
27812818
Defaults to ``True``.
27822819
2820+
truncate_outliers : bool, optional
2821+
Whether outlier should be truncated prior to prediction.
2822+
Defaults to ``True``.
2823+
27832824
Returns
27842825
-------
27852826
df_features_preprocessed : pandas DataFrame
@@ -2881,6 +2922,7 @@ def preprocess_new_data(self, df_input, df_feature_info, standardize_features=Tr
28812922
train_feature_sd,
28822923
exclude_zero_sd=False,
28832924
raise_error=False,
2925+
truncate_outliers=truncate_outliers,
28842926
)
28852927

28862928
# filter the feature values once again to remove possible NaN and inf values that

rsmtool/rsmexplain.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,11 @@ def mask(learner, featureset, feature_range=None):
133133

134134

135135
def generate_explanation(
136-
config_file_or_obj_or_dict, output_dir, overwrite_output=False, logger=None, wandb_run=None
136+
config_file_or_obj_or_dict,
137+
output_dir,
138+
overwrite_output=False,
139+
logger=None,
140+
wandb_run=None,
137141
):
138142
"""
139143
Generate a shap.Explanation object.
@@ -265,35 +269,33 @@ def generate_explanation(
265269
f"generated during model training."
266270
)
267271

268-
# read the original rsmtool configuration file, if it exists, and figure
269-
# out the value of `standardize_features` that was specified when running
270-
# the original rsmtool experiment
271-
rsmexplain_standardize_features = configuration["standardize_features"]
272+
# read the original rsmtool configuration file, if it exists, and ensure
273+
# that we use its value of `standardize_features` and `truncate_outliers`
274+
# even if that means we have to override the values specified in the
275+
# rsmexplain configuration file
272276
expected_config_file_path = join(experiment_output_dir, f"{experiment_id}_rsmtool.json")
273277
if exists(expected_config_file_path):
274278
with open(expected_config_file_path, "r") as rsmtool_configfh:
275-
rsmtool_config = json.load(rsmtool_configfh)
276-
rsmtool_standardize_features = rsmtool_config["standardize_features"]
277-
278-
# use the original rsmtool experiment's value for `standardize_features`
279-
# for rsmexplain as well; raise a warning if the values were different
280-
# to begin with
281-
if rsmexplain_standardize_features != rsmtool_standardize_features:
282-
logger.warning(
283-
f"overwriting current `standardize_features` value "
284-
f"({rsmexplain_standardize_features}) to match "
285-
f"value specified in original rsmtool experiment "
286-
f"({rsmtool_standardize_features})."
287-
)
288-
configuration["standardize_features"] = rsmtool_standardize_features
279+
rsmtool_configuration = json.load(rsmtool_configfh)
280+
281+
for option in ["standardize_features", "truncate_outliers"]:
282+
rsmtool_value = rsmtool_configuration[option]
283+
rsmexplain_value = configuration[option]
284+
if rsmexplain_value != rsmtool_value:
285+
logger.warning(
286+
f"overwriting current `{option}` value "
287+
f"({rsmexplain_value}) to match "
288+
f"value specified in original rsmtool experiment "
289+
f"({rsmtool_value})."
290+
)
291+
configuration[option] = rsmtool_value
289292

290293
# if the original experiment rsmtool does not exist, let the user know
291294
else:
292295
logger.warning(
293-
f"cannot locate original rsmtool configuration; "
294-
f"ensure that current value of "
295-
f"`standardize_features` ({rsmexplain_standardize_features}) "
296-
f"was the same when running rsmtool."
296+
"cannot locate original rsmtool configuration; "
297+
"ensure that the values of `standardize_features` "
298+
"and `truncate_outliers` were the same as when running rsmtool."
297299
)
298300

299301
# load the background and explain data sets
@@ -547,7 +549,12 @@ def main():
547549
# or one of the valid optional arguments, then assume that they
548550
# are arguments for the "run" sub-command. This allows the
549551
# old style command-line invocations to work without modification.
550-
if sys.argv[1] not in VALID_PARSER_SUBCOMMANDS + ["-h", "--help", "-V", "--version"]:
552+
if sys.argv[1] not in VALID_PARSER_SUBCOMMANDS + [
553+
"-h",
554+
"--help",
555+
"-V",
556+
"--version",
557+
]:
551558
args_to_pass = ["run"] + sys.argv[1:]
552559
else:
553560
args_to_pass = sys.argv[1:]
@@ -561,7 +568,9 @@ def main():
561568
logger.info(f"Output directory: {args.output_dir}")
562569

563570
generate_explanation(
564-
abspath(args.config_file), abspath(args.output_dir), overwrite_output=args.force_write
571+
abspath(args.config_file),
572+
abspath(args.output_dir),
573+
overwrite_output=args.force_write,
565574
)
566575

567576
else:
@@ -570,7 +579,10 @@ def main():
570579

571580
# auto-generate an example configuration and print it to STDOUT
572581
generator = ConfigurationGenerator(
573-
"rsmexplain", as_string=True, suppress_warnings=args.quiet, use_subgroups=False
582+
"rsmexplain",
583+
as_string=True,
584+
suppress_warnings=args.quiet,
585+
use_subgroups=False,
574586
)
575587
configuration = (
576588
generator.interact(output_file_name=args.output_file.name if args.output_file else None)

0 commit comments

Comments
 (0)