From d5dfb997ee4ae0ed3be9955ceb45f36c09acddf1 Mon Sep 17 00:00:00 2001 From: Shadan Khan <140972016+shadankhan108@users.noreply.github.com> Date: Mon, 19 May 2025 01:34:44 +1000 Subject: [PATCH] Add files via upload Signed-off-by: Shadan Khan <140972016+shadankhan108@users.noreply.github.com> --- Models/pateGAN_letter.ipynb | 624 ++++++++++++++++++++++++++++++++++++ 1 file changed, 624 insertions(+) create mode 100644 Models/pateGAN_letter.ipynb diff --git a/Models/pateGAN_letter.ipynb b/Models/pateGAN_letter.ipynb new file mode 100644 index 00000000..b2f58cca --- /dev/null +++ b/Models/pateGAN_letter.ipynb @@ -0,0 +1,624 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d0a5476a-bcf0-40c7-99f4-22dce6d830c5", + "metadata": {}, + "source": [ + "## Loading the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "52c36154-7a0e-4cca-91ef-ce752c8cc32e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "10" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from loguru import logger\n", + "import sys\n", + "\n", + "logger.remove()\n", + "logger.add(sys.stdout, level=\"INFO\") # or \"DEBUG\" for even more\n" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "id": "51a47467-849f-4020-93ee-3e3a8db61614", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " letter xbox ybox width height onpix xbar ybar x2bar y2bar xybar \\\n", + "0 B 4 2 5 4 4 8 7 6 6 7 \n", + "1 A 1 1 3 2 1 8 2 2 2 8 \n", + "2 B 5 9 7 7 10 9 8 4 4 6 \n", + "3 A 3 7 5 5 3 12 2 3 2 10 \n", + "4 A 3 8 5 6 3 9 2 2 3 8 \n", + "\n", + " x2ybar xy2bar xedge xedgey yedge yedgex \n", + "0 6 6 2 8 7 10 \n", + "1 2 8 1 6 2 7 \n", + "2 8 6 6 11 8 7 \n", + "3 2 9 2 6 3 8 \n", + "4 2 8 2 6 3 7 \n", + "\n", + "RangeIndex: 1555 entries, 0 to 1554\n", + "Data columns (total 17 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 letter 1555 non-null object\n", + " 1 xbox 1555 non-null int64 \n", + " 2 ybox 1555 non-null int64 \n", + " 3 width 1555 non-null int64 \n", + " 4 height 1555 non-null int64 \n", + " 5 onpix 1555 non-null int64 \n", + " 6 xbar 1555 non-null int64 \n", + " 7 ybar 1555 non-null int64 \n", + " 8 x2bar 1555 non-null int64 \n", + " 9 y2bar 1555 non-null int64 \n", + " 10 xybar 1555 non-null int64 \n", + " 11 x2ybar 1555 non-null int64 \n", + " 12 xy2bar 1555 non-null int64 \n", + " 13 xedge 1555 non-null int64 \n", + " 14 xedgey 1555 non-null int64 \n", + " 15 yedge 1555 non-null int64 \n", + " 16 yedgex 1555 non-null int64 \n", + "dtypes: int64(16), object(1)\n", + "memory usage: 206.7+ KB\n", + "None\n" + ] + } + ], + "source": [ + "\n", + "import pandas as pd\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "\n", + "csv_path = \"letter_a_b.csv\" \n", + "TARGET_COLUMN = 'letter'\n", + "df = pd.read_csv(csv_path)\n", + "\n", + "\n", + "df = df.dropna()\n", + "\n", + "print(df.head())\n", + "print(df.info())\n" + ] + }, + { + "cell_type": "markdown", + "id": "b58ec4d5-2bbc-4f51-a90c-ccb0d74259c5", + "metadata": {}, + "source": [ + "### Initialise PATE-GAN" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "id": "7e943780-cdd3-49e8-a55a-803e8ff1f10c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m2025-05-19 01:16:45.484\u001b[0m | \u001b[41m\u001b[1mCRITICAL\u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[41m\u001b[1mload failed: module 'synthcity.plugins.generic.plugin_great' has no attribute 'plugin'\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:45.485\u001b[0m | \u001b[41m\u001b[1mCRITICAL\u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[41m\u001b[1mload failed: module 'synthcity.plugins.generic.plugin_great' has no attribute 'plugin'\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:45.486\u001b[0m | \u001b[41m\u001b[1mCRITICAL\u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[41m\u001b[1mmodule plugin_great load failed\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:45.487\u001b[0m | \u001b[41m\u001b[1mCRITICAL\u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[41m\u001b[1mmodule disabled: /home/seyam-omar/jupyenv/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py\u001b[0m\n" + ] + } + ], + "source": [ + "from synthcity.plugins import Plugins\n", + "\n", + "pategan = Plugins().get(\n", + " \"pategan\",\n", + " # --- Differential Privacy ---\n", + " epsilon=2.0, # less strict for better utility\n", + " delta=1e-5,\n", + " lamda=0.01,\n", + " n_teachers=1,\n", + " teacher_template=\"xgboost\",\n", + " clipping_value=1,\n", + "\n", + " # --- Generator Config ---\n", + " generator_n_layers_hidden=5,\n", + " generator_n_units_hidden=256,\n", + " generator_nonlin=\"leaky_relu\",\n", + " generator_dropout=0.2,\n", + " generator_n_iter=25, # more updates per loop\n", + "\n", + " # --- Discriminator Config ---\n", + " discriminator_n_layers_hidden=5,\n", + " discriminator_n_units_hidden=256,\n", + " discriminator_nonlin=\"leaky_relu\",\n", + " discriminator_dropout=0.2,\n", + " discriminator_n_iter=25,\n", + "\n", + " # --- Training ---\n", + " n_iter=100,\n", + " lr=1e-4,\n", + " \n", + " weight_decay=1e-5,\n", + " batch_size=1024,\n", + " random_state=42,\n", + "\n", + " # --- Encoding ---\n", + " encoder_max_clusters=10,\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "id": "d63c3d06-2d3d-478c-937c-99c56ab48259", + "metadata": {}, + "source": [ + "### Fitting pate-gan" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "03c4d51d-08a4-4a91-bbbc-7f50810a8085", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m2025-05-19 01:16:48.281\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1m[pategan] using delta = 1.6308139207105265e-05\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:48.286\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mEncoding letter 5767584641373809913\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:48.290\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mEncoding xbox 12695041101220943635\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:48.979\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mEncoding ybox 13921920367518933482\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:49.676\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mEncoding width 6687162974100422244\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:49.955\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mEncoding height 7386505415593150607\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:49.962\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mEncoding onpix 9598955325056794481\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:50.487\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mEncoding xbar 8496188470491360011\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:50.496\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mEncoding ybar 5798536606865221659\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:50.654\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mEncoding x2bar 12008679059154499337\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:51.237\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mEncoding y2bar 5987834430205051737\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:51.243\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mEncoding xybar 8798025242305133887\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:51.249\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mEncoding x2ybar 4502825646803912971\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:51.256\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mEncoding xy2bar 9799609071340211071\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:51.262\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mEncoding xedge 8679319888590132864\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:51.268\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mEncoding xedgey 7574980522037697118\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:51.553\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mEncoding yedge 15830187168449519616\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:51.956\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mEncoding yedgex 5497041692268137304\u001b[0m\n", + "\u001b[32m2025-05-19 01:16:52.445\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mTraining GAN on device cuda. features = 169\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████████████████████████████████████████| 25/25 [03:10<00:00, 7.61s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m2025-05-19 01:20:03.539\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msynthcity.logger\u001b[0m:\u001b[36mlog_and_print\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1m[pategan it 1] epsilon_hat = 171.84304623692026. self.epsilon = 2.0\u001b[0m\n", + "Training complete. Fitted status: True\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "from synthcity.plugins.core.dataloader import GenericDataLoader\n", + "from synthcity.plugins.core.schema import Schema\n", + "\n", + "# Define the target column\n", + "target_col = TARGET_COLUMN\n", + "\n", + "# Identify all feature columns\n", + "feature_cols = [col for col in df.columns if col != target_col]\n", + "\n", + "# Create a domain dictionary marking all features as numerical\n", + "domain = {col: \"numerical\" for col in feature_cols}\n", + "\n", + "# Initialize the schema with the domain\n", + "schema = Schema(domain=domain)\n", + "\n", + "# Create the data loader with the specified schema\n", + "dataloader = GenericDataLoader(df, target_column=target_col, schema=schema)\n", + "\n", + "\n", + "# Fit the generator\n", + "pategan.fit(dataloader)\n", + "\n", + "print(\"Training complete. Fitted status:\", pategan.fitted)\n" + ] + }, + { + "cell_type": "markdown", + "id": "af4cad68-c9df-4021-b006-da725c67f7b3", + "metadata": {}, + "source": [ + "## Generate Synthetic data" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "id": "61fb3caf-f23f-488c-8094-34b670b2ad1c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total synthetic rows: 776\n", + "Per-class counts:\n", + "letter\n", + "A 388\n", + "B 388\n", + "Name: count, dtype: int64\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "TARGET_COLUMN='letter'\n", + "# --- Configuration ---\n", + "target_col = TARGET_COLUMN # your real target column\n", + "total_samples = int(len(df) * 0.5)\n", + "unique_classes = df[target_col].unique()\n", + "num_classes = len(unique_classes)\n", + "desired_per_class = total_samples // num_classes\n", + "\n", + "# Accumulate per‐class chunks here\n", + "accumulated = {cls: [] for cls in unique_classes}\n", + "\n", + "# You can tweak this batch size up or down\n", + "batch_size = desired_per_class * num_classes\n", + "\n", + "# Keep generating until every class has enough rows\n", + "while True:\n", + " # Check if we’re done\n", + " if all(\n", + " sum(len(chunk) for chunk in accumulated[cls]) >= desired_per_class\n", + " for cls in unique_classes\n", + " ):\n", + " break\n", + "\n", + " # Generate a fresh batch (unconditional)\n", + " batch_loader = pategan.generate(count=batch_size)\n", + " batch_df = batch_loader.dataframe()\n", + "\n", + " # For each class, grab as many as still needed (if any)\n", + " for cls in unique_classes:\n", + " current_count = sum(len(chunk) for chunk in accumulated[cls])\n", + " need = desired_per_class - current_count\n", + " if need > 0:\n", + " cls_rows = batch_df[batch_df[target_col] == cls]\n", + " if not cls_rows.empty:\n", + " # take up to `need` rows\n", + " accumulated[cls].append(cls_rows.iloc[:need])\n", + "\n", + "# Stitch together exactly `desired_per_class` rows of each class\n", + "parts = []\n", + "for cls in unique_classes:\n", + " cls_df = pd.concat(accumulated[cls], ignore_index=True)\n", + " parts.append(cls_df)\n", + "\n", + "synthetic_df = pd.concat(parts, ignore_index=True)\n", + "synthetic_df = synthetic_df.sample(frac=1, random_state=42) # shuffle\n", + "\n", + "# Verify\n", + "print(\"Total synthetic rows:\", len(synthetic_df))\n", + "print(\"Per-class counts:\")\n", + "print(synthetic_df[target_col].value_counts())\n" + ] + }, + { + "cell_type": "markdown", + "id": "89faf44c-ce1d-4f38-80cd-900d7aa9d38d", + "metadata": {}, + "source": [ + "### TSTR" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "id": "b059ae93-2ed1-4d33-a082-eb35d3d3bbe2", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Class mapping: {'A': 0, 'B': 1}\n", + "\n", + "--- 3-Fold Cross-Validation Results (All Folds) ---\n", + "\n", + "Logistic Regression - Per Fold Scores:\n", + " Fold 1:\n", + " Accuracy: 0.5174\n", + " Precision: 0.5176\n", + " Recall: 0.5174\n", + " F1: 0.5169\n", + " Fold 2:\n", + " Accuracy: 0.4981\n", + " Precision: 0.4979\n", + " Recall: 0.4981\n", + " F1: 0.4969\n", + " Fold 3:\n", + " Accuracy: 0.4574\n", + " Precision: 0.4569\n", + " Recall: 0.4574\n", + " F1: 0.4558\n", + "\n", + "Random Forest - Per Fold Scores:\n", + " Fold 1:\n", + " Accuracy: 0.4865\n", + " Precision: 0.4865\n", + " Recall: 0.4865\n", + " F1: 0.4865\n", + " Fold 2:\n", + " Accuracy: 0.4672\n", + " Precision: 0.4672\n", + " Recall: 0.4672\n", + " F1: 0.4667\n", + " Fold 3:\n", + " Accuracy: 0.4612\n", + " Precision: 0.4606\n", + " Recall: 0.4612\n", + " F1: 0.4589\n", + "\n", + "XGBoost - Per Fold Scores:\n", + " Fold 1:\n", + " Accuracy: 0.4865\n", + " Precision: 0.4864\n", + " Recall: 0.4865\n", + " F1: 0.4863\n", + " Fold 2:\n", + " Accuracy: 0.5290\n", + " Precision: 0.5290\n", + " Recall: 0.5290\n", + " F1: 0.5289\n", + " Fold 3:\n", + " Accuracy: 0.4690\n", + " Precision: 0.4689\n", + " Recall: 0.4690\n", + " F1: 0.4686\n", + "\n", + "MLP - Per Fold Scores:\n", + " Fold 1:\n", + " Accuracy: 0.4788\n", + " Precision: 0.4787\n", + " Recall: 0.4788\n", + " F1: 0.4771\n", + " Fold 2:\n", + " Accuracy: 0.5097\n", + " Precision: 0.5104\n", + " Recall: 0.5097\n", + " F1: 0.5052\n", + " Fold 3:\n", + " Accuracy: 0.5465\n", + " Precision: 0.5466\n", + " Recall: 0.5465\n", + " F1: 0.5463\n", + "\n", + "Cross-Validation Summary (Averaged Across Folds):\n", + " accuracy precision recall f1\n", + "Logistic Regression 0.490936 0.490772 0.490936 0.489839\n", + "Random Forest 0.471636 0.471403 0.471636 0.470699\n", + "XGBoost 0.494812 0.494761 0.494812 0.494600\n", + "MLP 0.511643 0.511899 0.511643 0.509566\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "from sklearn.model_selection import train_test_split, cross_validate\n", + "from sklearn.preprocessing import StandardScaler, LabelEncoder\n", + "from sklearn.neural_network import MLPClassifier\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.linear_model import LogisticRegression\n", + "import xgboost as xgb\n", + "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report\n", + "from sklearn.pipeline import Pipeline\n", + "\n", + "# --- Configuration ---\n", + "\n", + "TEST_SIZE = 0.3\n", + "RANDOM_STATE = 42\n", + "\n", + "# Assume synthetic_df is already loaded in the environment\n", + "# Validate target column\n", + "if TARGET_COLUMN not in synthetic_df.columns:\n", + " raise ValueError(f\"Target column '{TARGET_COLUMN}' not found.\")\n", + "\n", + "# Separate features and target\n", + "X_original = synthetic_df.drop(TARGET_COLUMN, axis=1)\n", + "y_original = synthetic_df[TARGET_COLUMN]\n", + "\n", + "# Encode target\n", + "le = LabelEncoder()\n", + "y = le.fit_transform(y_original)\n", + "print(f\"Class mapping: {dict(zip(le.classes_, le.transform(le.classes_)))}\")\n", + "\n", + "# Encode features\n", + "non_numeric_cols = X_original.select_dtypes(exclude=['number', 'bool']).columns\n", + "if len(non_numeric_cols) > 0:\n", + " X = pd.get_dummies(X_original, columns=non_numeric_cols, drop_first=True, dummy_na=False)\n", + "else:\n", + " X = X_original\n", + "\n", + "# Train-test split\n", + "try:\n", + " X_train, X_test, y_train, y_test = train_test_split(\n", + " X, y, test_size=TEST_SIZE, random_state=RANDOM_STATE, stratify=y\n", + " )\n", + "except ValueError:\n", + " X_train, X_test, y_train, y_test = train_test_split(\n", + " X, y, test_size=TEST_SIZE, random_state=RANDOM_STATE\n", + " )\n", + "\n", + "# Scale features\n", + "scaler = StandardScaler()\n", + "X_train_scaled = scaler.fit_transform(X_train)\n", + "X_test_scaled = scaler.transform(X_test)\n", + "\n", + "# Define models\n", + "models = {\n", + " \"Logistic Regression\": LogisticRegression(random_state=RANDOM_STATE, max_iter=1000),\n", + " \"Random Forest\": RandomForestClassifier(random_state=RANDOM_STATE, n_estimators=100),\n", + " \"XGBoost\": xgb.XGBClassifier(random_state=RANDOM_STATE, use_label_encoder=False, eval_metric='logloss'),\n", + " \"MLP\": MLPClassifier(random_state=RANDOM_STATE, max_iter=500, early_stopping=True, hidden_layer_sizes=(64, 32))\n", + "}\n", + "\n", + "# 3-Fold Cross-Validation\n", + "print(\"\\n--- 3-Fold Cross-Validation Results (All Folds) ---\")\n", + "pipelines = {\n", + " name: Pipeline([('scaler', StandardScaler()), ('clf', model)])\n", + " for name, model in models.items()\n", + "}\n", + "scoring = {\n", + " 'accuracy': 'accuracy',\n", + " 'precision': 'precision_weighted',\n", + " 'recall': 'recall_weighted',\n", + " 'f1': 'f1_weighted',\n", + "}\n", + "cv_results = {}\n", + "for name, pipe in pipelines.items():\n", + " scores = cross_validate(pipe, X, y, cv=3, scoring=scoring, return_train_score=False)\n", + " \n", + " print(f\"\\n{name} - Per Fold Scores:\")\n", + " for i in range(3): # 3 folds\n", + " print(f\" Fold {i+1}:\")\n", + " for metric in scoring:\n", + " score_value = scores[f'test_{metric}'][i]\n", + " print(f\" {metric.capitalize()}: {score_value:.4f}\")\n", + " \n", + " avg_scores = {metric: np.mean(scores[f'test_{metric}']) for metric in scoring}\n", + " cv_results[name] = avg_scores\n", + "\n", + "cv_df = pd.DataFrame(cv_results).T\n", + "print(\"\\nCross-Validation Summary (Averaged Across Folds):\")\n", + "print(cv_df)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "396c9370-dd1d-4be9-aff8-43faf32961ed", + "metadata": {}, + "source": [ + "## JSD and WD" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "id": "d79d9e26-8dcb-4fcf-ae5b-02d913328665", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average JSD across all columns: 0.018843\n", + "Average WD across all columns: 0.114645\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from scipy.spatial.distance import jensenshannon\n", + "\n", + "# --- Compute JSD and WD for each categorical column ---\n", + "jsd_results = {}\n", + "wd_results = {}\n", + "\n", + "for col in df.columns:\n", + " # Real and synthetic PMFs\n", + " real_freq = df[col].value_counts(normalize=True)\n", + " synth_freq = synthetic_df[col].value_counts(normalize=True)\n", + "\n", + " # Union of categories\n", + " cats = sorted(set(real_freq.index).union(synth_freq.index))\n", + "\n", + " # Build probability vectors (add epsilon to avoid zeros)\n", + " eps = 1e-12\n", + " P = np.array([real_freq.get(c, 0.0) for c in cats]) + eps\n", + " Q = np.array([synth_freq.get(c, 0.0) for c in cats]) + eps\n", + " P /= P.sum()\n", + " Q /= Q.sum()\n", + "\n", + " # Jensen–Shannon Divergence\n", + " jsd_results[col] = jensenshannon(P, Q) ** 2\n", + "\n", + " # Categorical WD with 0–1 cost is half the L1 distance\n", + " wd_results[col] = 0.5 * np.abs(P - Q).sum()\n", + "\n", + "# --- Compute and print averages ---\n", + "avg_jsd = np.mean(list(jsd_results.values()))\n", + "avg_wd = np.mean(list(wd_results.values()))\n", + "\n", + "print(f\"Average JSD across all columns: {avg_jsd:.6f}\")\n", + "print(f\"Average WD across all columns: {avg_wd:.6f}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "679712fb-cfad-4dad-89dc-c7428063cd3e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}