From 7513305268480a04ba6e3199e81f2804c8d08457 Mon Sep 17 00:00:00 2001 From: gherka Date: Sun, 26 Nov 2023 22:11:11 +0000 Subject: [PATCH] Probabilities in the spec now apply to values generated from anonymising set SQL --- .pylintrc | 1 + exhibit/core/generate/categorical.py | 28 +++++-- .../core/generate/tests/test_categorical.py | 76 ++++++++++++++++++- 3 files changed, 97 insertions(+), 8 deletions(-) diff --git a/.pylintrc b/.pylintrc index eafaa17..8a04803 100644 --- a/.pylintrc +++ b/.pylintrc @@ -31,6 +31,7 @@ disable=C0303, # trailing whitespace E1130, # invalid unary operand (numpy) W0622, # redefine builtins - __package__ W0640, # variable defined in loop + W0632, # unbalanced tuple unpacking - false positives [BASIC] diff --git a/exhibit/core/generate/categorical.py b/exhibit/core/generate/categorical.py index 79b6ec5..77663f3 100644 --- a/exhibit/core/generate/categorical.py +++ b/exhibit/core/generate/categorical.py @@ -467,7 +467,7 @@ def _generate_using_external_table(self, col_name, anon_set): sql_tables = parser.tables aliased_columns = parser.columns_aliases_names source_table_id = self.spec_dict["metadata"]["id"] - + if len(aliased_columns) != 1 or aliased_columns[0] != col_name: raise RuntimeError( f"Please make sure the SQL SELECT statement in {col_name}'s " @@ -537,13 +537,20 @@ def _generate_using_external_table(self, col_name, anon_set): # get the probabilities for the selected column in the external table # at the level of the join key - use a hash for the combination of columns! - # TODO add a check for any existing probabilities in case user want to override - # the probabilities coming from the external table. # Rather than use existing probabilities from the spec, treat them as a weight # and apply them to the conditional, per-join key probabilities from external - # table. TEST THIS A LOT. + # table. probas = {} + orig_vals = None + + try: + orig_vals = self.spec_dict["columns"][col_name]["original_values"] + if isinstance(orig_vals, pd.DataFrame): + orig_vals = orig_vals.set_index(col_name) + # if we don't have original_values in the column spec, it's a date + except KeyError: + pass groups = sql_df.groupby(join_columns) for i, group in groups: @@ -556,11 +563,18 @@ def _generate_using_external_table(self, col_name, anon_set): .values ) a, p = np.split(proba_arr, 2, axis=1) - # enusre p sums up to 1 + a = a.flatten() p = p.flatten().astype(float) - p = p * (1 / sum(p)) - probas[i[0]] = (a.flatten(), p.flatten().astype(float)) + if orig_vals is not None: + for j, val in enumerate(a): + if val in orig_vals.index: + p_weight = float(orig_vals.loc[val, "probability_vector"]) + p[j] = p[j] * p_weight + + # enusre p sums up to 1 + p = p * (1 / sum(p)) + probas[i[0]] = (a, p) # take the data generated so far and generate appropriate values based on key groups = existing_data.groupby(join_columns).groups diff --git a/exhibit/core/generate/tests/test_categorical.py b/exhibit/core/generate/tests/test_categorical.py index e9a0439..e838d1f 100644 --- a/exhibit/core/generate/tests/test_categorical.py +++ b/exhibit/core/generate/tests/test_categorical.py @@ -184,7 +184,7 @@ def test_conditional_sql_anonymising_set_has_aliased_column(self): gen = tm.CategoricalDataGenerator(spec_dict=test_dict, core_rows=10) self.assertRaises(RuntimeError, gen.generate) - def test_external_tables_used_inconditonal_sql_anonymising_set_exist(self): + def test_external_tables_used_in_conditonal_sql_anonymising_set_exist(self): ''' Users can provide a custom SQL as anonymising set which can reference columns in the spec as well as any table in the Exhibit DB. @@ -426,6 +426,80 @@ def test_column_with_using_case_statement_in_conditonal_sql(self): self.assertTrue((result.query("age > 18")["smoker"] == "yes").all()) self.assertTrue((result.query("age <= 18")["smoker"] == "no").all()) + def test_column_with_external_sql_values_and_probablities(self): + ''' + Users can provide a custom SQL as anonymising set which can reference + columns in the spec as well as any table in the Exhibit DB. + ''' + + set_sql = ''' + SELECT temp_main.gender, temp_linked.condition as condition + FROM temp_main JOIN temp_linked ON temp_main.gender = temp_linked.gender + ''' + + linked_data = pd.DataFrame(data={ + "gender" : ["M", "M", "M", "F", "F", "F"], + "condition": ["A", "B", "C", "C", "D", "E"] + }) + + original_vals = pd.DataFrame(data={ + "condition" : ["A", "B", "C", "D", "E", "Missing Data"], + "probability_vector" : [0.1, 0.1, 0.5, 0.1, 0.2, 0.0], + }) + + db_util.insert_table(linked_data, "temp_linked") + + test_dict = { + "_rng" : np.random.default_rng(seed=1), + "metadata": { + "categorical_columns": ["gender", "condition"], + "date_columns" : [], + "inline_limit" : 5, + "id" : "main" + }, + "columns": { + "gender": { + "type": "categorical", + "uniques" : 2, + "original_values" : pd.DataFrame(data={ + "gender" : ["M", "F", "Missing Data"], + "probability_vector" : [0.5, 0.5, 0] + }), + "paired_columns": None, + "anonymising_set" : "random", + "cross_join_all_unique_values" : False, + }, + "condition": { + "type": "categorical", + "uniques" : 5, + "original_values" : original_vals, + "paired_columns": None, + "anonymising_set" : set_sql, + "cross_join_all_unique_values" : False, + } + }, + } + + gen = tm.CategoricalDataGenerator(spec_dict=test_dict, core_rows=1000) + result = gen.generate() + + # basic check to see that we get gender-specific conditions + self.assertTrue( + (result.query("gender == 'M'")["condition"].isin(["A", "B", "C"]).all()) + ) + + # main check to see if probabilities are taken into account + gen_probs = ( + result.query("gender == 'F'")["condition"].value_counts() / + result.query("gender == 'F'")["condition"].value_counts().sum() + ).round(1).values.tolist() + + # the probs would sum up to 1, and rounding / RNG will mean that we should be close + # to the specified C=0.5, E=0.2 and D=0.1 probabilities. + expected_probs = [0.6, 0.2, 0.1] + + self.assertListEqual(gen_probs, expected_probs) + def test_date_column_with_impossible_combination_of_from_to_and_period(self): ''' By default, the spec is generated with date_from, date_to, unique periods and