diff --git a/exhibit/core/constraints.py b/exhibit/core/constraints.py index 38ee1dc..42ebe6f 100644 --- a/exhibit/core/constraints.py +++ b/exhibit/core/constraints.py @@ -139,8 +139,12 @@ def process_custom_constraints(self, custom_constraints): source = self.input if self.output is None else self.output output_df = source.copy() - # preserve category dtypes at the end of each custom action - output_dtypes = output_df.dtypes[output_df.dtypes == "category"] + # change categorical dtype to normal object to simplify the logic, particularly for + # assigning new values to series affected by custom constraints. You can still set + # the dtype to Categorical on a case by case basis where it improves performance. + cat_dtype_cols = output_df.select_dtypes(include=["category"]).columns + for col in cat_dtype_cols: + output_df[col] = output_df[col].astype("object") dispatch_dict = { "make_outlier" : self.make_outlier, @@ -209,7 +213,8 @@ def process_custom_constraints(self, custom_constraints): # overwrite the original DF row IDs with the adjusted ones output_df.loc[cc_filter_idx] = action_func( output_df, cc_filter_idx, target_str, - cc_partitions, **_kwargs).astype(output_dtypes) + cc_partitions, **_kwargs) + return output_df def adjust_dataframe_to_fit_constraint(self, anon_df, basic_constraint): @@ -636,9 +641,6 @@ def _make_distinct_within_group(group): if not group.duplicated().any(): return group - # make sure blank string is available as a category before replacing - if group.dtype.name == "category": #pragma: no cover - group = group.cat.add_categories([""]) new_group = group.where(~group.duplicated(), "").tolist() return pd.Series(new_group, index=group.index, name=target_col) diff --git a/exhibit/core/exhibit.py b/exhibit/core/exhibit.py index 91f9a98..72f4012 100644 --- a/exhibit/core/exhibit.py +++ b/exhibit/core/exhibit.py @@ -379,13 +379,16 @@ def execute_spec(self): miss_gen = MissingDataGenerator(self.spec_dict, anon_df) anon_df = miss_gen.add_missing_data() - #6.5) GENERATE DERIVED COLUMNS IF ANY ARE SPECIFIED + #6.1) GENERATE DERIVED COLUMNS IF ANY ARE SPECIFIED if self.derived_columns_first: #pragma: no cover for name, calc in self.spec_dict["derived_columns"].items(): if "Example" not in name: anon_df[name] = generate_derived_column(anon_df, calc) #7) PROCESS BASIC AND CUSTOM CONSTRAINTS (IF ANY) + # note that the ConstraintHandler changes the dtype of categorical columns from + # Categorical to object to reduce dtype-related bugs / unexpected behaviour, like + # when users supply a filter expression or want to add a new value to a column. ch = ConstraintHandler(self.spec_dict, anon_df) anon_df = ch.process_constraints() # if there are any constraints that affect categorical columns, we need to @@ -406,6 +409,7 @@ def execute_spec(self): # check if there are common columns between constraint targets and cat_cols if cat_cols_set & set(constraint_targets): + # fill NAs with a placeholder - so that correct weights could be used anon_df.loc[:, cat_cols] = ( anon_df.loc[:, cat_cols].fillna(MISSING_DATA_STR)) @@ -439,7 +443,7 @@ def execute_spec(self): if num_col in derived_def: anon_df[derived_col] = generate_derived_column(anon_df, derived_def) break - + # change the missing data placeholder back to NAs anon_df.loc[:, cat_cols] = anon_df.loc[:, cat_cols].applymap( lambda x: np.nan if x == MISSING_DATA_STR else x) diff --git a/exhibit/core/generate/missing.py b/exhibit/core/generate/missing.py index 4cd048f..76eb76c 100644 --- a/exhibit/core/generate/missing.py +++ b/exhibit/core/generate/missing.py @@ -165,7 +165,7 @@ def add_missing_data(self): set(self.spec_dict.get("derived_columns", {}).keys())) if not (any(self.nan_data[cat_cols].isna()) and num_cols): - return self.nan_data + return self.nan_data.astype(self.dtypes) cat_mask = self.nan_data[cat_cols].isna().any(axis=1) self.nan_data[cat_cols] = self.nan_data[cat_cols].fillna(MISSING_DATA_STR) diff --git a/exhibit/core/generate/tests/test_missing.py b/exhibit/core/generate/tests/test_missing.py index b3edd66..002b3aa 100644 --- a/exhibit/core/generate/tests/test_missing.py +++ b/exhibit/core/generate/tests/test_missing.py @@ -571,6 +571,46 @@ def test_user_linked_columns_having_missing_data(self): test_spec_dict=test_dict, return_spec=False) self.assertTrue(df.query("A == 'eggs'")["B"].isna().any()) + + def test_categorical_numerical_missing_data_with_make_null_cc(self): + ''' + Typing issues (categorical vs object) can cause bugs when we have categorical columns, + a make_null custom constraint, a filter casting categorical column to integers (which + assumes object, not categorical - because you can't cast categorical to int if there + is a Missing data categorical value - without removing unused categories first) AND + a numerical column. Commenting out the numerical column used to pass the test, and + uncommenting it used to fail it - which is wrong. + + Without extra checks, AGE.astype('int') will fail if AGE is dtype="category" because + it'll have numbers as strings (which can be cast to int) and "invisible" Missing data + which can't. + ''' + + test_df = pd.DataFrame(data={ + "AGE": ["1", "2", "3", "4", "4"], + "NULLED" : list("ABCAB"), + "NUMS": range(5) + }) + + test_dict = { + "metadata" : { + "number_of_rows" : 10, + "categorical_columns": ["AGE", "NULLED"], + "numerical_columns" : ["NUMS"] + }, + "constraints" : { + "custom_constraints" : { + "test_nulls" : { + "filter" : "AGE.astype('int') > 1", + "targets" : {"NULLED" : "make_null"} + } + } + } + } + + _, df = temp_exhibit(filename=test_df, test_spec_dict=test_dict, return_spec=False) + + self.assertTrue(df.NULLED.isna().any()) if __name__ == "__main__" and __package__ is None: #overwrite __package__ builtin as per PEP 366 diff --git a/exhibit/core/tests/test_constraints.py b/exhibit/core/tests/test_constraints.py index e04ab53..8e3ff83 100644 --- a/exhibit/core/tests/test_constraints.py +++ b/exhibit/core/tests/test_constraints.py @@ -1717,6 +1717,10 @@ def test_custom_constraints_assign_value(self): "C" : ["spam", "spam", "ham", "ham", "ham"] * 4, }) + # remember that all categorical data comes in as "Category" dtype which imposes + # limits on what you can and can't do with those columns, like adding new values. + test_data["A"] = test_data["A"].astype("category") + test_gen = tm.ConstraintHandler(test_dict, test_data) result = test_gen.process_constraints()