Skip to content

Commit

Permalink
Probabilities in the spec now apply to values generated from anonymis…
Browse files Browse the repository at this point in the history
…ing set SQL
  • Loading branch information
gherka committed Nov 27, 2023
1 parent e637252 commit 7513305
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 8 deletions.
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ disable=C0303, # trailing whitespace
E1130, # invalid unary operand (numpy)
W0622, # redefine builtins - __package__
W0640, # variable defined in loop
W0632, # unbalanced tuple unpacking - false positives


[BASIC]
Expand Down
28 changes: 21 additions & 7 deletions exhibit/core/generate/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def _generate_using_external_table(self, col_name, anon_set):
sql_tables = parser.tables
aliased_columns = parser.columns_aliases_names
source_table_id = self.spec_dict["metadata"]["id"]

if len(aliased_columns) != 1 or aliased_columns[0] != col_name:
raise RuntimeError(
f"Please make sure the SQL SELECT statement in {col_name}'s "
Expand Down Expand Up @@ -537,13 +537,20 @@ def _generate_using_external_table(self, col_name, anon_set):

# get the probabilities for the selected column in the external table
# at the level of the join key - use a hash for the combination of columns!
# TODO add a check for any existing probabilities in case user want to override
# the probabilities coming from the external table.

# Rather than use existing probabilities from the spec, treat them as a weight
# and apply them to the conditional, per-join key probabilities from external
# table. TEST THIS A LOT.
# table.
probas = {}
orig_vals = None

try:
orig_vals = self.spec_dict["columns"][col_name]["original_values"]
if isinstance(orig_vals, pd.DataFrame):
orig_vals = orig_vals.set_index(col_name)
# if we don't have original_values in the column spec, it's a date
except KeyError:
pass

groups = sql_df.groupby(join_columns)
for i, group in groups:
Expand All @@ -556,11 +563,18 @@ def _generate_using_external_table(self, col_name, anon_set):
.values
)
a, p = np.split(proba_arr, 2, axis=1)
# enusre p sums up to 1
a = a.flatten()
p = p.flatten().astype(float)
p = p * (1 / sum(p))

probas[i[0]] = (a.flatten(), p.flatten().astype(float))
if orig_vals is not None:
for j, val in enumerate(a):
if val in orig_vals.index:
p_weight = float(orig_vals.loc[val, "probability_vector"])
p[j] = p[j] * p_weight

# enusre p sums up to 1
p = p * (1 / sum(p))
probas[i[0]] = (a, p)

# take the data generated so far and generate appropriate values based on key
groups = existing_data.groupby(join_columns).groups
Expand Down
76 changes: 75 additions & 1 deletion exhibit/core/generate/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_conditional_sql_anonymising_set_has_aliased_column(self):
gen = tm.CategoricalDataGenerator(spec_dict=test_dict, core_rows=10)
self.assertRaises(RuntimeError, gen.generate)

def test_external_tables_used_inconditonal_sql_anonymising_set_exist(self):
def test_external_tables_used_in_conditonal_sql_anonymising_set_exist(self):
'''
Users can provide a custom SQL as anonymising set which can reference
columns in the spec as well as any table in the Exhibit DB.
Expand Down Expand Up @@ -426,6 +426,80 @@ def test_column_with_using_case_statement_in_conditonal_sql(self):
self.assertTrue((result.query("age > 18")["smoker"] == "yes").all())
self.assertTrue((result.query("age <= 18")["smoker"] == "no").all())

def test_column_with_external_sql_values_and_probablities(self):
'''
Users can provide a custom SQL as anonymising set which can reference
columns in the spec as well as any table in the Exhibit DB.
'''

set_sql = '''
SELECT temp_main.gender, temp_linked.condition as condition
FROM temp_main JOIN temp_linked ON temp_main.gender = temp_linked.gender
'''

linked_data = pd.DataFrame(data={
"gender" : ["M", "M", "M", "F", "F", "F"],
"condition": ["A", "B", "C", "C", "D", "E"]
})

original_vals = pd.DataFrame(data={
"condition" : ["A", "B", "C", "D", "E", "Missing Data"],
"probability_vector" : [0.1, 0.1, 0.5, 0.1, 0.2, 0.0],
})

db_util.insert_table(linked_data, "temp_linked")

test_dict = {
"_rng" : np.random.default_rng(seed=1),
"metadata": {
"categorical_columns": ["gender", "condition"],
"date_columns" : [],
"inline_limit" : 5,
"id" : "main"
},
"columns": {
"gender": {
"type": "categorical",
"uniques" : 2,
"original_values" : pd.DataFrame(data={
"gender" : ["M", "F", "Missing Data"],
"probability_vector" : [0.5, 0.5, 0]
}),
"paired_columns": None,
"anonymising_set" : "random",
"cross_join_all_unique_values" : False,
},
"condition": {
"type": "categorical",
"uniques" : 5,
"original_values" : original_vals,
"paired_columns": None,
"anonymising_set" : set_sql,
"cross_join_all_unique_values" : False,
}
},
}

gen = tm.CategoricalDataGenerator(spec_dict=test_dict, core_rows=1000)
result = gen.generate()

# basic check to see that we get gender-specific conditions
self.assertTrue(
(result.query("gender == 'M'")["condition"].isin(["A", "B", "C"]).all())
)

# main check to see if probabilities are taken into account
gen_probs = (
result.query("gender == 'F'")["condition"].value_counts() /
result.query("gender == 'F'")["condition"].value_counts().sum()
).round(1).values.tolist()

# the probs would sum up to 1, and rounding / RNG will mean that we should be close
# to the specified C=0.5, E=0.2 and D=0.1 probabilities.
expected_probs = [0.6, 0.2, 0.1]

self.assertListEqual(gen_probs, expected_probs)

def test_date_column_with_impossible_combination_of_from_to_and_period(self):
'''
By default, the spec is generated with date_from, date_to, unique periods and
Expand Down

0 comments on commit 7513305

Please sign in to comment.