From c3e72401b807f120d99d4d9910c0b923b193457b Mon Sep 17 00:00:00 2001 From: SvenKlaassen Date: Tue, 9 Dec 2025 16:07:05 +0100 Subject: [PATCH] extend ssm sim and tuning --- doc/ssm/ssm_mar.qmd | 55 +++++- monte-cover/src/montecover/ssm/__init__.py | 2 + monte-cover/src/montecover/ssm/ssm_mar_ate.py | 9 + .../src/montecover/ssm/ssm_mar_ate_tune.py | 162 ++++++++++++++++++ results/ssm/ssm_mar_ate_config.yml | 2 +- results/ssm/ssm_mar_ate_coverage.csv | 38 ++-- results/ssm/ssm_mar_ate_metadata.csv | 2 +- results/ssm/ssm_mar_ate_tune_config.yml | 26 +++ results/ssm/ssm_mar_ate_tune_coverage.csv | 5 + results/ssm/ssm_mar_ate_tune_metadata.csv | 2 + scripts/ssm/ssm_mar_ate_config.yml | 2 +- scripts/ssm/ssm_mar_ate_tune.py | 13 ++ scripts/ssm/ssm_mar_ate_tune_config.yml | 30 ++++ 13 files changed, 325 insertions(+), 23 deletions(-) create mode 100644 monte-cover/src/montecover/ssm/ssm_mar_ate_tune.py create mode 100644 results/ssm/ssm_mar_ate_tune_config.yml create mode 100644 results/ssm/ssm_mar_ate_tune_coverage.csv create mode 100644 results/ssm/ssm_mar_ate_tune_metadata.csv create mode 100644 scripts/ssm/ssm_mar_ate_tune.py create mode 100644 scripts/ssm/ssm_mar_ate_tune_config.yml diff --git a/doc/ssm/ssm_mar.qmd b/doc/ssm/ssm_mar.qmd index 36334158..d6b2a33f 100644 --- a/doc/ssm/ssm_mar.qmd +++ b/doc/ssm/ssm_mar.qmd @@ -47,7 +47,7 @@ df = pd.read_csv("../../results/ssm/ssm_mar_ate_coverage.csv", index_col=None) assert df["repetition"].nunique() == 1 n_rep = df["repetition"].unique()[0] -display_columns = ["Learner g", "Learner m", "Learner pi", "Bias", "CI Length", "Coverage"] +display_columns = ["Learner g", "Learner m", "Learner pi", "Bias", "CI Length", "Coverage", "Loss g_d0", "Loss g_d1", "Loss m", "Loss pi"] ``` ```{python} @@ -73,3 +73,56 @@ generate_and_show_styled_table( coverage_highlight_cols=["Coverage"] ) ``` + + +## Tuning + +The simulations are based on the [make_ssm_data](https://docs.doubleml.org/stable/api/generated/doubleml.irm.datasets.make_ssm_data.html)-DGP with $2000$ observations. The simulation considers data under [missingness at random](https://docs.doubleml.org/stable/guide/models.html#missingness-at-random). This is only an example as the untuned version just relies on the default configuration. + +::: {.callout-note title="Metadata" collapse="true"} + +```{python} +#| echo: false +metadata_file = '../../results/ssm/ssm_mar_ate_tune_metadata.csv' +metadata_df = pd.read_csv(metadata_file) +print(metadata_df.T.to_string(header=False)) +``` + +::: + + +```{python} +#| echo: false + +# set up data and rename columns +df = pd.read_csv("../../results/ssm/ssm_mar_ate_tune_coverage.csv", index_col=None) + +assert df["repetition"].nunique() == 1 +n_rep = df["repetition"].unique()[0] + +display_columns = ["Learner g", "Learner m", "Learner pi", "Tuned", "Bias", "CI Length", "Coverage", "Loss g_d0", "Loss g_d1", "Loss m", "Loss pi"] +``` + +```{python} +#| echo: false +generate_and_show_styled_table( + main_df=df, + filters={"level": 0.95}, + display_cols=display_columns, + n_rep=n_rep, + level_col="level", + coverage_highlight_cols=["Coverage"] +) +``` + +```{python} +#| echo: false +generate_and_show_styled_table( + main_df=df, + filters={"level": 0.9}, + display_cols=display_columns, + n_rep=n_rep, + level_col="level", + coverage_highlight_cols=["Coverage"] +) +``` \ No newline at end of file diff --git a/monte-cover/src/montecover/ssm/__init__.py b/monte-cover/src/montecover/ssm/__init__.py index 86d02b5a..48f64ac7 100644 --- a/monte-cover/src/montecover/ssm/__init__.py +++ b/monte-cover/src/montecover/ssm/__init__.py @@ -1,9 +1,11 @@ """Monte Carlo coverage simulations for SSM.""" from montecover.ssm.ssm_mar_ate import SSMMarATECoverageSimulation +from montecover.ssm.ssm_mar_ate_tune import SSMMarATETuningCoverageSimulation from montecover.ssm.ssm_nonig_ate import SSMNonIgnorableATECoverageSimulation __all__ = [ "SSMMarATECoverageSimulation", + "SSMMarATETuningCoverageSimulation", "SSMNonIgnorableATECoverageSimulation", ] diff --git a/monte-cover/src/montecover/ssm/ssm_mar_ate.py b/monte-cover/src/montecover/ssm/ssm_mar_ate.py index fe6dc0b8..2081b9c2 100644 --- a/monte-cover/src/montecover/ssm/ssm_mar_ate.py +++ b/monte-cover/src/montecover/ssm/ssm_mar_ate.py @@ -65,6 +65,7 @@ def run_single_rep( score="missing-at-random", ) dml_model.fit() + nuisance_loss = dml_model.nuisance_loss result = { "coverage": [], @@ -86,6 +87,10 @@ def run_single_rep( "Learner m": learner_m_name, "Learner pi": learner_pi_name, "level": level, + "Loss g_d0": nuisance_loss["ml_g_d0"].mean(), + "Loss g_d1": nuisance_loss["ml_g_d1"].mean(), + "Loss m": nuisance_loss["ml_m"].mean(), + "Loss pi": nuisance_loss["ml_pi"].mean(), } ) for key, res in level_result.items(): @@ -103,6 +108,10 @@ def summarize_results(self): "Coverage": "mean", "CI Length": "mean", "Bias": "mean", + "Loss g_d0": "mean", + "Loss g_d1": "mean", + "Loss m": "mean", + "Loss pi": "mean", "repetition": "count", } diff --git a/monte-cover/src/montecover/ssm/ssm_mar_ate_tune.py b/monte-cover/src/montecover/ssm/ssm_mar_ate_tune.py new file mode 100644 index 00000000..cd76aed1 --- /dev/null +++ b/monte-cover/src/montecover/ssm/ssm_mar_ate_tune.py @@ -0,0 +1,162 @@ +from typing import Any, Dict, Optional + +import doubleml as dml +import optuna +from doubleml.irm.datasets import make_ssm_data + +from montecover.base import BaseSimulation +from montecover.utils import create_learner_from_config +from montecover.utils_tuning import lgbm_reg_params, lgbm_cls_params + + +class SSMMarATETuningCoverageSimulation(BaseSimulation): + """Simulation class for coverage properties of DoubleMLSSM with missing at random for ATE estimation with tuning.""" + + def __init__( + self, + config_file: str, + suppress_warnings: bool = True, + log_level: str = "INFO", + log_file: Optional[str] = None, + ): + super().__init__( + config_file=config_file, + suppress_warnings=suppress_warnings, + log_level=log_level, + log_file=log_file, + ) + + # Calculate oracle values + self._calculate_oracle_values() + # tuning specific settings + self._param_space = {"ml_g": lgbm_reg_params, "ml_m": lgbm_cls_params, "ml_pi": lgbm_cls_params} + + self._optuna_settings = { + "n_trials": 50, + "show_progress_bar": False, + "verbosity": optuna.logging.WARNING, # Suppress Optuna logs + } + + def _process_config_parameters(self): + """Process simulation-specific parameters from config""" + # Process ML models in parameter grid + assert ( + "learners" in self.dml_parameters + ), "No learners specified in the config file" + + required_learners = ["ml_g", "ml_m", "ml_pi"] + for learner in self.dml_parameters["learners"]: + for ml in required_learners: + assert ml in learner, f"No {ml} specified in the config file" + + def _calculate_oracle_values(self): + """Calculate oracle values for the simulation.""" + self.logger.info("Calculating oracle values") + + self.oracle_values = dict() + self.oracle_values["theta"] = self.dgp_parameters["theta"] + + def run_single_rep( + self, dml_data: dml.DoubleMLData, dml_params: Dict[str, Any] + ) -> Dict[str, Any]: + """Run a single repetition with the given parameters.""" + # Extract parameters + learner_config = dml_params["learners"] + learner_g_name, ml_g = create_learner_from_config(learner_config["ml_g"]) + learner_m_name, ml_m = create_learner_from_config(learner_config["ml_m"]) + learner_pi_name, ml_pi = create_learner_from_config(learner_config["ml_pi"]) + + # Model + dml_model = dml.DoubleMLSSM( + obj_dml_data=dml_data, + ml_g=ml_g, + ml_m=ml_m, + ml_pi=ml_pi, + score="missing-at-random", + ) + + dml_model_tuned = dml.DoubleMLSSM( + obj_dml_data=dml_data, + ml_g=ml_g, + ml_m=ml_m, + ml_pi=ml_pi, + score="missing-at-random", + ) + dml_model_tuned.tune_ml_models( + ml_param_space=self._param_space, + optuna_settings=self._optuna_settings, + ) + + result = { + "coverage": [], + } + for model in [dml_model, dml_model_tuned]: + model.fit() + nuisance_loss = model.nuisance_loss + for level in self.confidence_parameters["level"]: + level_result = dict() + level_result["coverage"] = self._compute_coverage( + thetas=model.coef, + oracle_thetas=self.oracle_values["theta"], + confint=model.confint(level=level), + joint_confint=None, + ) + + # add parameters to the result + for res_metric in level_result.values(): + res_metric.update( + { + "Learner g": learner_g_name, + "Learner m": learner_m_name, + "Learner pi": learner_pi_name, + "level": level, + "Tuned": model is dml_model_tuned, + "Loss g_d0": nuisance_loss["ml_g_d0"].mean(), + "Loss g_d1": nuisance_loss["ml_g_d1"].mean(), + "Loss m": nuisance_loss["ml_m"].mean(), + "Loss pi": nuisance_loss["ml_pi"].mean(), + } + ) + for key, res in level_result.items(): + result[key].append(res) + + return result + + def summarize_results(self): + """Summarize the simulation results.""" + self.logger.info("Summarizing simulation results") + + # Group by parameter combinations + groupby_cols = ["Learner g", "Learner m", "Learner pi", "level", "Tuned"] + aggregation_dict = { + "Coverage": "mean", + "CI Length": "mean", + "Bias": "mean", + "Loss g_d0": "mean", + "Loss g_d1": "mean", + "Loss m": "mean", + "Loss pi": "mean", + "repetition": "count", + } + + # Aggregate results (possibly multiple result dfs) + result_summary = dict() + for result_name, result_df in self.results.items(): + result_summary[result_name] = ( + result_df.groupby(groupby_cols).agg(aggregation_dict).reset_index() + ) + self.logger.debug(f"Summarized {result_name} results") + + return result_summary + + def _generate_dml_data(self, dgp_params: Dict[str, Any]) -> dml.DoubleMLData: + """Generate data for the simulation.""" + data = make_ssm_data( + theta=dgp_params["theta"], + n_obs=dgp_params["n_obs"], + dim_x=dgp_params["dim_x"], + mar=True, + return_type="DataFrame", + ) + dml_data = dml.data.DoubleMLSSMData(data, "y", "d", s_col="s") + return dml_data diff --git a/results/ssm/ssm_mar_ate_config.yml b/results/ssm/ssm_mar_ate_config.yml index 6c5f9261..91f9299a 100644 --- a/results/ssm/ssm_mar_ate_config.yml +++ b/results/ssm/ssm_mar_ate_config.yml @@ -1,5 +1,5 @@ simulation_parameters: - repetitions: 1000 + repetitions: 500 max_runtime: 19800 random_seed: 42 n_jobs: -2 diff --git a/results/ssm/ssm_mar_ate_coverage.csv b/results/ssm/ssm_mar_ate_coverage.csv index 86a35c6b..04411483 100644 --- a/results/ssm/ssm_mar_ate_coverage.csv +++ b/results/ssm/ssm_mar_ate_coverage.csv @@ -1,19 +1,19 @@ -Learner g,Learner m,Learner pi,level,Coverage,CI Length,Bias,repetition -LGBM Regr.,LGBM Clas.,LGBM Clas.,0.9,0.932,1.1096490607618295,0.2537753697730454,1000 -LGBM Regr.,LGBM Clas.,LGBM Clas.,0.95,0.973,1.3222284092249192,0.2537753697730454,1000 -LGBM Regr.,LGBM Clas.,Logistic,0.9,0.927,0.9202034782128122,0.21795635797590815,1000 -LGBM Regr.,LGBM Clas.,Logistic,0.95,0.98,1.0964900743711041,0.21795635797590815,1000 -LGBM Regr.,Logistic,LGBM Clas.,0.9,0.925,0.7859610293391193,0.17778968043589227,1000 -LGBM Regr.,Logistic,LGBM Clas.,0.95,0.97,0.9365303304293047,0.17778968043589227,1000 -LassoCV,LGBM Clas.,LGBM Clas.,0.9,0.937,1.06664811597998,0.24252593303321004,1000 -LassoCV,LGBM Clas.,LGBM Clas.,0.95,0.978,1.2709896231757172,0.24252593303321004,1000 -LassoCV,Logistic,Logistic,0.9,0.929,0.5870360234747519,0.1352163741591098,1000 -LassoCV,Logistic,Logistic,0.95,0.966,0.6994965660078569,0.1352163741591098,1000 -LassoCV,RF Clas.,RF Clas.,0.9,0.912,0.5180218173196215,0.12230587140635803,1000 -LassoCV,RF Clas.,RF Clas.,0.95,0.957,0.6172610671954945,0.12230587140635803,1000 -RF Regr.,Logistic,RF Clas.,0.9,0.907,0.5844857310821764,0.1396924993073941,1000 -RF Regr.,Logistic,RF Clas.,0.95,0.958,0.6964577051891234,0.1396924993073941,1000 -RF Regr.,RF Clas.,Logistic,0.9,0.917,0.5605579339809742,0.12931631817619357,1000 -RF Regr.,RF Clas.,Logistic,0.95,0.962,0.6679459763767203,0.12931631817619357,1000 -RF Regr.,RF Clas.,RF Clas.,0.9,0.91,0.5264596768137552,0.12286829856248617,1000 -RF Regr.,RF Clas.,RF Clas.,0.95,0.953,0.6273153969207249,0.12286829856248617,1000 +Learner g,Learner m,Learner pi,level,Coverage,CI Length,Bias,Loss g_d0,Loss g_d1,Loss m,Loss pi,repetition +LGBM Regr.,LGBM Clas.,LGBM Clas.,0.9,0.948,1.1131909010003327,0.24831669167598444,1.1741754097195407,1.163266373059539,0.7278358998629711,0.5898654853722893,500 +LGBM Regr.,LGBM Clas.,LGBM Clas.,0.95,0.978,1.3264487721755889,0.24831669167598444,1.1741754097195407,1.163266373059539,0.7278358998629711,0.5898654853722893,500 +LGBM Regr.,LGBM Clas.,Logistic,0.9,0.93,0.9374187726638454,0.21604194586480147,1.1744072789316446,1.1652216220418317,0.7267603771267632,0.5291022373942739,500 +LGBM Regr.,LGBM Clas.,Logistic,0.95,0.978,1.117003362942448,0.21604194586480147,1.1744072789316446,1.1652216220418317,0.7267603771267632,0.5291022373942739,500 +LGBM Regr.,Logistic,LGBM Clas.,0.9,0.92,0.7678357332655885,0.1751825561586277,1.1737576580451303,1.1653530986342269,0.6558418076554349,0.5887502027779351,500 +LGBM Regr.,Logistic,LGBM Clas.,0.95,0.97,0.9149327080444567,0.1751825561586277,1.1737576580451303,1.1653530986342269,0.6558418076554349,0.5887502027779351,500 +LassoCV,LGBM Clas.,LGBM Clas.,0.9,0.93,1.035562492786874,0.234860111118871,1.1269581915040925,1.1113999782983293,0.7263941131278941,0.5884520344901399,500 +LassoCV,LGBM Clas.,LGBM Clas.,0.95,0.972,1.2339488185125143,0.234860111118871,1.1269581915040925,1.1113999782983293,0.7263941131278941,0.5884520344901399,500 +LassoCV,Logistic,Logistic,0.9,0.924,0.60534225681428,0.1450666149444554,1.1265484870298061,1.1120018067070034,0.6555441137864938,0.5292784350095167,500 +LassoCV,Logistic,Logistic,0.95,0.968,0.7213097884430713,0.1450666149444554,1.1265484870298061,1.1120018067070034,0.6555441137864938,0.5292784350095167,500 +LassoCV,RF Clas.,RF Clas.,0.9,0.906,0.515238254750361,0.12187481387157952,1.1270654166551928,1.1113696711379897,0.6562679741378069,0.5360320761650595,500 +LassoCV,RF Clas.,RF Clas.,0.95,0.968,0.6139442478171186,0.12187481387157952,1.1270654166551928,1.1113696711379897,0.6562679741378069,0.5360320761650595,500 +RF Regr.,Logistic,RF Clas.,0.9,0.914,0.5796627327333214,0.14392844815850958,1.1433176517482628,1.1268174098574868,0.6556467509221011,0.5348783258079363,500 +RF Regr.,Logistic,RF Clas.,0.95,0.974,0.6907107481916357,0.14392844815850958,1.1433176517482628,1.1268174098574868,0.6556467509221011,0.5348783258079363,500 +RF Regr.,RF Clas.,Logistic,0.9,0.892,0.5525698364999777,0.13113512350389206,1.1437489701377335,1.1261227655007,0.6561396929395531,0.5281982168495511,500 +RF Regr.,RF Clas.,Logistic,0.95,0.964,0.6584275711452678,0.13113512350389206,1.1437489701377335,1.1261227655007,0.6561396929395531,0.5281982168495511,500 +RF Regr.,RF Clas.,RF Clas.,0.9,0.908,0.5220365874708703,0.1260710270982447,1.1429305895422688,1.126981021703091,0.6561581464421511,0.5363056603896544,500 +RF Regr.,RF Clas.,RF Clas.,0.95,0.968,0.6220449608950439,0.1260710270982447,1.1429305895422688,1.126981021703091,0.6561581464421511,0.5363056603896544,500 diff --git a/results/ssm/ssm_mar_ate_metadata.csv b/results/ssm/ssm_mar_ate_metadata.csv index 88c4652d..94586479 100644 --- a/results/ssm/ssm_mar_ate_metadata.csv +++ b/results/ssm/ssm_mar_ate_metadata.csv @@ -1,2 +1,2 @@ DoubleML Version,Script,Date,Total Runtime (minutes),Python Version,Config File -0.12.dev0,SSMMarATECoverageSimulation,2025-12-04 21:25,255.4998017311096,3.12.3,scripts/ssm/ssm_mar_ate_config.yml +0.12.dev0,SSMMarATECoverageSimulation,2025-12-09 15:29,13.831887169679005,3.12.9,scripts/ssm/ssm_mar_ate_config.yml diff --git a/results/ssm/ssm_mar_ate_tune_config.yml b/results/ssm/ssm_mar_ate_tune_config.yml new file mode 100644 index 00000000..04a35561 --- /dev/null +++ b/results/ssm/ssm_mar_ate_tune_config.yml @@ -0,0 +1,26 @@ +simulation_parameters: + repetitions: 200 + max_runtime: 19800 + random_seed: 42 + n_jobs: -2 +dgp_parameters: + theta: + - 1.0 + n_obs: + - 2000 + dim_x: + - 20 +learner_definitions: + lgbmr: &id001 + name: LGBM Regr. + lgbmc: &id002 + name: LGBM Clas. +dml_parameters: + learners: + - ml_g: *id001 + ml_m: *id002 + ml_pi: *id002 +confidence_parameters: + level: + - 0.95 + - 0.9 diff --git a/results/ssm/ssm_mar_ate_tune_coverage.csv b/results/ssm/ssm_mar_ate_tune_coverage.csv new file mode 100644 index 00000000..3847a25c --- /dev/null +++ b/results/ssm/ssm_mar_ate_tune_coverage.csv @@ -0,0 +1,5 @@ +Learner g,Learner m,Learner pi,level,Tuned,Coverage,CI Length,Bias,Loss g_d0,Loss g_d1,Loss m,Loss pi,repetition +LGBM Regr.,LGBM Clas.,LGBM Clas.,0.9,False,0.97,0.8371879519257365,0.17197778171949135,1.1814391661801091,1.1464920976592006,0.7343828611710117,0.6073260528329507,200 +LGBM Regr.,LGBM Clas.,LGBM Clas.,0.9,True,0.86,0.2378523929611131,0.05988962532541387,1.1234779915527298,1.1159925042652648,0.6491732261779875,0.5248390719094591,200 +LGBM Regr.,LGBM Clas.,LGBM Clas.,0.95,False,0.985,0.9975709735986762,0.17197778171949135,1.1814391661801091,1.1464920976592006,0.7343828611710117,0.6073260528329507,200 +LGBM Regr.,LGBM Clas.,LGBM Clas.,0.95,True,0.935,0.28341860710393996,0.05988962532541387,1.1234779915527298,1.1159925042652648,0.6491732261779875,0.5248390719094591,200 diff --git a/results/ssm/ssm_mar_ate_tune_metadata.csv b/results/ssm/ssm_mar_ate_tune_metadata.csv new file mode 100644 index 00000000..fbeb341e --- /dev/null +++ b/results/ssm/ssm_mar_ate_tune_metadata.csv @@ -0,0 +1,2 @@ +DoubleML Version,Script,Date,Total Runtime (minutes),Python Version,Config File +0.12.dev0,SSMMarATETuningCoverageSimulation,2025-12-09 15:59,15.788600877920787,3.12.9,scripts/ssm/ssm_mar_ate_tune_config.yml diff --git a/scripts/ssm/ssm_mar_ate_config.yml b/scripts/ssm/ssm_mar_ate_config.yml index ca857513..41e18211 100644 --- a/scripts/ssm/ssm_mar_ate_config.yml +++ b/scripts/ssm/ssm_mar_ate_config.yml @@ -1,7 +1,7 @@ # Simulation parameters for IRM ATE Coverage simulation_parameters: - repetitions: 1000 + repetitions: 500 max_runtime: 19800 # 5.5 hours in seconds random_seed: 42 n_jobs: -2 diff --git a/scripts/ssm/ssm_mar_ate_tune.py b/scripts/ssm/ssm_mar_ate_tune.py new file mode 100644 index 00000000..ab4021c1 --- /dev/null +++ b/scripts/ssm/ssm_mar_ate_tune.py @@ -0,0 +1,13 @@ +from montecover.ssm import SSMMarATETuningCoverageSimulation + +# Create and run simulation with config file +sim = SSMMarATETuningCoverageSimulation( + config_file="scripts/ssm/ssm_mar_ate_tune_config.yml", + log_level="INFO", + log_file="logs/ssm/ssm_mar_ate_tune_sim.log", +) +sim.run_simulation() +sim.save_results(output_path="results/ssm/", file_prefix="ssm_mar_ate_tune") + +# Save config file for reproducibility +sim.save_config("results/ssm/ssm_mar_ate_tune_config.yml") diff --git a/scripts/ssm/ssm_mar_ate_tune_config.yml b/scripts/ssm/ssm_mar_ate_tune_config.yml new file mode 100644 index 00000000..82ea6707 --- /dev/null +++ b/scripts/ssm/ssm_mar_ate_tune_config.yml @@ -0,0 +1,30 @@ +# Simulation parameters for IRM ATE Coverage + +simulation_parameters: + repetitions: 200 + max_runtime: 19800 # 5.5 hours in seconds + random_seed: 42 + n_jobs: -2 + +dgp_parameters: + theta: [1.0] # Treatment effect + n_obs: [2000] # Sample size + dim_x: [20] # Number of covariates + +# Define reusable learner configurations +learner_definitions: + lgbmr: &lgbmr + name: "LGBM Regr." + + lgbmc: &lgbmc + name: "LGBM Clas." + +dml_parameters: + learners: + - ml_g: *lgbmr + ml_m: *lgbmc + ml_pi: *lgbmc + + +confidence_parameters: + level: [0.95, 0.90] # Confidence levels