-
Notifications
You must be signed in to change notification settings - Fork 0
Ensemble attack: Meta classifier pipeline #37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8b7a781
71178c0
5e4e3d9
f0a849c
7c4d47c
4f963b3
7746ab3
a812665
a352d36
bb662c7
ca0e2d5
ba22d92
6a6b99d
8e764e1
e740b76
8574d22
aec7458
5516989
be2ac7d
3dadde3
0442cbb
aad598b
86d1e0a
fb3036a
cfa821b
eb38a75
a6c06dd
653d9ac
62d0c1f
cdf9f4a
91b827e
4f1c55c
da4e1f8
b655f6f
ad59425
afce376
f3c2aee
377cb14
1ff1ec1
d26a22e
421e045
6feaf63
02eb5ae
0710806
5fecc68
605052b
e116f9d
4f9b5b1
4ec2631
a068418
b77b802
a2837ae
418d65b
2c034cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,44 +3,146 @@ | |
provided resources and data. | ||
""" | ||
|
||
import pickle | ||
from datetime import datetime | ||
from logging import INFO | ||
from pathlib import Path | ||
|
||
import hydra | ||
import numpy as np | ||
from omegaconf import DictConfig | ||
|
||
from examples.ensemble_attack.real_data_collection import collect_population_data_ensemble | ||
from midst_toolkit.attacks.ensemble.blending import BlendingPlusPlus, MetaClassifierType | ||
from midst_toolkit.attacks.ensemble.data_utils import load_dataframe | ||
from midst_toolkit.attacks.ensemble.process_split_data import process_split_data | ||
from midst_toolkit.common.logger import log | ||
|
||
|
||
def run_data_processing(config: DictConfig) -> None: | ||
""" | ||
Function to run the data processing pipeline. | ||
Args: | ||
config: Configuration object set in config.yaml. | ||
""" | ||
log(INFO, "Running data processing pipeline...") | ||
# Collect the real data from the MIDST challenge resources. | ||
population_data = collect_population_data_ensemble( | ||
midst_data_input_dir=Path(config.data_paths.midst_data_path), | ||
data_processing_config=config.data_processing_config, | ||
save_dir=Path(config.data_paths.population_path), | ||
) | ||
# The following function saves the required dataframe splits in the specified processed_attack_data_path path. | ||
process_split_data( | ||
all_population_data=population_data, | ||
processed_attack_data_path=Path(config.data_paths.processed_attack_data_path), | ||
# TODO: column_to_stratify value is not documented in the original codebase. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe this is not true anymore? I see docstrings in that function right now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think by "original codebase", @fatemetkl meant the submission repository (link), but I'm not sure what the "TODO" is for. My guess is we test things with stratified columns specified. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, As Sara mentioned, since this parameter wasn’t documented in the original attack codebase, I added a TODO to experiment with other columns in case the one I specified isn’t the correct one. |
||
column_to_stratify=config.data_processing_config.column_to_stratify, | ||
num_total_samples=config.data_processing_config.population_sample_size, | ||
random_seed=config.random_seed, | ||
) | ||
log(INFO, "Data processing pipeline finished.") | ||
|
||
|
||
def run_metaclassifier_training(config: DictConfig) -> None: | ||
""" | ||
Fuction to run the metaclassifier training and evaluation. | ||
Args: | ||
config: Configuration object set in config.yaml. | ||
""" | ||
log(INFO, "Running metaclassifier training...") | ||
# Load the processed data splits. | ||
df_meta_train = load_dataframe( | ||
Path(config.data_paths.processed_attack_data_path), | ||
"master_challenge_train.csv", | ||
) | ||
y_meta_train = np.load( | ||
Path(config.data_paths.processed_attack_data_path) / "master_challenge_train_labels.npy", | ||
) | ||
df_meta_test = load_dataframe( | ||
Path(config.data_paths.processed_attack_data_path), | ||
"master_challenge_test.csv", | ||
) | ||
y_meta_test = np.load( | ||
Path(config.data_paths.processed_attack_data_path) / "master_challenge_test_labels.npy", | ||
) | ||
|
||
# Synthetic data borrowed from the attack implementation repository. | ||
# TODO: Change this file path to the path where the synthetic data is stored. | ||
df_synthetic = load_dataframe( | ||
Path(config.data_paths.processed_attack_data_path), | ||
"synth.csv", | ||
) | ||
|
||
df_reference = load_dataframe( | ||
Path(config.data_paths.population_path), | ||
"population_all_with_challenge_no_id.csv", | ||
) | ||
|
||
# Fit the metaclassifier. | ||
meta_classifier_enum = MetaClassifierType(config.metaclassifier.model_type) | ||
|
||
# 1. Initialize the attacker | ||
blending_attacker = BlendingPlusPlus( | ||
data_configs=config.data_configs, meta_classifier_type=meta_classifier_enum, random_seed=config.random_seed | ||
) | ||
log(INFO, f"{meta_classifier_enum} created with random seed {config.random_seed}, starting training...") | ||
|
||
# 2. Train the attacker on the meta-train set | ||
|
||
blending_attacker.fit( | ||
df_train=df_meta_train, | ||
y_train=y_meta_train, | ||
df_synthetic=df_synthetic, | ||
df_reference=df_reference, | ||
use_gpu=config.metaclassifier.use_gpu, | ||
epochs=config.metaclassifier.epochs, | ||
) | ||
|
||
log(INFO, "Metaclassifier training finished.") | ||
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | ||
model_filename = f"{timestamp}_{config.metaclassifier.model_type}_trained_metaclassifier.pkl" | ||
with open(Path(config.model_paths.metaclassifier_model_path) / model_filename, "wb") as f: | ||
pickle.dump(blending_attacker.trained_model, f) | ||
|
||
log(INFO, "Metaclassifier model saved, starting evaluation...") | ||
|
||
# 3. Get predictions on the test set | ||
probabilities, pred_score = blending_attacker.predict( | ||
df_test=df_meta_test, | ||
df_synthetic=df_synthetic, | ||
df_reference=df_reference, | ||
y_test=y_meta_test, | ||
) | ||
|
||
# Save the prediction probabilities | ||
np.save( | ||
Path(config.data_paths.attack_results_path) | ||
/ f"{timestamp}_{config.metaclassifier.model_type}_test_pred_proba.npy", | ||
probabilities, | ||
) | ||
log(INFO, "Test set prediction probabilities saved.") | ||
|
||
if pred_score is not None: | ||
log(INFO, f"TPR at FPR=0.1: {pred_score:.4f}") | ||
|
||
|
||
@hydra.main(config_path=".", config_name="config", version_base=None) | ||
def main(cfg: DictConfig) -> None: | ||
def main(config: DictConfig) -> None: | ||
""" | ||
Run the Ensemble Attack example pipeline. | ||
As the first step, data processing is done. | ||
Args: | ||
cfg: Attack OmegaConf DictConfig object. | ||
config: Attack configuration as an OmegaConf DictConfig object. | ||
""" | ||
if cfg.pipeline.run_data_processing: | ||
log(INFO, "Running data processing pipeline...") | ||
# Collect the real data from the MIDST challenge resources. | ||
population_data = collect_population_data_ensemble( | ||
midst_data_input_dir=Path(cfg.data_paths.midst_data_path), | ||
data_processing_config=cfg.data_processing_config, | ||
save_dir=Path(cfg.data_paths.population_path), | ||
) | ||
# The following function saves the required dataframe splits in the specified processed_attack_data_path path. | ||
process_split_data( | ||
all_population_data=population_data, | ||
processed_attack_data_path=Path(cfg.data_paths.processed_attack_data_path), | ||
# TODO: column_to_stratify value is not documented in the original codebase. | ||
column_to_stratify=cfg.data_processing_config.column_to_stratify, | ||
num_total_samples=cfg.data_processing_config.population_sample_size, | ||
random_seed=cfg.random_seed, | ||
) | ||
log(INFO, "Data processing pipeline finished.") | ||
if config.pipeline.run_data_processing: | ||
run_data_processing(config) | ||
if config.pipeline.run_metaclassifier_training: | ||
run_metaclassifier_training(config) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,3 +36,6 @@ ignore_missing_imports = True | |
|
||
[mypy-category_encoders.*] | ||
ignore_missing_imports = True | ||
|
||
[mypy-gower.*] | ||
ignore_missing_imports = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I kept getting a "Skipping analyzing "gower": module is installed, but missing library stubs or py.typed" error, and no stub files are available. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is a good idea, but can we create a
trans_metadata.json
orreal_metadata.json
file to store this information? I am suggesting this because we have several data config files liketrans_domain.json
andinfo.json
with similar type of metadata information. We can keep thisconfig.yaml
for only attack pipeline related configurations. Then we can set the path to this metadata json file here, similar totrans_domain_file_path
and load it in BlendingPlusPlus init.