generated from VectorInstitute/aieng-template-uv
-
Notifications
You must be signed in to change notification settings - Fork 1
Ensemble attack: Meta classifier pipeline #37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 60 commits
Commits
Show all changes
62 commits
Select commit
Hold shift + click to select a range
8b7a781
Data processing and shadow models
fatemetkl 71178c0
Merged main
fatemetkl 5e4e3d9
mypy fixes
fatemetkl f0a849c
Removed rmia component
fatemetkl 7c4d47c
Added data collection, processing, and tests
fatemetkl 4f963b3
Small fixes
fatemetkl 7746ab3
Updated README
fatemetkl a812665
Merged main
fatemetkl a352d36
Removed some parts that should go with the next PR
fatemetkl bb662c7
Small fixes
fatemetkl ca0e2d5
Simplifying mypy legacy type check to only check .py files
fatemetkl ba22d92
Added an example for MIDST competition ensemble attack, updates tests
fatemetkl 6a6b99d
Fixed docstrings
fatemetkl 8e764e1
fix
fatemetkl e740b76
mypy fixes
fatemetkl 8574d22
mypy fix
fatemetkl aec7458
Added hydra and omegaconf to pyproject.toml
fatemetkl 5516989
David's comments, added a simple bash script to example
fatemetkl be2ac7d
Improved comments and function name
fatemetkl 3dadde3
Updated readme with diagrams
fatemetkl 0442cbb
Updated readme
fatemetkl aad598b
Updated readme
fatemetkl 86d1e0a
Modify file structure
fb3036a
Modify basic file structure
cfa821b
Modify high-level pipeline. (Needs a lot of cleanup.)
sarakodeiri eb38a75
Add DOMIAS calculation (#31)
sarakodeiri a6c06dd
Add Gower distance (#32)
sarakodeiri 653d9ac
Add gower to uv
sarakodeiri 62d0c1f
Metaclassifier Training (#36)
sarakodeiri cdf9f4a
Add predict and TPR@FPR
sarakodeiri 91b827e
Fix uv packages
sarakodeiri 4f1c55c
Add tests
sarakodeiri da4e1f8
Minor cleanup
sarakodeiri b655f6f
Merge branch 'main' into sk/meta_classifier
sarakodeiri ad59425
Add packages
sarakodeiri afce376
Merge branch 'sk/meta_classifier' of https://github.com/VectorInstitu…
sarakodeiri f3c2aee
Fix end of file (.gitignore)
sarakodeiri 377cb14
Merge branch 'main' into sk/meta_classifier
sarakodeiri 1ff1ec1
Change run script format
sarakodeiri d26a22e
Fix XGBoost Docstrings
sarakodeiri 421e045
Merge remote-tracking branch 'origin/main' into sk/meta_classifier
sarakodeiri 6feaf63
Add meta classifier type enum
sarakodeiri 02eb5ae
Resolved Marcelo's comments
sarakodeiri 0710806
Fix tests
sarakodeiri 5fecc68
Update
sarakodeiri 605052b
Merge branch 'main' into sk/meta_classifier
sarakodeiri e116f9d
Fix build
sarakodeiri 4f9b5b1
Remove ensemble_attack_examples
sarakodeiri 4ec2631
Ruff fix
sarakodeiri a068418
Apply David's comments
sarakodeiri b77b802
Apply David's comments, pt. 2.
sarakodeiri a2837ae
Apply Fatemeh's comments
sarakodeiri 418d65b
Remove bounds and col_type
sarakodeiri 2c034cc
Merge branch 'main' into sk/meta_classifier
sarakodeiri 651b2e2
Remove xgboost.py
sarakodeiri 129699d
Expand comment
sarakodeiri ac4e5f5
Merge branch 'main' into sk/meta_classifier
emersodb 5f17c07
Merge branch 'sk/meta_classifier' of https://github.com/VectorInstitu…
sarakodeiri 6bfc6c9
Resolve some comments
sarakodeiri d38ef66
Resolve all comments
sarakodeiri a63604d
Modify test
sarakodeiri 7572177
Minor fixes
sarakodeiri File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| { | ||
| "numerical": ["trans_date", "amount", "balance", "account"], | ||
| "categorical": ["trans_type", "operation", "k_symbol", "bank"], | ||
| "variable_to_predict": "trans_type" | ||
| } | ||
|
|
||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,44 +3,147 @@ | |
| provided resources and data. | ||
| """ | ||
|
|
||
| import pickle | ||
| from datetime import datetime | ||
| from logging import INFO | ||
| from pathlib import Path | ||
|
|
||
| import hydra | ||
| import numpy as np | ||
| from omegaconf import DictConfig | ||
|
|
||
| from examples.ensemble_attack.real_data_collection import collect_population_data_ensemble | ||
| from midst_toolkit.attacks.ensemble.blending import BlendingPlusPlus, MetaClassifierType | ||
| from midst_toolkit.attacks.ensemble.data_utils import load_dataframe | ||
| from midst_toolkit.attacks.ensemble.process_split_data import process_split_data | ||
| from midst_toolkit.common.logger import log | ||
|
|
||
|
|
||
| def run_data_processing(config: DictConfig) -> None: | ||
| """ | ||
| Function to run the data processing pipeline. | ||
|
|
||
| Args: | ||
| config: Configuration object set in config.yaml. | ||
| """ | ||
| log(INFO, "Running data processing pipeline...") | ||
| # Collect the real data from the MIDST challenge resources. | ||
| population_data = collect_population_data_ensemble( | ||
| midst_data_input_dir=Path(config.data_paths.midst_data_path), | ||
| data_processing_config=config.data_processing_config, | ||
| save_dir=Path(config.data_paths.population_path), | ||
| ) | ||
| # The following function saves the required dataframe splits in the specified processed_attack_data_path path. | ||
| process_split_data( | ||
| all_population_data=population_data, | ||
| processed_attack_data_path=Path(config.data_paths.processed_attack_data_path), | ||
| # TODO: column_to_stratify value is not documented in the original codebase. | ||
| column_to_stratify=config.data_processing_config.column_to_stratify, | ||
| num_total_samples=config.data_processing_config.population_sample_size, | ||
| random_seed=config.random_seed, | ||
| ) | ||
| log(INFO, "Data processing pipeline finished.") | ||
|
|
||
|
|
||
| def run_metaclassifier_training(config: DictConfig) -> None: | ||
| """ | ||
| Fuction to run the metaclassifier training and evaluation. | ||
|
|
||
| Args: | ||
| config: Configuration object set in config.yaml. | ||
| """ | ||
| log(INFO, "Running metaclassifier training...") | ||
| # Load the processed data splits. | ||
| df_meta_train = load_dataframe( | ||
| Path(config.data_paths.processed_attack_data_path), | ||
| "master_challenge_train.csv", | ||
| ) | ||
| y_meta_train = np.load( | ||
| Path(config.data_paths.processed_attack_data_path) / "master_challenge_train_labels.npy", | ||
| ) | ||
| df_meta_test = load_dataframe( | ||
| Path(config.data_paths.processed_attack_data_path), | ||
| "master_challenge_test.csv", | ||
| ) | ||
| y_meta_test = np.load( | ||
| Path(config.data_paths.processed_attack_data_path) / "master_challenge_test_labels.npy", | ||
| ) | ||
|
|
||
| # Synthetic data borrowed from the attack implementation repository. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd add a link to where this was borrowed from. |
||
| # From (https://github.com/CRCHUM-CITADEL/ensemble-mia/tree/main/input/tabddpm_black_box/meta_classifier) | ||
| # TODO: Change this file path to the path where the synthetic data is stored. | ||
| df_synthetic = load_dataframe( | ||
| Path(config.data_paths.processed_attack_data_path), | ||
| "synth.csv", | ||
| ) | ||
|
|
||
| df_reference = load_dataframe( | ||
| Path(config.data_paths.population_path), | ||
| "population_all_with_challenge_no_id.csv", | ||
| ) | ||
|
|
||
| # Fit the metaclassifier. | ||
| meta_classifier_enum = MetaClassifierType(config.metaclassifier.model_type) | ||
|
|
||
| # 1. Initialize the attacker | ||
| blending_attacker = BlendingPlusPlus( | ||
| config=config, meta_classifier_type=meta_classifier_enum, random_seed=config.random_seed | ||
| ) | ||
| log(INFO, f"{meta_classifier_enum} created with random seed {config.random_seed}, starting training...") | ||
|
|
||
| # 2. Train the attacker on the meta-train set | ||
|
|
||
| blending_attacker.fit( | ||
| df_train=df_meta_train, | ||
| y_train=y_meta_train, | ||
| df_synthetic=df_synthetic, | ||
| df_reference=df_reference, | ||
| use_gpu=config.metaclassifier.use_gpu, | ||
| epochs=config.metaclassifier.epochs, | ||
| ) | ||
|
|
||
| log(INFO, "Metaclassifier training finished.") | ||
|
|
||
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | ||
| model_filename = f"{timestamp}_{config.metaclassifier.model_type}_trained_metaclassifier.pkl" | ||
| with open(Path(config.model_paths.metaclassifier_model_path) / model_filename, "wb") as f: | ||
| pickle.dump(blending_attacker.trained_model, f) | ||
|
|
||
| log(INFO, "Metaclassifier model saved, starting evaluation...") | ||
|
|
||
| # 3. Get predictions on the test set | ||
| probabilities, pred_score = blending_attacker.predict( | ||
| df_test=df_meta_test, | ||
| df_synthetic=df_synthetic, | ||
| df_reference=df_reference, | ||
| y_test=y_meta_test, | ||
| ) | ||
|
|
||
| # Save the prediction probabilities | ||
| np.save( | ||
| Path(config.data_paths.attack_results_path) | ||
| / f"{timestamp}_{config.metaclassifier.model_type}_test_pred_proba.npy", | ||
| probabilities, | ||
| ) | ||
| log(INFO, "Test set prediction probabilities saved.") | ||
|
|
||
| if pred_score is not None: | ||
| log(INFO, f"TPR at FPR=0.1: {pred_score:.4f}") | ||
|
|
||
|
|
||
| @hydra.main(config_path=".", config_name="config", version_base=None) | ||
| def main(cfg: DictConfig) -> None: | ||
| def main(config: DictConfig) -> None: | ||
| """ | ||
| Run the Ensemble Attack example pipeline. | ||
| As the first step, data processing is done. | ||
|
|
||
| Args: | ||
| cfg: Attack OmegaConf DictConfig object. | ||
| config: Attack configuration as an OmegaConf DictConfig object. | ||
| """ | ||
| if cfg.pipeline.run_data_processing: | ||
| log(INFO, "Running data processing pipeline...") | ||
| # Collect the real data from the MIDST challenge resources. | ||
| population_data = collect_population_data_ensemble( | ||
| midst_data_input_dir=Path(cfg.data_paths.midst_data_path), | ||
| data_processing_config=cfg.data_processing_config, | ||
| save_dir=Path(cfg.data_paths.population_path), | ||
| ) | ||
| # The following function saves the required dataframe splits in the specified processed_attack_data_path path. | ||
| process_split_data( | ||
| all_population_data=population_data, | ||
| processed_attack_data_path=Path(cfg.data_paths.processed_attack_data_path), | ||
| # TODO: column_to_stratify value is not documented in the original codebase. | ||
| column_to_stratify=cfg.data_processing_config.column_to_stratify, | ||
| num_total_samples=cfg.data_processing_config.population_sample_size, | ||
| random_seed=cfg.random_seed, | ||
| ) | ||
| log(INFO, "Data processing pipeline finished.") | ||
| if config.pipeline.run_data_processing: | ||
| run_data_processing(config) | ||
| if config.pipeline.run_metaclassifier_training: | ||
| run_metaclassifier_training(config) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is not true anymore? I see docstrings in that function right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think by "original codebase", @fatemetkl meant the submission repository (link), but I'm not sure what the "TODO" is for. My guess is we test things with stratified columns specified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
The "original codebase" in this comment refers to the attack submission. I should have added a link. So sorry for the confusion! I will fix it in my PR.
As Sara mentioned, since this parameter wasn’t documented in the original attack codebase, I added a TODO to experiment with other columns in case the one I specified isn’t the correct one.