diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 610e7f9b..b24a2191 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -49,7 +49,7 @@ settings should be prescribed in your VS Code settings JSON: ```json { "autoDocstring.customTemplatePath": "", - "autoDocstring.docstringFormat": "google", + "autoDocstring.docstringFormat": "google-notypes", "autoDocstring.generateDocstringOnEnter": true, "autoDocstring.guessTypes": true, "autoDocstring.includeExtendedSummary": false, diff --git a/src/midst_toolkit/common/variables.py b/src/midst_toolkit/common/variables.py new file mode 100644 index 00000000..484b35f2 --- /dev/null +++ b/src/midst_toolkit/common/variables.py @@ -0,0 +1,4 @@ +import torch + + +DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") diff --git a/src/midst_toolkit/data_processing/utils.py b/src/midst_toolkit/data_processing/utils.py index 9a940c64..db4253c9 100644 --- a/src/midst_toolkit/data_processing/utils.py +++ b/src/midst_toolkit/data_processing/utils.py @@ -176,6 +176,8 @@ def get_categorical_columns(dataframe: pd.DataFrame, threshold: int) -> list[str it is deemed a categorical column. For example, a hurricane might be rated from 1 to 5 in an integer based column. With a threshold of 10, this column would be added to the set of categorical columns. + NOTE: A failure case is DateTimes, which will not be detected as categorical, but are not exactly numerical either. + Args: dataframe: Dataframe from which to extract column names corresponding to categorical variables. threshold: Threshold below which a column with numerical values (integer or float for example) is deemed to @@ -190,10 +192,8 @@ def get_categorical_columns(dataframe: pd.DataFrame, threshold: int) -> list[str for column_name in dataframe.columns: # If dtype is an object (as str columns are), assume categorical - if ( - dataframe[column_name].dtype == "object" - or is_column_type_numerical(dataframe, column_name) - and dataframe[column_name].nunique() <= threshold + if dataframe[column_name].dtype == "object" or ( + is_column_type_numerical(dataframe, column_name) and dataframe[column_name].nunique() <= threshold ): categorical_variables.append(column_name) diff --git a/src/midst_toolkit/evaluation/privacy/distance_closest_record.py b/src/midst_toolkit/evaluation/privacy/distance_closest_record.py index 5f0afa96..74351eda 100644 --- a/src/midst_toolkit/evaluation/privacy/distance_closest_record.py +++ b/src/midst_toolkit/evaluation/privacy/distance_closest_record.py @@ -1,133 +1,15 @@ from logging import INFO -from typing import Any, overload +from typing import Any -import numpy as np import pandas as pd import torch -from sklearn.preprocessing import OneHotEncoder from tqdm import tqdm from midst_toolkit.common.logger import log +from midst_toolkit.common.variables import DEVICE from midst_toolkit.evaluation.metrics_base import MetricBase +from midst_toolkit.evaluation.privacy.distance_preprocess import preprocess_for_distance_computation from midst_toolkit.evaluation.privacy.distance_utils import NormType, minimum_distances -from midst_toolkit.evaluation.utils import extract_columns_based_on_meta_info - - -DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - - -@overload -def preprocess( - meta_info: dict[str, Any], synthetic_data: pd.DataFrame, real_data_train: pd.DataFrame -) -> tuple[pd.DataFrame, pd.DataFrame]: ... - - -@overload -def preprocess( - meta_info: dict[str, Any], - synthetic_data: pd.DataFrame, - real_data_train: pd.DataFrame, - real_data_test: pd.DataFrame, -) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: ... - - -def preprocess( - meta_info: dict[str, Any], - synthetic_data: pd.DataFrame, - real_data_train: pd.DataFrame, - real_data_test: pd.DataFrame | None = None, -) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame] | tuple[pd.DataFrame, pd.DataFrame]: - """ - This function performs preprocessing on Pandas dataframes to prepare for computation of the distance to closest - record score. Specifically, this function filters the provided raw dataframes to the appropriate numerical and - categorical columns based on the information of the ``meta_info`` JSON. For the numerical columns, it normalizes - values by the distance between the largest and smallest value of each column of the ``real_data_train`` numerical - values. The categorical columns are processed into one-hot encoding columns, where the transformation is fitted - on the concatenation of columns from each dataset. - - Args: - meta_info: JSON with meta information about the columns and their corresponding types that should be - considered. - synthetic_data: Dataframe containing all synthetically generated data. - real_data_train: Dataframe containing the real training data associated with the model that generated the - ``synthetic_data``. - real_data_test: Dataframe containing the real test data. It's important that this data was not seen by the - model that generated ``synthetic_data`` during training. If None, then it will, of course, not be - preprocessed. Defaults to None. - - Returns: - Processed Pandas dataframes with the synthetic data, real data for training, real data for testing if it was - provided. - """ - numerical_synthetic_data, categorical_synthetic_data = extract_columns_based_on_meta_info( - synthetic_data, meta_info - ) - numerical_real_data_train, categorical_real_data_train = extract_columns_based_on_meta_info( - real_data_train, meta_info - ) - - numerical_ranges = [ - numerical_real_data_train[index].max() - numerical_real_data_train[index].min() - for index in numerical_real_data_train.columns - ] - numerical_ranges_np = np.array(numerical_ranges) - - num_synthetic_data_np = numerical_synthetic_data.to_numpy() - num_real_data_train_np = numerical_real_data_train.to_numpy() - - # Normalize the values of the numerical columns of the different datasets by the ranges of the train set. - num_synthetic_data_np = num_synthetic_data_np / numerical_ranges_np - num_real_data_train_np = num_real_data_train_np / numerical_ranges_np - - cat_synthetic_data_np = categorical_synthetic_data.to_numpy().astype("str") - cat_real_data_train_np = categorical_real_data_train.to_numpy().astype("str") - - if real_data_test is not None: - numerical_real_data_test, categorical_real_data_test = extract_columns_based_on_meta_info( - real_data_test, meta_info - ) - num_real_data_test_np = numerical_real_data_test.to_numpy() - # Normalize the values of the numerical columns of the different datasets by the ranges of the train set. - num_real_data_test_np = num_real_data_test_np / numerical_ranges_np - cat_real_data_test_np = categorical_real_data_test.to_numpy().astype("str") - else: - num_real_data_test_np, cat_real_data_test_np = None, None - - if categorical_real_data_train.shape[1] > 0: - encoder = OneHotEncoder() - if cat_real_data_test_np is not None: - encoder.fit(np.concatenate((cat_synthetic_data_np, cat_real_data_train_np, cat_real_data_test_np), axis=0)) - else: - encoder.fit(np.concatenate((cat_synthetic_data_np, cat_real_data_train_np), axis=0)) - - cat_synthetic_data_oh = encoder.transform(cat_synthetic_data_np).toarray() - cat_real_data_train_oh = encoder.transform(cat_real_data_train_np).toarray() - if cat_real_data_test_np is not None: - cat_real_data_test_oh = encoder.transform(cat_real_data_test_np).toarray() - - else: - cat_synthetic_data_oh = np.empty((categorical_synthetic_data.shape[0], 0)) - cat_real_data_train_oh = np.empty((categorical_real_data_train.shape[0], 0)) - if categorical_real_data_test is not None: - cat_real_data_test_oh = np.empty((categorical_real_data_test.shape[0], 0)) - - processed_real_data_train = pd.DataFrame( - np.concatenate((num_real_data_train_np, cat_real_data_train_oh), axis=1) - ).astype(float) - processed_synthetic_data = pd.DataFrame( - np.concatenate((num_synthetic_data_np, cat_synthetic_data_oh), axis=1) - ).astype(float) - - if real_data_test is None: - return (processed_synthetic_data, processed_real_data_train) - - assert num_real_data_test_np is not None - assert cat_real_data_test_oh is not None - return ( - processed_synthetic_data, - processed_real_data_train, - pd.DataFrame(np.concatenate((num_real_data_test_np, cat_real_data_test_oh), axis=1)).astype(float), - ) class DistanceToClosestRecordScore(MetricBase): @@ -159,13 +41,15 @@ def __init__( Args: norm: Determines what norm the distances are computed in. Defaults to NormType.L1. batch_size: Batch size used to compute the DCR iteratively. Just needed to manage memory. Defaults to 1000. - device: What device the tensors should be sent to in order to perform the calculations. Defaults to DEVICE. + device: What device the tensors should be sent to in order to perform the calculations. Defaults to + "cuda" if CUDA is available, "cpu" otherwise. meta_info: This is only required/used if ``do_preprocess`` is True. JSON with meta information about the columns and their corresponding types that should be considered. At minimum, it should have the keys 'num_col_idx', 'cat_col_idx', 'target_col_idx', and 'task_type'. If None, then no preprocessing is expected to be done. Defaults to None. do_preprocess: Whether or not to preprocess the dataframes before performing the DCR computations. - Preprocessing is performed with the ``preprocess`` function Defaults to False. + Preprocessing is performed with the ``preprocess`` function. Note, ``meta_info`` must be provided in + order to perform the appropriate preprocessing steps. Defaults to False. """ self.norm = norm self.batch_size = batch_size @@ -190,8 +74,8 @@ def compute( NOTE: The dataframes provided need to be pre-processed into numerical values for each column in some way. That is, for example, the categorical variables should be one-hot encoded and the numerical values normalized in - some way. This can be done via the ``preprocess`` function beforehand or it can be done within compute if - ``do_preprocess`` is True and ``meta_info`` has been provided. + some way. This can be done via the ``preprocess`` function in ``distance_preprocess.py`` beforehand or it can + be done within ``compute`` if ``do_preprocess`` is True and ``meta_info`` has been provided. Args: real_data: Real data that was used to train the model that generated the ``synthetic_data``. @@ -205,7 +89,7 @@ def compute( assert holdout_data is not None, "For DCR score calculations, a holdout dataset is required" if self.do_preprocess: - synthetic_data, real_data, holdout_data = preprocess( + synthetic_data, real_data, holdout_data = preprocess_for_distance_computation( self.meta_info, synthetic_data, real_data, holdout_data ) @@ -221,7 +105,7 @@ def compute( end_index = min(start_index + self.batch_size, synthetic_data_tensor.size(0)) synthetic_data_batch = synthetic_data_tensor[start_index:end_index] - # Calculate distances for real and test data in smaller batches + # Calculate distances from synthetic data points to real and test data in smaller batches dcr_train_batch = minimum_distances( synthetic_data_batch, real_data_train_tensor, self.batch_size, self.norm ) @@ -260,13 +144,15 @@ def __init__( Args: norm: Determines what norm the distances are computed in. Defaults to NormType.L1. batch_size: Batch size used to compute the DCR iteratively. Just needed to manage memory. Defaults to 1000. - device: What device the tensors should be sent to in order to perform the calculations. Defaults to DEVICE. + device: What device the tensors should be sent to in order to perform the calculations. Defaults to + "cuda" if CUDA is available, "cpu" otherwise. meta_info: This is only required/used if ``do_preprocess`` is True. JSON with meta information about the columns and their corresponding types that should be considered. At minimum, it should have the keys 'num_col_idx', 'cat_col_idx', 'target_col_idx', and 'task_type'. If None, then no preprocessing is expected to be done. Defaults to None. do_preprocess: Whether or not to preprocess the dataframes before performing the DCR computations. - Preprocessing is performed with the ``preprocess`` function Defaults to False. + Preprocessing is performed with the ``preprocess``. Note, ``meta_info`` must be provided in order + to perform the appropriate preprocessing steps. function Defaults to False. """ self.norm = norm self.batch_size = batch_size @@ -287,7 +173,7 @@ def compute(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame) -> dict NOTE: The dataframes provided need to be pre-processed into numerical values for each column in some way. That is, for example, the categorical variables should be one-hot encoded and the numerical values normalized in - some way. This can be done via the ``preprocess`` function beforehand or it can be done within compute if + some way. This can be done via the ``preprocess`` function beforehand or it can be done within ``compute`` if ``do_preprocess`` is True and ``meta_info`` has been provided. Args: @@ -301,7 +187,7 @@ def compute(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame) -> dict Example: { "median_dcr_score": 0.79 } """ if self.do_preprocess: - synthetic_data, real_data = preprocess(self.meta_info, synthetic_data, real_data) + synthetic_data, real_data = preprocess_for_distance_computation(self.meta_info, synthetic_data, real_data) real_data_tensor = torch.tensor(real_data.to_numpy()).to(self.device) synthetic_data_tensor = torch.tensor(synthetic_data.to_numpy()).to(self.device) diff --git a/src/midst_toolkit/evaluation/privacy/distance_preprocess.py b/src/midst_toolkit/evaluation/privacy/distance_preprocess.py new file mode 100644 index 00000000..6c132e8c --- /dev/null +++ b/src/midst_toolkit/evaluation/privacy/distance_preprocess.py @@ -0,0 +1,122 @@ +from typing import Any, overload + +import numpy as np +import pandas as pd +from sklearn.preprocessing import OneHotEncoder + +from midst_toolkit.evaluation.utils import extract_columns_based_on_meta_info + + +@overload +def preprocess_for_distance_computation( + meta_info: dict[str, Any], synthetic_data: pd.DataFrame, real_data_train: pd.DataFrame +) -> tuple[pd.DataFrame, pd.DataFrame]: ... + + +@overload +def preprocess_for_distance_computation( + meta_info: dict[str, Any], + synthetic_data: pd.DataFrame, + real_data_train: pd.DataFrame, + real_data_test: pd.DataFrame, +) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: ... + + +def preprocess_for_distance_computation( + meta_info: dict[str, Any], + synthetic_data: pd.DataFrame, + real_data_train: pd.DataFrame, + real_data_test: pd.DataFrame | None = None, +) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame] | tuple[pd.DataFrame, pd.DataFrame]: + """ + This function performs preprocessing on Pandas dataframes to prepare for computation of various record-to-record + distances. This is used for computations like distance to closest record scores. Specifically, this function + filters the provided raw dataframes to the appropriate numerical and categorical columns based on the information + of the ``meta_info`` JSON. For the numerical columns, it normalizes values by the distance between the largest + and smallest value of each column of the ``real_data_train`` numerical values. The categorical columns are + processed into one-hot encoding columns, where the transformation is fitted on the concatenation of columns from + each dataset. + + Args: + meta_info: JSON with meta information about the columns and their corresponding types that should be + considered. + synthetic_data: Dataframe containing all synthetically generated data. + real_data_train: Dataframe containing the real training data associated with the model that generated the + ``synthetic_data``. + real_data_test: Dataframe containing the real test data. It's important that this data was not seen by the + model that generated ``synthetic_data`` during training. If None, then it will, of course, not be + preprocessed. Defaults to None. + + Returns: + Processed Pandas dataframes with the synthetic data, real data for training, real data for testing if it was + provided. + """ + numerical_synthetic_data, categorical_synthetic_data = extract_columns_based_on_meta_info( + synthetic_data, meta_info + ) + numerical_real_data_train, categorical_real_data_train = extract_columns_based_on_meta_info( + real_data_train, meta_info + ) + + numerical_ranges = [ + numerical_real_data_train[index].max() - numerical_real_data_train[index].min() + for index in numerical_real_data_train.columns + ] + numerical_ranges_np = np.array(numerical_ranges) + + num_synthetic_data_np = numerical_synthetic_data.to_numpy() + num_real_data_train_np = numerical_real_data_train.to_numpy() + + # Normalize the values of the numerical columns of the different datasets by the ranges of the train set. + num_synthetic_data_np = num_synthetic_data_np / numerical_ranges_np + num_real_data_train_np = num_real_data_train_np / numerical_ranges_np + + cat_synthetic_data_np = categorical_synthetic_data.to_numpy().astype("str") + cat_real_data_train_np = categorical_real_data_train.to_numpy().astype("str") + + if real_data_test is not None: + numerical_real_data_test, categorical_real_data_test = extract_columns_based_on_meta_info( + real_data_test, meta_info + ) + num_real_data_test_np = numerical_real_data_test.to_numpy() + # Normalize the values of the numerical columns of the different datasets by the ranges of the train set. + num_real_data_test_np = num_real_data_test_np / numerical_ranges_np + cat_real_data_test_np = categorical_real_data_test.to_numpy().astype("str") + else: + num_real_data_test_np, cat_real_data_test_np = None, None + + if categorical_real_data_train.shape[1] > 0: + encoder = OneHotEncoder() + if cat_real_data_test_np is not None: + encoder.fit(np.concatenate((cat_synthetic_data_np, cat_real_data_train_np, cat_real_data_test_np), axis=0)) + else: + encoder.fit(np.concatenate((cat_synthetic_data_np, cat_real_data_train_np), axis=0)) + + cat_synthetic_data_oh = encoder.transform(cat_synthetic_data_np).toarray() + cat_real_data_train_oh = encoder.transform(cat_real_data_train_np).toarray() + if cat_real_data_test_np is not None: + cat_real_data_test_oh = encoder.transform(cat_real_data_test_np).toarray() + + else: + cat_synthetic_data_oh = np.empty((categorical_synthetic_data.shape[0], 0)) + cat_real_data_train_oh = np.empty((categorical_real_data_train.shape[0], 0)) + if categorical_real_data_test is not None: + cat_real_data_test_oh = np.empty((categorical_real_data_test.shape[0], 0)) + + processed_real_data_train = pd.DataFrame( + np.concatenate((num_real_data_train_np, cat_real_data_train_oh), axis=1) + ).astype(float) + processed_synthetic_data = pd.DataFrame( + np.concatenate((num_synthetic_data_np, cat_synthetic_data_oh), axis=1) + ).astype(float) + + if real_data_test is None: + return (processed_synthetic_data, processed_real_data_train) + + assert num_real_data_test_np is not None + assert cat_real_data_test_oh is not None + return ( + processed_synthetic_data, + processed_real_data_train, + pd.DataFrame(np.concatenate((num_real_data_test_np, cat_real_data_test_oh), axis=1)).astype(float), + ) diff --git a/src/midst_toolkit/evaluation/privacy/distance_utils.py b/src/midst_toolkit/evaluation/privacy/distance_utils.py index ab0a77a2..f8511955 100644 --- a/src/midst_toolkit/evaluation/privacy/distance_utils.py +++ b/src/midst_toolkit/evaluation/privacy/distance_utils.py @@ -8,81 +8,88 @@ class NormType(Enum): L2 = "l2" -def compute_l1_distance( - target_data: torch.Tensor, reference_data: torch.Tensor, skip_diagonal: bool = False -) -> torch.Tensor: +def compute_l1_distances(target_data: torch.Tensor, reference_data: torch.Tensor) -> torch.Tensor: """ - Compute the smallest l1 distance between each point in the target data tensor compared to all points in the - reference data tensor. + This computes the l1 distances between every point in ``target_data`` to every point in ``reference_data``. + The final distances are arranged with shape (n_target_points, n_ref_points), where rows correspond to the distance + of a single point in ``target_data`` to all points in ``reference_data``. Args: - target_data: Tensor of target data. Assumed to be a 2D tensor with batch size first, followed by - data dimension. - reference_data: Tensor of reference data. Assumed to be a 2D tensor with batch size first, followed by - data dimension. - skip_diagonal: Whether or not to skip computations on diagonal of distance matrix. This is generally only used - when ``target_data`` and ``reference_data`` are the same set. In this case, the diagonal elements are the - distance of the point from itself (which is 0). Defaults to False. + target_data: First tensor with shape (n_target_points, data_dim). + reference_data: Second tensor with shape (n_reference_points, data_dim). Returns: - A 1D tensor containing the l1 minimum distances between each data point in the target data and all points in - the reference data. Order will be the same as the target data. + A matrix with the l1 distances between all points in ``target_data`` to all points in ``reference_data``. Rows + correspond the distance of a single point in ``target_data`` to all points in ``reference_data``. Order will + be the same as ``target_data.`` """ - assert target_data.ndim == 2 and reference_data.ndim == 2, "Target and Reference data tensors should be 2D" - assert target_data.shape[1] == reference_data.shape[1], "Data dimensions do not match for the provided tensors" - # For target_data (n_target_points, data_dim), and reference_data (n_ref_points, data_dim), this subtracts # every point in reference_data from every point in target_data to create a tensor of shape # (n_target_points, n_ref_points, data_dim). point_differences = target_data[:, None] - reference_data - distances = (point_differences).abs().sum(dim=2) + return (point_differences).abs().sum(dim=2) + - # Minimum distance of points in n_target_points compared to all other points in reference_data. - if not skip_diagonal: - min_batch_distances, _ = distances.min(dim=1) - return min_batch_distances +def compute_l2_distances(target_data: torch.Tensor, reference_data: torch.Tensor) -> torch.Tensor: + """ + This computes the l2 distances between every point in ``target_data`` to every point in ``reference_data``. + The final distances are arranged with shape (n_target_points, n_ref_points), where rows correspond to the distance + of a single point in ``target_data`` to all points in ``reference_data``. + + Args: + target_data: First tensor with shape (n_target_points, data_dim). + reference_data: Second tensor with shape (n_reference_points, data_dim). - # Bottom two distances, because one of them might be the reference point to itself. - min_batch_distances, _ = torch.topk(distances, 2, dim=1, largest=False) - return min_batch_distances + Returns: + A matrix with the l2 distances between all points in ``target_data`` to all points in ``reference_data``. Rows + correspond the distance of a single point in ``target_data`` to all points in ``reference_data``. Order will + be the same as ``target_data.`` + """ + # For target_data (n_target_points, data_dim), and reference_data (n_reference_points, data_dim), this subtracts + # every point in reference_data from every point in target_data to create a tensor of shape + # (n_target_points, n_reference_points, data_dim). + point_differences = target_data[:, None] - reference_data + return torch.sqrt(torch.pow(point_differences, 2.0).sum(dim=2)) -def compute_l2_distance( - target_data: torch.Tensor, reference_data: torch.Tensor, skip_diagonal: bool = False +def compute_top_k_distances( + target_data: torch.Tensor, reference_data: torch.Tensor, norm: NormType = NormType.L1, top_k: int = 1 ) -> torch.Tensor: """ - Compute the smallest l2 distance between each point in the target data tensor compared to all points in the - reference data tensor. + This function computes the ``top_k`` SMALLEST distances for each point in ``target_data`` to points in + ``reference_data``. A matrix is returned whose rows correspond to the smallest distances from a point in + ``target_data`` to any points in reference data. The columns are in ascending order and ONLY the distances are + returned. Args: - target_data: Tensor of target data. Assumed to be a 2D tensor with batch size first, followed by - data dimension. - reference_data: Tensor of reference data. Assumed to be a 2D tensor with batch size first, followed by - data dimension. - skip_diagonal: Whether or not to skip computations on diagonal of distance matrix. This is generally only used - when ``target_data`` and ``reference_data`` are the same set. In this case, the diagonal elements are the - distance of the point from itself (which is 0). Defaults to False. + target_data: A 2-D tensor with shape (``n_target_datapoints``, ``data_dim``). Each point is compared to all + points in ``reference_data``. + reference_data: A 2-D tensor with shape (``n_reference_datapoints``, ``data_dim``). Each point in + ``target_data`` is compared to all points in this tensor. + norm: Type of norm to apply when measuring distance between two points. Defaults to NormType.L1. + top_k: Number of SMALLEST distances to return for each point in ``target_data``. Defaults to 1. + + Raises: + ValueError: Thrown if the requested distance measure is not supported. Returns: - A 1D tensor containing the l2 minimum distances between each data point in the target data and all points in - the reference data. Order will be the same as the target data. + A matrix of shape (``n_target_datapoints``, ``top_k``). Each row of this tensor corresponds to the SMALLEST + ``top_k`` distances from a point in ``target_data`` to any point in ``reference_data``. Order will be the same + as ``target_data.`` """ assert target_data.ndim == 2 and reference_data.ndim == 2, "Target and Reference data tensors should be 2D" assert target_data.shape[1] == reference_data.shape[1], "Data dimensions do not match for the provided tensors" - # For target_data (n_target_points, data_dim), and reference_data (n_reference_points, data_dim), this subtracts - # every point in reference_data from every point in target_data to create a tensor of shape - # (n_target_points, n_reference_points, data_dim). - point_differences = target_data[:, None] - reference_data - distances = torch.sqrt(torch.pow(point_differences, 2.0).sum(dim=2)) - # Minimum distance of points in n_target_points compared to all other points in reference_data. - if not skip_diagonal: - min_batch_distances, _ = distances.min(dim=1) - return min_batch_distances + if norm == NormType.L1: + distances = compute_l1_distances(target_data, reference_data) + elif norm == NormType.L2: + distances = compute_l2_distances(target_data, reference_data) + else: + raise ValueError(f"Unsupported NormType: {norm.value}") - # Bottom two distances, because one of them might be the reference point to itself. - min_batch_distances, _ = torch.topk(distances, 2, dim=1, largest=False) - return min_batch_distances + # Smallest top_k distances + top_k_distances, _ = torch.topk(distances, top_k, dim=1, largest=False) + return top_k_distances def minimum_distances( @@ -97,13 +104,14 @@ def minimum_distances( provided. This can be done in batches if specified. Otherwise, the entire computation is done at once. Args: - target_data: The complete set of target data, stacked as a tensor with shape (n_samples, data dimension). - reference_data: The complete set of reference data, stacked as a tensor with shape (n_samples, data dimension). + target_data: The complete set of target data, stacked as a tensor with shape (``n_samples``, data dimension). + reference_data: The complete set of reference data, stacked as a tensor with shape (``n_samples``, + data dimension). batch_size: Size of the batches to facilitate computing the minimum distances, if specified. Defaults to None. norm: Which type of norm to use as the distance metric. Defaults to NormType.L1. skip_diagonal: Whether or not to skip computations on diagonal of distance matrix. This is generally only used when ``target_data`` and ``reference_data`` are the same set. In this case, the diagonal elements are the - distance of the point from itself (which is 0). Defaults to False. + distance of the point from itself (which is always 0). Defaults to False. Returns: A 1D tensor with the minimum distances. Should be of length n_samples. Order will be the same as @@ -111,7 +119,7 @@ def minimum_distances( """ if batch_size is None: # If batch size isn't specified, do it all at once. - batch_size = target_data.size(0) + batch_size = reference_data.size(0) # Create a minimum distance for each target data sample if skip_diagonal: @@ -124,17 +132,14 @@ def minimum_distances( end_index = min(start_index + batch_size, reference_data.size(0)) reference_data_batch = reference_data[start_index:end_index] - if norm is NormType.L1: - min_batch_distances = compute_l1_distance(target_data, reference_data_batch, skip_diagonal) - elif norm is NormType.L2: - min_batch_distances = compute_l2_distance(target_data, reference_data_batch, skip_diagonal) - else: - raise ValueError(f"Unrecognized norm type: {str(norm)}") - if not skip_diagonal: - min_distances = torch.minimum(min_distances, min_batch_distances) - else: + if skip_diagonal: + min_batch_distances = compute_top_k_distances(target_data, reference_data_batch, norm, top_k=2) combined_distances = torch.cat((min_distances, min_batch_distances), dim=1) min_distances, _ = torch.topk(combined_distances, 2, dim=1, largest=False) + else: + min_batch_distances = compute_top_k_distances(target_data, reference_data_batch, norm, top_k=1) + min_distances = torch.minimum(min_distances, min_batch_distances.squeeze()) + if skip_diagonal: # Smallest distance should be point to itself. Second smallest is the rest. return min_distances[:, 1] diff --git a/src/midst_toolkit/evaluation/privacy/epsilon_identifiability_risk.py b/src/midst_toolkit/evaluation/privacy/epsilon_identifiability_risk.py new file mode 100644 index 00000000..31a1c052 --- /dev/null +++ b/src/midst_toolkit/evaluation/privacy/epsilon_identifiability_risk.py @@ -0,0 +1,142 @@ +from enum import Enum + +import pandas as pd +from syntheval.metrics.privacy.metric_epsilon_identifiability import EpsilonIdentifiability + +from midst_toolkit.evaluation.metrics_base import SynthEvalMetric + + +class EpsilonIdentifiabilityNorm(Enum): + """These are the valid norms for SynthEval measures.""" + + L2 = "euclid" + GOWER = "gower" + + +class EpsilonIdentifiabilityRisk(SynthEvalMetric): + def __init__( + self, + categorical_columns: list[str], + numerical_columns: list[str], + do_preprocess: bool = False, + norm: EpsilonIdentifiabilityNorm = EpsilonIdentifiabilityNorm.GOWER, + ): + """ + Class to compute the Epsilon Identifiability Risk. This computes the ratio of real data points that have a + synthetic data point closer than any other real data point in the set of data points. As such, a value closer + to 0 is better. + + If a holdout set is provided to the compute function, the same ratio is computed for holdout data points + compared with synthetic ones. The difference between the ratio for the real data points compared with the + holdout data points is then calculated. Ideally, these should be roughly the same (i.e. difference near zero) + or negative. In this scenario, it is typical that the real data was USED TO TRAIN a model that generated the + synthetic data and the holdout set represents real data that was NOT. + + NOTE: Columns are not uniformly weighted. They are weighted by their inverse column entropy to provide + greater attention to rare data points. This is formally defined in: + + Yoon, J., Drumright, L.N., Schaar, M.: Anonymization through data synthesis using generative adversarial + networks (ADS-GAN). IEEE J. Biomed. Health Informatics 24(8), 2378–2388 (2020) + https://doi.org/10.1109/JBHI.2020.2980262 + + NOTE: The dataframes provided need to be pre-processed into numerical values for each column in some way. That + is, for example, the categorical variables may be one-hot encoded and the numerical values normalized in + some way. This can be done via the ``preprocess`` function in ``distance_preprocess.py`` beforehand or it can + be done within ``compute`` if ``do_preprocess`` is True using the SynthEval pipeline. + + + + Args: + categorical_columns: Column names corresponding to the categorical variables of any provided dataframe. + numerical_columns: Column names corresponding to the numerical variables of any provided dataframe. + do_preprocess: Whether or not to preprocess the dataframes with the default pipeline used by SynthEval. + Defaults to False. + norm: The kind of norm to use when measuring distances between points. Only l2 and Gower norms are + currently supported. SynthEval defaults to Gower, so we do here as well. Note that if norm is + EpsilonIdentifiabilityNorm.L2, then distances only consider the columns specified by + ``numerical_columns``. Defaults to EpsilonIdentifiabilityNorm.GOWER. + """ + super().__init__(categorical_columns, numerical_columns, do_preprocess) + self.norm = norm + self.all_columns = categorical_columns + numerical_columns + + def compute( + self, + real_data: pd.DataFrame, + synthetic_data: pd.DataFrame, + holdout_data: pd.DataFrame | None = None, + ) -> dict[str, float]: + """ + Computes the Epsilon Identifiability Risk. This is the ratio of data points from ``real_data`` that have a + point from ``synthetic_data` that is closer than any other real data point in ``real_data``. As such, a value + closer to 0 is better. + + If ``holdout_data`` is provided, the same ratio is computed for points in ``holdout_data`` compared with those + in ``synthetic_data``. The difference between the ratio for ``real_data`` compared with ``holdout_data`` is + then calculated. Ideally, these should be roughly the same (i.e. difference near zero) or negative. In this + scenario, it is typical that the real data was USED TO TRAIN a model that generated the synthetic data and the + holdout set represents real data that was NOT. + + NOTE: Columns are not uniformly weighted. They are weighted by their inverse column entropy to provide + greater attention to rare data points. This is formally defined in: + + Yoon, J., Drumright, L.N., Schaar, M.: Anonymization through data synthesis using generative adversarial + networks (ADS-GAN). IEEE J. Biomed. Health Informatics 24(8), 2378–2388 (2020) + https://doi.org/10.1109/JBHI.2020.2980262 + + NOTE: The dataframes provided need to be pre-processed into numerical values for each column in some way. That + is, for example, the categorical variables may be one-hot encoded and the numerical values normalized in + some way. This can be done via the ``preprocess`` function in ``distance_preprocess.py`` beforehand or it can + be done within ``compute`` if ``do_preprocess`` is True using the SynthEval pipeline. + + Args: + real_data: Real data to which the synthetic data may be compared. In many cases this will be data used + to TRAIN the model that generated the synthetic data, but not always. + synthetic_data: Synthetically generated data whose quality is to be assessed. + holdout_data: Real data to which the synthetic data may also be compared. In many cases this will be data + was NOT used in training the generating model. If none, then 'privacy_loss' is not computed. + + Returns: + A dictionary of Epsilon Identifiability Risk results. Regardless of input, the estimated epsilon + identifiability risk for ``real_data`` is reported, keyed by 'epsilon_identifiability_risk'. If + ``holdout_data`` is provided. The difference of the risk for ``real_data`` and ``holdout_data`` is + reported, keyed by 'privacy_loss'. + """ + if self.do_preprocess: + if holdout_data is None: + real_data, synthetic_data = self.preprocess(real_data, synthetic_data) + else: + real_data, synthetic_data, holdout_data = self.preprocess(real_data, synthetic_data, holdout_data) + + # When using the l2 distance, SynthEval aims to filter to only the numerical columns. However, there is a bug + # when providing a holdout set, where that set does not get filtered. So we'll do it here. + if self.norm == EpsilonIdentifiabilityNorm.L2: + filtered_real_data = real_data[self.numerical_columns] + filtered_synthetic_data = synthetic_data[self.numerical_columns] + filtered_holdout_data = holdout_data[self.numerical_columns] if holdout_data is not None else None + elif self.norm == EpsilonIdentifiabilityNorm.GOWER: + # NOTE: The SynthEval class ignores column specifications by default. However, for other classes + # (correlation_matrix_difference for example), specifying less than all of the columns restricts the score + # computation to just those columns. To make this consistent we do that here, before passing to the + # SynthEval class. + filtered_real_data = real_data[self.all_columns] + filtered_synthetic_data = synthetic_data[self.all_columns] + filtered_holdout_data = holdout_data[self.all_columns] if holdout_data is not None else None + else: + raise ValueError(f"Unrecognized EpsilonIdentifiabilityNorm Option: {self.norm}") + + self.syntheval_metric = EpsilonIdentifiability( + real_data=filtered_real_data, + synt_data=filtered_synthetic_data, + hout_data=filtered_holdout_data, + cat_cols=self.categorical_columns, + num_cols=self.numerical_columns, + do_preprocessing=False, + verbose=False, + nn_dist=self.norm.value, + ) + result = self.syntheval_metric.evaluate() + result["epsilon_identifiability_risk"] = result.pop("eps_risk") + if holdout_data is not None: + result["privacy_loss"] = result.pop("priv_loss") + return result diff --git a/src/midst_toolkit/evaluation/privacy/nearest_neighbor_distance_ratio.py b/src/midst_toolkit/evaluation/privacy/nearest_neighbor_distance_ratio.py new file mode 100644 index 00000000..b47278da --- /dev/null +++ b/src/midst_toolkit/evaluation/privacy/nearest_neighbor_distance_ratio.py @@ -0,0 +1,155 @@ +import math +from typing import Any + +import pandas as pd +import torch +from tqdm import tqdm + +from midst_toolkit.common.variables import DEVICE +from midst_toolkit.evaluation.metrics_base import MetricBase +from midst_toolkit.evaluation.privacy.distance_preprocess import preprocess_for_distance_computation +from midst_toolkit.evaluation.privacy.distance_utils import NormType, compute_top_k_distances + + +class NearestNeighborDistanceRatio(MetricBase): + def __init__( + self, + norm: NormType = NormType.L2, + batch_size: int = 1000, + device: torch.device = DEVICE, + meta_info: dict[str, Any] | None = None, + do_preprocess: bool = False, + epsilon: float = 1e-8, + ): + """ + This class computes the nearest neighbor distance ratio (NNDR) between synthetic and real datasets. The + primary, real dataset typically corresponds to the data used to train the model that generated the + corresponding synthetic dataset. For each point in the synthetic dataset, the top two nearest points in the + real dataset are computed. The ratio of the two distances (closes/second closest) is computed for all synthetic + points and averaged for the final score. + + See: https://arxiv.org/pdf/2501.03941 + + Intuitively, this measures whether the synthetic points are in "dense" areas of the real data or "sparse" + regions, potentially endangering outliers. If the area is dense, the two distances will be similar and the + ratio close to 1. If not, the second closest point may be much farther away, producing a ratio closer to 0. + + If a holdout dataset, composed of real data points that were NOT used to train the generating model, is + provided the same computation comparing the synthetic data to the holdout set is performed. The difference + between the two ratios (train and holdout comparisons) is a score comparing the "privacy loss." The more + positive, the more the synthetic data may reveal about the original training set. + + NOTE: The dataframes provided need to be pre-processed into numerical values for each column in some way. That + is, for example, the categorical variables may be one-hot encoded and the numerical values normalized in + some way. This can be done via the ``preprocess`` function in ``distance_preprocess.py`` beforehand or it can + be done within ``compute`` if ``do_preprocess`` is True and ``meta_info`` has been provided. + + Args: + norm: Determines what norm the distances are computed in. Defaults to NormType.L2. + batch_size: Batch size used to compute the NNDR iteratively. Just needed to manage memory. Defaults to + 1000. + device: What device the tensors should be sent to in order to perform the calculations. Defaults to + "cuda" if CUDA is available, "cpu" otherwise. + meta_info: This is only required/used if ``do_preprocess`` is True. JSON with meta information about the + columns and their corresponding types that should be considered. At minimum, it should have the keys + 'num_col_idx' and 'cat_col_idx'. If 'target_col_idx' is specified then 'task_type' must also exist. + If None, then no preprocessing is expected to be done. Defaults to None. + do_preprocess: Whether or not to preprocess the dataframes before performing the NNDR calculations. + Preprocessing is performed with the ``preprocess`` function of ``distance_preprocess.py``. Note, + ``meta_info`` must be provided in order to perform the appropriate preprocessing steps. Defaults to + False. + epsilon: Regularization term that ensures that we do not divide by 0. Defaults to 1e-8 + """ + self.norm = norm + self.batch_size = batch_size + self.device = device + self.do_preprocess = do_preprocess + if self.do_preprocess and meta_info is None: + raise ValueError("Preprocessing requires meta_info to be defined, but it is None.") + self.meta_info = meta_info if meta_info is not None else {} + self.epsilon = epsilon + + def compute( + self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame, holdout_data: pd.DataFrame | None = None + ) -> dict[str, float]: + """ + Computes the nearest neighbor distance ratio (NNDR) between synthetic and real datasets. The primary, real + dataset typically corresponds to the data used to train the model that generated the corresponding synthetic + dataset. For each point in the synthetic dataset, the top two nearest points in the real dataset are computed. + The ratio of the two distances (closes/second closest) is computed for all synthetic points and averaged for + the final score. + + If a holdout dataset, composed of real data points that were NOT used to train the generating model, is + provided the same computation comparing the synthetic data to the holdout set is performed. The difference + between the two ratios (train and holdout comparisons) is a score comparing the "privacy loss." The more + positive, the more the synthetic data may reveal about the original training set. + + NOTE: The dataframes provided need to be pre-processed into numerical values for each column in some way. That + is, for example, the categorical variables may be one-hot encoded and the numerical values normalized in + some way. This can be done via the ``preprocess`` function in ``distance_preprocess.py`` beforehand or it can + be done within ``compute`` if ``do_preprocess`` is True and ``meta_info`` has been provided. + + Args: + real_data: Real data to which the synthetic data may be compared. In many cases this will be data used + to TRAIN the model that generated the synthetic data, but not always. + synthetic_data: Synthetically generated data whose quality is to be assessed. + holdout_data: Real data to which the synthetic data may also be compared. In many cases this will be data + was NOT used in training the generating model. If none, then the metrics 'privacy_loss' and + 'privacy_loss_standard_error' are not reported. Defaults to None. + + Returns: + A dictionary of NNDR results. Regardless of input, the mean of the NNDR values for each synthetic data + point and standard error of the mean are reported, keyed by 'mean_nndr' and 'nndr_standard_error', + respectively. If ``holdout_data`` is provided. The difference of the mean nndr using ``real_data`` and + ``holdout_data`` is reported as 'privacy_loss', along with the pooled standard errors for both + mean nndr values (key: 'privacy_loss_standard_error'). + """ + if self.do_preprocess: + if holdout_data is None: + synthetic_data, real_data = preprocess_for_distance_computation( + self.meta_info, synthetic_data, real_data + ) + else: + synthetic_data, real_data, holdout_data = preprocess_for_distance_computation( + self.meta_info, synthetic_data, real_data, holdout_data + ) + + synthetic_data_tensor = torch.tensor(synthetic_data.to_numpy()).to(self.device) + real_data_tensor = torch.tensor(real_data.to_numpy()).to(self.device) + mean_nndr, nndr_standard_error = self._compute_mean_nearest_neighbor_distance_ratio( + synthetic_data_tensor, real_data_tensor + ) + + result = { + "mean_nndr": mean_nndr, + "nndr_standard_error": nndr_standard_error, + } + + if holdout_data is not None: + holdout_data_tensor = torch.tensor(holdout_data.to_numpy()).to(self.device) + mean_nndr_holdout, nndr_standard_error_holdout = self._compute_mean_nearest_neighbor_distance_ratio( + synthetic_data_tensor, holdout_data_tensor + ) + result["privacy_loss"] = mean_nndr - mean_nndr_holdout + result["privacy_loss_standard_error"] = math.sqrt(nndr_standard_error**2 + nndr_standard_error_holdout**2) + + return result + + def _compute_mean_nearest_neighbor_distance_ratio( + self, target_tensor: torch.Tensor, reference_tensor: torch.Tensor + ) -> tuple[float, float]: + ratios = [] + # Assumes that the tensors are 2D and arranged (n_samples, data dimension) + for start_index in tqdm(range(0, target_tensor.size(0), self.batch_size)): + end_index = min(start_index + self.batch_size, target_tensor.size(0)) + target_data_batch = target_tensor[start_index:end_index] + + # Calculate top-2 distances for real and test data in smaller batches + top_2_distances = compute_top_k_distances(target_data_batch, reference_tensor, self.norm, top_k=2) + ratios.append(top_2_distances[:, 0] / (top_2_distances[:, 1] + self.epsilon)) + + all_ratios = torch.cat(ratios) + mean_ratios = float(torch.mean(all_ratios).item()) + ratios_standard_error = torch.std(all_ratios).item() / math.sqrt(len(all_ratios)) + + return mean_ratios, ratios_standard_error diff --git a/src/midst_toolkit/evaluation/privacy/scripts/midst_dcr_score_eval.py b/src/midst_toolkit/evaluation/privacy/scripts/midst_dcr_score_eval.py index 3eb16cbb..300c409e 100644 --- a/src/midst_toolkit/evaluation/privacy/scripts/midst_dcr_score_eval.py +++ b/src/midst_toolkit/evaluation/privacy/scripts/midst_dcr_score_eval.py @@ -6,10 +6,8 @@ from midst_toolkit.common.logger import log from midst_toolkit.data_processing.midst_data_processing import load_midst_data_with_test -from midst_toolkit.evaluation.privacy.distance_closest_record import ( - DistanceToClosestRecordScore, - preprocess, -) +from midst_toolkit.evaluation.privacy.distance_closest_record import DistanceToClosestRecordScore +from midst_toolkit.evaluation.privacy.distance_preprocess import preprocess_for_distance_computation # Killing a benign pandas warning @@ -108,7 +106,7 @@ real_data_train_path, synthetic_data_path, meta_info_path, real_data_test_path ) - synthetic_data, real_data_train, real_data_test = preprocess( + synthetic_data, real_data_train, real_data_test = preprocess_for_distance_computation( meta_info, synthetic_data, real_data_train, real_data_test ) metric = DistanceToClosestRecordScore() diff --git a/src/midst_toolkit/evaluation/quality/synthcity/metric.py b/src/midst_toolkit/evaluation/quality/synthcity/metric.py index d671f2e8..4e2ecb51 100644 --- a/src/midst_toolkit/evaluation/quality/synthcity/metric.py +++ b/src/midst_toolkit/evaluation/quality/synthcity/metric.py @@ -5,13 +5,11 @@ import numpy as np import torch +from midst_toolkit.common.variables import DEVICE from midst_toolkit.evaluation.quality.synthcity.dataloader import DataLoader from midst_toolkit.evaluation.quality.synthcity.one_class import OneClassLayer -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - class MetricEvaluator(metaclass=ABCMeta): def __init__( self, diff --git a/src/midst_toolkit/evaluation/utils.py b/src/midst_toolkit/evaluation/utils.py index fb6eb390..c4a8f6a2 100644 --- a/src/midst_toolkit/evaluation/utils.py +++ b/src/midst_toolkit/evaluation/utils.py @@ -53,8 +53,8 @@ def extract_columns_based_on_meta_info( Args: data: Dataframe to be filtered using the meta information meta_info: JSON with meta information about the columns and their corresponding types that should be - considered. At minimum, it should have the keys 'num_col_idx', 'cat_col_idx', 'target_col_idx', and - 'task_type' + considered. At minimum, it should have the keys 'num_col_idx', 'cat_col_idx'. If it also has a + 'target_col_idx' it must also specify a 'task_type'. Returns: Filtered dataframes. The first dataframe is the filtered set of columns associated with numerical data. The @@ -64,6 +64,7 @@ def extract_columns_based_on_meta_info( # Training the diffusion generators. # Enumerate columns and replace column name with index + data = data.copy() data.columns = list(range(len(data.columns))) # Get numerical and categorical column indices from meta info @@ -71,13 +72,14 @@ def extract_columns_based_on_meta_info( numerical_column_idx = meta_info["num_col_idx"] categorical_column_idx = meta_info["cat_col_idx"] - # Target columns are also part of the generation, just need to add it to the right "category" - target_col_idx = meta_info["target_col_idx"] - task_type = TaskType(meta_info["task_type"]) - if task_type == TaskType.REGRESSION: - numerical_column_idx = numerical_column_idx + target_col_idx - else: - categorical_column_idx = categorical_column_idx + target_col_idx + if "target_col_idx" in meta_info: + # Target columns are also part of the generation, just need to add it to the right "category" + target_col_idx = meta_info["target_col_idx"] + task_type = TaskType(meta_info["task_type"]) + if task_type == TaskType.REGRESSION: + numerical_column_idx = numerical_column_idx + target_col_idx + else: + categorical_column_idx = categorical_column_idx + target_col_idx numerical_data = data[numerical_column_idx] categorical_data = data[categorical_column_idx] diff --git a/tests/integration/models/clavaddpm/test_model.py b/tests/integration/models/clavaddpm/test_model.py index e088607a..dfe40210 100644 --- a/tests/integration/models/clavaddpm/test_model.py +++ b/tests/integration/models/clavaddpm/test_model.py @@ -397,7 +397,7 @@ def test_clustering_reload(tmp_path: Path): set_all_random_seeds(seed=133742, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) # Act - tables, relation_order, dataset_meta = load_multi_table(Path("tests/integration/assets/multi_table/")) + tables, relation_order, _ = load_multi_table(Path("tests/integration/assets/multi_table/")) tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, CLUSTERING_CONFIG) # Assert diff --git a/tests/unit/data_processing/test_utils.py b/tests/unit/data_processing/test_utils.py index 8ef3b4f0..b810aa3a 100644 --- a/tests/unit/data_processing/test_utils.py +++ b/tests/unit/data_processing/test_utils.py @@ -14,6 +14,13 @@ "column_c": ["house", "cat", "cat", "car", "dog"], "column_d": [1, 1, 3, 2, 2], "column_e": [1.0, 3.0, 1.0, 2.0, 1.0], + "column_f": [ + pd.Timestamp("2018-01-05"), + pd.Timestamp("2018-01-06"), + pd.Timestamp("2018-01-07"), + pd.Timestamp("2018-01-08"), + pd.Timestamp("2018-01-09"), + ], } ) @@ -24,15 +31,20 @@ def test_is_column_type_numerical() -> None: assert not is_column_type_numerical(TEST_DATAFRAME, "column_c") assert is_column_type_numerical(TEST_DATAFRAME, "column_d") assert is_column_type_numerical(TEST_DATAFRAME, "column_e") + assert not is_column_type_numerical(TEST_DATAFRAME, "column_f") def test_get_categorical_columns() -> None: # Low threshold categorical_columns = get_categorical_columns(TEST_DATAFRAME, 2) + # Note that this does not include the date time column, as it isn't a categorical, as the detection algorithm + # functions at the moment. assert categorical_columns == ["column_c"] # Higher threshold categorical_columns = get_categorical_columns(TEST_DATAFRAME, 4) + # Note that this does not include the date time column, as it isn't a categorical, as the detection algorithm + # functions at the moment. assert sorted(categorical_columns) == ["column_c", "column_d", "column_e"] @@ -42,11 +54,13 @@ def test_infer_categorical_and_numerical_columns() -> None: TEST_DATAFRAME, categorical_threshold=2 ) assert categorical_columns == ["column_c"] - assert sorted(numerical_columns) == ["column_a", "column_b", "column_d", "column_e"] + # Note that this includes the date time column, as it isn't a categorical, as the detection algorithm functions + # at the moment. + assert sorted(numerical_columns) == ["column_a", "column_b", "column_d", "column_e", "column_f"] # Higher threshold categorical_columns, numerical_columns = infer_categorical_and_numerical_columns( TEST_DATAFRAME, categorical_threshold=4 ) assert sorted(categorical_columns) == ["column_c", "column_d", "column_e"] - assert sorted(numerical_columns) == ["column_a", "column_b"] + assert sorted(numerical_columns) == ["column_a", "column_b", "column_f"] diff --git a/tests/unit/evaluation/privacy/test_distance_closest_record.py b/tests/unit/evaluation/privacy/test_distance_closest_record.py index bfa8df8d..7881ded1 100644 --- a/tests/unit/evaluation/privacy/test_distance_closest_record.py +++ b/tests/unit/evaluation/privacy/test_distance_closest_record.py @@ -8,8 +8,8 @@ from midst_toolkit.evaluation.privacy.distance_closest_record import ( DistanceToClosestRecordScore, minimum_distances, - preprocess, ) +from midst_toolkit.evaluation.privacy.distance_preprocess import preprocess_for_distance_computation from midst_toolkit.evaluation.privacy.distance_utils import NormType @@ -52,7 +52,7 @@ def test_dcr_score() -> None: REAL_DATA_TRAIN_PATH, SYNTHETIC_DATA_PATH, META_INFO_PATH, REAL_DATA_TEST_PATH ) - synthetic_data, real_data_train, real_data_test = preprocess( + synthetic_data, real_data_train, real_data_test = preprocess_for_distance_computation( meta_info, synthetic_data, real_data_train, real_data_test ) dcr_metric = DistanceToClosestRecordScore() diff --git a/tests/unit/evaluation/privacy/test_epsilon_identifiability_risk.py b/tests/unit/evaluation/privacy/test_epsilon_identifiability_risk.py new file mode 100644 index 00000000..c1839984 --- /dev/null +++ b/tests/unit/evaluation/privacy/test_epsilon_identifiability_risk.py @@ -0,0 +1,171 @@ +import pandas as pd +import pytest + +from midst_toolkit.data_processing.midst_data_processing import load_midst_data_with_test +from midst_toolkit.evaluation.privacy.distance_preprocess import preprocess_for_distance_computation +from midst_toolkit.evaluation.privacy.epsilon_identifiability_risk import ( + EpsilonIdentifiabilityNorm, + EpsilonIdentifiabilityRisk, +) + + +REAL_DATA = pd.DataFrame( + { + "column_a": [1, 2, 3, 4, 5], + "column_b": [4, 5, 6, 7, 8], + "column_c": ["horse", "dog", "horse", "cat", "cat"], + "column_d": [-1, -2, -3, -2, -5], + } +) + +SYNTHETIC_DATA = pd.DataFrame( + { + "column_a": [1, 2, 3, 4, 5], + "column_b": [4, 6, 6, -1, 1], + "column_c": ["cat", "dog", "horse", "cat", "cat"], + "column_d": [-1, -2, -3, -2, -5], + } +) + +HOLDOUT_DATA = pd.DataFrame( + { + "column_a": [2, 3, 4, 5, 6], + "column_b": [4, 5, 6, 2, 3], + "column_c": ["cat", "dog", "horse", "cat", "cat"], + "column_d": [-1, -2, -3, -2, -5], + } +) + +META_INFO = { + "num_col_idx": [0, 1, 3], + "cat_col_idx": [2], +} + + +SYNTHETIC_DATA_PATH = "tests/assets/synthetic_data_dcr.csv" +REAL_DATA_TRAIN_PATH = "tests/assets/real_data_dcr.csv" +REAL_DATA_TEST_PATH = "tests/assets/real_data_test.csv" +META_INFO_PATH = "tests/assets/meta_info.json" + + +def test_epsilon_identifiability_risk_small_data_l2() -> None: + eir_metric = EpsilonIdentifiabilityRisk( + categorical_columns=[], + numerical_columns=["column_a", "column_b", "column_d"], + norm=EpsilonIdentifiabilityNorm.L2, + ) + results = eir_metric.compute(REAL_DATA, SYNTHETIC_DATA) + + assert len(results) == 1 + target = 3 / 5 + assert pytest.approx(results["epsilon_identifiability_risk"], abs=1e-5) == target + + results = eir_metric.compute(REAL_DATA, SYNTHETIC_DATA, HOLDOUT_DATA) + + target_holdout = 5 / 5 + assert len(results) == 2 + assert pytest.approx(results["privacy_loss"], abs=1e-5) == target - target_holdout + + # Should get the same results even if we include cat columns, since we're using L2 + eir_metric = EpsilonIdentifiabilityRisk( + categorical_columns=["column_c"], + numerical_columns=["column_a", "column_b", "column_d"], + norm=EpsilonIdentifiabilityNorm.L2, + ) + results = eir_metric.compute(REAL_DATA, SYNTHETIC_DATA) + + assert len(results) == 1 + target = 3 / 5 + assert pytest.approx(results["epsilon_identifiability_risk"], abs=1e-5) == target + + +def test_epsilon_identifiability_risk_small_data_gower() -> None: + eir_metric = EpsilonIdentifiabilityRisk( + categorical_columns=[], + numerical_columns=["column_a", "column_b", "column_d"], + norm=EpsilonIdentifiabilityNorm.GOWER, + ) + results = eir_metric.compute(REAL_DATA, SYNTHETIC_DATA) + + assert len(results) == 1 + target = 5 / 5 + assert pytest.approx(results["epsilon_identifiability_risk"], abs=1e-5) == target + + results = eir_metric.compute(REAL_DATA, SYNTHETIC_DATA, HOLDOUT_DATA) + + target_holdout = 5 / 5 + assert len(results) == 2 + assert pytest.approx(results["privacy_loss"], abs=1e-5) == target - target_holdout + + # Using Categorical columns too after preprocess + real_data, synthetic_data = preprocess_for_distance_computation(META_INFO, REAL_DATA, SYNTHETIC_DATA) + + eir_metric = EpsilonIdentifiabilityRisk( + categorical_columns=[3, 4, 5], + numerical_columns=[0, 1, 2], + norm=EpsilonIdentifiabilityNorm.GOWER, + ) + + results = eir_metric.compute(real_data, synthetic_data) + + assert len(results) == 1 + target = 4 / 5 + assert pytest.approx(results["epsilon_identifiability_risk"], abs=1e-5) == target + + +def test_epsilon_identifiability_risk_small_data_with_preprocess() -> None: + eir_metric = EpsilonIdentifiabilityRisk( + categorical_columns=["column_c"], + numerical_columns=["column_a", "column_b", "column_d"], + do_preprocess=True, + ) + results = eir_metric.compute(REAL_DATA, SYNTHETIC_DATA) + + assert len(results) == 1 + target = 3 / 5 + assert pytest.approx(results["epsilon_identifiability_risk"], abs=1e-5) == target + + results = eir_metric.compute(REAL_DATA, SYNTHETIC_DATA, HOLDOUT_DATA) + target_holdout = 4 / 5 + + assert len(results) == 2 + assert pytest.approx(results["privacy_loss"], abs=1e-5) == target - target_holdout + + +def test_epsilon_identifiability_risk() -> None: + real_data, synthetic_data, holdout_data, meta_info = load_midst_data_with_test( + REAL_DATA_TRAIN_PATH, SYNTHETIC_DATA_PATH, META_INFO_PATH, REAL_DATA_TEST_PATH + ) + + synthetic_data, real_data, holdout_data = preprocess_for_distance_computation( + meta_info, synthetic_data, real_data, holdout_data + ) + + # After one-hot, we'll treat all the categoricals like numerical columns and leave off the target column + eir_metric = EpsilonIdentifiabilityRisk( + categorical_columns=[], + numerical_columns=list(meta_info["cat_col_idx"] + meta_info["num_col_idx"]), + norm=EpsilonIdentifiabilityNorm.L2, + ) + results = eir_metric.compute(real_data, synthetic_data, holdout_data) + + assert pytest.approx(results["epsilon_identifiability_risk"], abs=1e-8) == 0.21739130434782608 + assert pytest.approx(results["privacy_loss"], abs=1e-8) == 0.02006688963210701 + + +def test_epsilon_identifiability_risk_with_preprocess() -> None: + real_data, synthetic_data, holdout_data, meta_info = load_midst_data_with_test( + REAL_DATA_TRAIN_PATH, SYNTHETIC_DATA_PATH, META_INFO_PATH, REAL_DATA_TEST_PATH + ) + categorical_columns = [real_data.columns[i] for i in meta_info["cat_col_idx"] + meta_info["target_col_idx"]] + numerical_columns = [real_data.columns[i] for i in meta_info["num_col_idx"]] + eir_metric = EpsilonIdentifiabilityRisk( + categorical_columns=categorical_columns, + numerical_columns=numerical_columns, + do_preprocess=True, + norm=EpsilonIdentifiabilityNorm.GOWER, + ) + results = eir_metric.compute(real_data, synthetic_data, holdout_data) + + assert pytest.approx(results["epsilon_identifiability_risk"], abs=1e-8) == 0.46488294314381273 + assert pytest.approx(results["privacy_loss"], abs=1e-8) == 0.023411371237458234 diff --git a/tests/unit/evaluation/privacy/test_median_dcr.py b/tests/unit/evaluation/privacy/test_median_dcr.py index 09d90f52..280b363b 100644 --- a/tests/unit/evaluation/privacy/test_median_dcr.py +++ b/tests/unit/evaluation/privacy/test_median_dcr.py @@ -8,8 +8,8 @@ from midst_toolkit.evaluation.privacy.distance_closest_record import ( MedianDistanceToClosestRecordScore, minimum_distances, - preprocess, ) +from midst_toolkit.evaluation.privacy.distance_preprocess import preprocess_for_distance_computation from midst_toolkit.evaluation.privacy.distance_utils import NormType @@ -49,7 +49,7 @@ def test_minimum_distance_l2_no_skip_diagonal() -> None: def test_median_dcr_score() -> None: real_data, synthetic_data, meta_info = load_midst_data(REAL_DATA_TRAIN_PATH, SYNTHETIC_DATA_PATH, META_INFO_PATH) - synthetic_data, real_data = preprocess(meta_info, synthetic_data, real_data) + synthetic_data, real_data = preprocess_for_distance_computation(meta_info, synthetic_data, real_data) dcr_metric = MedianDistanceToClosestRecordScore() dcr_score = dcr_metric.compute(real_data, synthetic_data) assert pytest.approx(dcr_score["median_dcr_score"], abs=1e-8) == 6.540543187576836 diff --git a/tests/unit/evaluation/privacy/test_nearest_neighbor_distance_ratio.py b/tests/unit/evaluation/privacy/test_nearest_neighbor_distance_ratio.py new file mode 100644 index 00000000..8f76d4ae --- /dev/null +++ b/tests/unit/evaluation/privacy/test_nearest_neighbor_distance_ratio.py @@ -0,0 +1,146 @@ +import pandas as pd +import pytest + +from midst_toolkit.data_processing.midst_data_processing import load_midst_data_with_test +from midst_toolkit.evaluation.privacy.distance_closest_record import NormType +from midst_toolkit.evaluation.privacy.distance_preprocess import preprocess_for_distance_computation +from midst_toolkit.evaluation.privacy.nearest_neighbor_distance_ratio import ( + NearestNeighborDistanceRatio, +) + + +REAL_DATA = pd.DataFrame( + { + "column_a": [1, 2, 3, 4, 5], + "column_b": [4, 5, 6, 7, 8], + "column_c": ["horse", "dog", "horse", "cat", "cat"], + "column_d": [-1, -2, -3, -2, -5], + } +) + +SYNTHETIC_DATA = pd.DataFrame( + { + "column_a": [1, 2, 3, 4, 5], + "column_b": [4, 6, 6, -1, 1], + "column_c": ["cat", "dog", "horse", "cat", "cat"], + "column_d": [-1, -2, -3, -2, -5], + } +) + +HOLDOUT_DATA = pd.DataFrame( + { + "column_a": [2, 3, 4, 5, 6], + "column_b": [4, 5, 6, 2, 3], + "column_c": ["cat", "dog", "horse", "cat", "cat"], + "column_d": [-1, -2, -3, -2, -5], + } +) + +META_INFO = { + "num_col_idx": [0, 1, 3], + "cat_col_idx": [2], +} + + +SYNTHETIC_DATA_PATH = "tests/assets/synthetic_data_dcr.csv" +REAL_DATA_TRAIN_PATH = "tests/assets/real_data_dcr.csv" +REAL_DATA_TEST_PATH = "tests/assets/real_data_test.csv" +META_INFO_PATH = "tests/assets/meta_info.json" + + +def test_nndr_score_small_data() -> None: + filtered_real_data = REAL_DATA[["column_a", "column_b", "column_d"]] + filtered_synthetic_data = SYNTHETIC_DATA[["column_a", "column_b", "column_d"]] + filtered_holdout_data = HOLDOUT_DATA[["column_a", "column_b", "column_d"]] + + nndr_metric = NearestNeighborDistanceRatio() + results = nndr_metric.compute(filtered_real_data, filtered_synthetic_data) + assert len(results) == 2 + target = ((0 / 1.7321) + (1.0000 / 1.4142) + (0 / 1.7321) + (5.9161 / 6.3246) + (5.7446 / 5.8310)) / 5.0 + assert pytest.approx(results["mean_nndr"], abs=1e-5) == target + + results = nndr_metric.compute(filtered_real_data, filtered_synthetic_data, filtered_holdout_data) + target_holdout = ((1 / 2.4495) + (1.4142 / 2.2361) + (1 / 1.4142) + (3.1623 / 5.3852) + (2.2361 / 3.1623)) / 5.0 + assert len(results) == 4 + assert pytest.approx(results["privacy_loss"], abs=1e-5) == target - target_holdout + + +def test_nndr_score_small_data_l1() -> None: + filtered_real_data = REAL_DATA[["column_a", "column_b", "column_d"]] + filtered_synthetic_data = SYNTHETIC_DATA[["column_a", "column_b", "column_d"]] + filtered_holdout_data = HOLDOUT_DATA[["column_a", "column_b", "column_d"]] + + nndr_metric = NearestNeighborDistanceRatio(norm=NormType.L1) + results = nndr_metric.compute(filtered_real_data, filtered_synthetic_data) + assert len(results) == 2 + target = ((0 / 3) + (1 / 2) + (0 / 3) + (8 / 8) + (7 / 9)) / 5.0 + assert pytest.approx(results["mean_nndr"], abs=1e-5) == target + + results = nndr_metric.compute(filtered_real_data, filtered_synthetic_data, filtered_holdout_data) + target_holdout = ((1 / 4) + (2 / 3) + (1 / 2) + (4 / 7) + (3 / 4)) / 5.0 + assert len(results) == 4 + assert pytest.approx(results["privacy_loss"], abs=1e-5) == target - target_holdout + + +def test_nndr_score_small_data_with_categoricals_ordinal() -> None: + mapped_real_data = REAL_DATA.replace({"cat": 0, "horse": 1, "dog": 2}) + mapped_synthetic_data = SYNTHETIC_DATA.replace({"cat": 0, "horse": 1, "dog": 2}) + mapped_holdout_data = HOLDOUT_DATA.replace({"cat": 0, "horse": 1, "dog": 2}) + + nndr_metric = NearestNeighborDistanceRatio() + results = nndr_metric.compute(mapped_real_data, mapped_synthetic_data) + assert len(results) == 2 + target = ((1.0000 / 2.6458) + (1.0000 / 1.7321) + (0 / 2.0000) + (6.0000 / 6.6332) + (5.8310 / 6.1644)) / 5.0 + assert pytest.approx(results["mean_nndr"], abs=1e-5) == target + + results = nndr_metric.compute(mapped_real_data, mapped_synthetic_data, mapped_holdout_data) + target_holdout = ( + (1.0000 / 3.1623) + (1.4142 / 2.4495) + (1.0000 / 1.7321) + (3.1623 / 5.3852) + (2.2361 / 3.1623) + ) / 5.0 + assert len(results) == 4 + assert pytest.approx(results["privacy_loss"], abs=1e-5) == target - target_holdout + + +def test_nndr_score_small_data_with_categoricals_one_hot() -> None: + synthetic_data, real_data, holdout_data = preprocess_for_distance_computation( + META_INFO, SYNTHETIC_DATA, REAL_DATA, HOLDOUT_DATA + ) + + nndr_metric = NearestNeighborDistanceRatio() + results = nndr_metric.compute(real_data, synthetic_data) + assert len(results) == 2 + target = ((1.0897 / 1.4142) + (0.2500 / 1.4577) + (0 / 0.8660) + (2.0000 / 2.0463) + (1.6956 / 1.7500)) / 5.0 + assert pytest.approx(results["mean_nndr"], abs=1e-4) == target + + results = nndr_metric.compute(real_data, synthetic_data, holdout_data) + target_holdout = ( + (0.2500 / 1.1456) + (0.3536 / 1.5207) + (0.2500 / 1.4577) + (0.7906 / 1.3463) + (0.5590 / 0.7906) + ) / 5.0 + assert len(results) == 4 + assert pytest.approx(results["privacy_loss"], abs=1e-5) == target - target_holdout + + +def test_nndr_score() -> None: + real_data, synthetic_data, holdout_data, meta_info = load_midst_data_with_test( + REAL_DATA_TRAIN_PATH, SYNTHETIC_DATA_PATH, META_INFO_PATH, REAL_DATA_TEST_PATH + ) + + synthetic_data, real_data, holdout_data = preprocess_for_distance_computation( + meta_info, synthetic_data, real_data, holdout_data + ) + nndr_metric = NearestNeighborDistanceRatio() + results = nndr_metric.compute(real_data, synthetic_data, holdout_data) + assert pytest.approx(results["mean_nndr"], abs=1e-8) == 0.9782823717907417 + assert pytest.approx(results["privacy_loss"], abs=1e-8) == 0.005370743246908338 + + +def test_nndr_score_with_preprocess() -> None: + real_data, synthetic_data, holdout_data, meta_info = load_midst_data_with_test( + REAL_DATA_TRAIN_PATH, SYNTHETIC_DATA_PATH, META_INFO_PATH, REAL_DATA_TEST_PATH + ) + + # Preprocessing internally should return the same result + nndr_metric = NearestNeighborDistanceRatio(meta_info=meta_info, do_preprocess=True) + results = nndr_metric.compute(real_data, synthetic_data, holdout_data) + assert pytest.approx(results["mean_nndr"], abs=1e-8) == 0.9782823717907417 + assert pytest.approx(results["privacy_loss"], abs=1e-8) == 0.005370743246908338 diff --git a/tests/unit/evaluation/quality/test_mean_f1_score_difference.py b/tests/unit/evaluation/quality/test_mean_f1_score_difference.py index 97c62c4a..b015c4d7 100644 --- a/tests/unit/evaluation/quality/test_mean_f1_score_difference.py +++ b/tests/unit/evaluation/quality/test_mean_f1_score_difference.py @@ -58,7 +58,6 @@ def test_mean_f1_score_diff_with_preprocess() -> None: assert pytest.approx(0.49960000000000004, abs=1e-8) == score["svm_synthetic_train_f1"] assert pytest.approx(0.49720000000000003, abs=1e-8) == score["logreg_real_train_f1"] assert pytest.approx(0.49960000000000004, abs=1e-8) == score["logreg_synthetic_train_f1"] - unset_all_random_seeds()