Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion doc/ssm/ssm_mar.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ df = pd.read_csv("../../results/ssm/ssm_mar_ate_coverage.csv", index_col=None)
assert df["repetition"].nunique() == 1
n_rep = df["repetition"].unique()[0]

display_columns = ["Learner g", "Learner m", "Learner pi", "Bias", "CI Length", "Coverage"]
display_columns = ["Learner g", "Learner m", "Learner pi", "Bias", "CI Length", "Coverage", "Loss g_d0", "Loss g_d1", "Loss m", "Loss pi"]
```

```{python}
Expand All @@ -73,3 +73,56 @@ generate_and_show_styled_table(
coverage_highlight_cols=["Coverage"]
)
```


## Tuning

The simulations are based on the [make_ssm_data](https://docs.doubleml.org/stable/api/generated/doubleml.irm.datasets.make_ssm_data.html)-DGP with $2000$ observations. The simulation considers data under [missingness at random](https://docs.doubleml.org/stable/guide/models.html#missingness-at-random). This is only an example as the untuned version just relies on the default configuration.

::: {.callout-note title="Metadata" collapse="true"}

```{python}
#| echo: false
metadata_file = '../../results/ssm/ssm_mar_ate_tune_metadata.csv'
metadata_df = pd.read_csv(metadata_file)
print(metadata_df.T.to_string(header=False))
```

:::


```{python}
#| echo: false

# set up data and rename columns
df = pd.read_csv("../../results/ssm/ssm_mar_ate_tune_coverage.csv", index_col=None)

assert df["repetition"].nunique() == 1
n_rep = df["repetition"].unique()[0]

display_columns = ["Learner g", "Learner m", "Learner pi", "Tuned", "Bias", "CI Length", "Coverage", "Loss g_d0", "Loss g_d1", "Loss m", "Loss pi"]
```

```{python}
#| echo: false
generate_and_show_styled_table(
main_df=df,
filters={"level": 0.95},
display_cols=display_columns,
n_rep=n_rep,
level_col="level",
coverage_highlight_cols=["Coverage"]
)
```

```{python}
#| echo: false
generate_and_show_styled_table(
main_df=df,
filters={"level": 0.9},
display_cols=display_columns,
n_rep=n_rep,
level_col="level",
coverage_highlight_cols=["Coverage"]
)
```
2 changes: 2 additions & 0 deletions monte-cover/src/montecover/ssm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Monte Carlo coverage simulations for SSM."""

from montecover.ssm.ssm_mar_ate import SSMMarATECoverageSimulation
from montecover.ssm.ssm_mar_ate_tune import SSMMarATETuningCoverageSimulation
from montecover.ssm.ssm_nonig_ate import SSMNonIgnorableATECoverageSimulation

__all__ = [
"SSMMarATECoverageSimulation",
"SSMMarATETuningCoverageSimulation",
"SSMNonIgnorableATECoverageSimulation",
]
9 changes: 9 additions & 0 deletions monte-cover/src/montecover/ssm/ssm_mar_ate.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def run_single_rep(
score="missing-at-random",
)
dml_model.fit()
nuisance_loss = dml_model.nuisance_loss

result = {
"coverage": [],
Expand All @@ -86,6 +87,10 @@ def run_single_rep(
"Learner m": learner_m_name,
"Learner pi": learner_pi_name,
"level": level,
"Loss g_d0": nuisance_loss["ml_g_d0"].mean(),
"Loss g_d1": nuisance_loss["ml_g_d1"].mean(),
"Loss m": nuisance_loss["ml_m"].mean(),
"Loss pi": nuisance_loss["ml_pi"].mean(),
}
)
for key, res in level_result.items():
Expand All @@ -103,6 +108,10 @@ def summarize_results(self):
"Coverage": "mean",
"CI Length": "mean",
"Bias": "mean",
"Loss g_d0": "mean",
"Loss g_d1": "mean",
"Loss m": "mean",
"Loss pi": "mean",
"repetition": "count",
}

Expand Down
162 changes: 162 additions & 0 deletions monte-cover/src/montecover/ssm/ssm_mar_ate_tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from typing import Any, Dict, Optional

import doubleml as dml
import optuna
from doubleml.irm.datasets import make_ssm_data

from montecover.base import BaseSimulation
from montecover.utils import create_learner_from_config
from montecover.utils_tuning import lgbm_reg_params, lgbm_cls_params


class SSMMarATETuningCoverageSimulation(BaseSimulation):
"""Simulation class for coverage properties of DoubleMLSSM with missing at random for ATE estimation with tuning."""

def __init__(
self,
config_file: str,
suppress_warnings: bool = True,
log_level: str = "INFO",
log_file: Optional[str] = None,
):
super().__init__(
config_file=config_file,
suppress_warnings=suppress_warnings,
log_level=log_level,
log_file=log_file,
)

# Calculate oracle values
self._calculate_oracle_values()
# tuning specific settings
self._param_space = {"ml_g": lgbm_reg_params, "ml_m": lgbm_cls_params, "ml_pi": lgbm_cls_params}

self._optuna_settings = {
"n_trials": 50,
"show_progress_bar": False,
"verbosity": optuna.logging.WARNING, # Suppress Optuna logs
}

def _process_config_parameters(self):
"""Process simulation-specific parameters from config"""
# Process ML models in parameter grid
assert (
"learners" in self.dml_parameters
), "No learners specified in the config file"

required_learners = ["ml_g", "ml_m", "ml_pi"]
for learner in self.dml_parameters["learners"]:
for ml in required_learners:
assert ml in learner, f"No {ml} specified in the config file"

def _calculate_oracle_values(self):
"""Calculate oracle values for the simulation."""
self.logger.info("Calculating oracle values")

self.oracle_values = dict()
self.oracle_values["theta"] = self.dgp_parameters["theta"]

def run_single_rep(
self, dml_data: dml.DoubleMLData, dml_params: Dict[str, Any]
) -> Dict[str, Any]:
"""Run a single repetition with the given parameters."""
# Extract parameters
learner_config = dml_params["learners"]
learner_g_name, ml_g = create_learner_from_config(learner_config["ml_g"])
learner_m_name, ml_m = create_learner_from_config(learner_config["ml_m"])
learner_pi_name, ml_pi = create_learner_from_config(learner_config["ml_pi"])

# Model
dml_model = dml.DoubleMLSSM(
obj_dml_data=dml_data,
ml_g=ml_g,
ml_m=ml_m,
ml_pi=ml_pi,
score="missing-at-random",
)

dml_model_tuned = dml.DoubleMLSSM(
obj_dml_data=dml_data,
ml_g=ml_g,
ml_m=ml_m,
ml_pi=ml_pi,
score="missing-at-random",
)
dml_model_tuned.tune_ml_models(
ml_param_space=self._param_space,
optuna_settings=self._optuna_settings,
)

result = {
"coverage": [],
}
for model in [dml_model, dml_model_tuned]:
model.fit()
nuisance_loss = model.nuisance_loss
for level in self.confidence_parameters["level"]:
level_result = dict()
level_result["coverage"] = self._compute_coverage(
thetas=model.coef,
oracle_thetas=self.oracle_values["theta"],
confint=model.confint(level=level),
joint_confint=None,
)

# add parameters to the result
for res_metric in level_result.values():
res_metric.update(
{
"Learner g": learner_g_name,
"Learner m": learner_m_name,
"Learner pi": learner_pi_name,
"level": level,
"Tuned": model is dml_model_tuned,
"Loss g_d0": nuisance_loss["ml_g_d0"].mean(),
"Loss g_d1": nuisance_loss["ml_g_d1"].mean(),
"Loss m": nuisance_loss["ml_m"].mean(),
"Loss pi": nuisance_loss["ml_pi"].mean(),
}
)
for key, res in level_result.items():
result[key].append(res)

return result

def summarize_results(self):
"""Summarize the simulation results."""
self.logger.info("Summarizing simulation results")

# Group by parameter combinations
groupby_cols = ["Learner g", "Learner m", "Learner pi", "level", "Tuned"]
aggregation_dict = {
"Coverage": "mean",
"CI Length": "mean",
"Bias": "mean",
"Loss g_d0": "mean",
"Loss g_d1": "mean",
"Loss m": "mean",
"Loss pi": "mean",
"repetition": "count",
}

# Aggregate results (possibly multiple result dfs)
result_summary = dict()
for result_name, result_df in self.results.items():
result_summary[result_name] = (
result_df.groupby(groupby_cols).agg(aggregation_dict).reset_index()
)
self.logger.debug(f"Summarized {result_name} results")

return result_summary

def _generate_dml_data(self, dgp_params: Dict[str, Any]) -> dml.DoubleMLData:
"""Generate data for the simulation."""
data = make_ssm_data(
theta=dgp_params["theta"],
n_obs=dgp_params["n_obs"],
dim_x=dgp_params["dim_x"],
mar=True,
return_type="DataFrame",
)
dml_data = dml.data.DoubleMLSSMData(data, "y", "d", s_col="s")
return dml_data
2 changes: 1 addition & 1 deletion results/ssm/ssm_mar_ate_config.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
simulation_parameters:
repetitions: 1000
repetitions: 500
max_runtime: 19800
random_seed: 42
n_jobs: -2
Expand Down
38 changes: 19 additions & 19 deletions results/ssm/ssm_mar_ate_coverage.csv
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
Learner g,Learner m,Learner pi,level,Coverage,CI Length,Bias,repetition
LGBM Regr.,LGBM Clas.,LGBM Clas.,0.9,0.932,1.1096490607618295,0.2537753697730454,1000
LGBM Regr.,LGBM Clas.,LGBM Clas.,0.95,0.973,1.3222284092249192,0.2537753697730454,1000
LGBM Regr.,LGBM Clas.,Logistic,0.9,0.927,0.9202034782128122,0.21795635797590815,1000
LGBM Regr.,LGBM Clas.,Logistic,0.95,0.98,1.0964900743711041,0.21795635797590815,1000
LGBM Regr.,Logistic,LGBM Clas.,0.9,0.925,0.7859610293391193,0.17778968043589227,1000
LGBM Regr.,Logistic,LGBM Clas.,0.95,0.97,0.9365303304293047,0.17778968043589227,1000
LassoCV,LGBM Clas.,LGBM Clas.,0.9,0.937,1.06664811597998,0.24252593303321004,1000
LassoCV,LGBM Clas.,LGBM Clas.,0.95,0.978,1.2709896231757172,0.24252593303321004,1000
LassoCV,Logistic,Logistic,0.9,0.929,0.5870360234747519,0.1352163741591098,1000
LassoCV,Logistic,Logistic,0.95,0.966,0.6994965660078569,0.1352163741591098,1000
LassoCV,RF Clas.,RF Clas.,0.9,0.912,0.5180218173196215,0.12230587140635803,1000
LassoCV,RF Clas.,RF Clas.,0.95,0.957,0.6172610671954945,0.12230587140635803,1000
RF Regr.,Logistic,RF Clas.,0.9,0.907,0.5844857310821764,0.1396924993073941,1000
RF Regr.,Logistic,RF Clas.,0.95,0.958,0.6964577051891234,0.1396924993073941,1000
RF Regr.,RF Clas.,Logistic,0.9,0.917,0.5605579339809742,0.12931631817619357,1000
RF Regr.,RF Clas.,Logistic,0.95,0.962,0.6679459763767203,0.12931631817619357,1000
RF Regr.,RF Clas.,RF Clas.,0.9,0.91,0.5264596768137552,0.12286829856248617,1000
RF Regr.,RF Clas.,RF Clas.,0.95,0.953,0.6273153969207249,0.12286829856248617,1000
Learner g,Learner m,Learner pi,level,Coverage,CI Length,Bias,Loss g_d0,Loss g_d1,Loss m,Loss pi,repetition
LGBM Regr.,LGBM Clas.,LGBM Clas.,0.9,0.948,1.1131909010003327,0.24831669167598444,1.1741754097195407,1.163266373059539,0.7278358998629711,0.5898654853722893,500
LGBM Regr.,LGBM Clas.,LGBM Clas.,0.95,0.978,1.3264487721755889,0.24831669167598444,1.1741754097195407,1.163266373059539,0.7278358998629711,0.5898654853722893,500
LGBM Regr.,LGBM Clas.,Logistic,0.9,0.93,0.9374187726638454,0.21604194586480147,1.1744072789316446,1.1652216220418317,0.7267603771267632,0.5291022373942739,500
LGBM Regr.,LGBM Clas.,Logistic,0.95,0.978,1.117003362942448,0.21604194586480147,1.1744072789316446,1.1652216220418317,0.7267603771267632,0.5291022373942739,500
LGBM Regr.,Logistic,LGBM Clas.,0.9,0.92,0.7678357332655885,0.1751825561586277,1.1737576580451303,1.1653530986342269,0.6558418076554349,0.5887502027779351,500
LGBM Regr.,Logistic,LGBM Clas.,0.95,0.97,0.9149327080444567,0.1751825561586277,1.1737576580451303,1.1653530986342269,0.6558418076554349,0.5887502027779351,500
LassoCV,LGBM Clas.,LGBM Clas.,0.9,0.93,1.035562492786874,0.234860111118871,1.1269581915040925,1.1113999782983293,0.7263941131278941,0.5884520344901399,500
LassoCV,LGBM Clas.,LGBM Clas.,0.95,0.972,1.2339488185125143,0.234860111118871,1.1269581915040925,1.1113999782983293,0.7263941131278941,0.5884520344901399,500
LassoCV,Logistic,Logistic,0.9,0.924,0.60534225681428,0.1450666149444554,1.1265484870298061,1.1120018067070034,0.6555441137864938,0.5292784350095167,500
LassoCV,Logistic,Logistic,0.95,0.968,0.7213097884430713,0.1450666149444554,1.1265484870298061,1.1120018067070034,0.6555441137864938,0.5292784350095167,500
LassoCV,RF Clas.,RF Clas.,0.9,0.906,0.515238254750361,0.12187481387157952,1.1270654166551928,1.1113696711379897,0.6562679741378069,0.5360320761650595,500
LassoCV,RF Clas.,RF Clas.,0.95,0.968,0.6139442478171186,0.12187481387157952,1.1270654166551928,1.1113696711379897,0.6562679741378069,0.5360320761650595,500
RF Regr.,Logistic,RF Clas.,0.9,0.914,0.5796627327333214,0.14392844815850958,1.1433176517482628,1.1268174098574868,0.6556467509221011,0.5348783258079363,500
RF Regr.,Logistic,RF Clas.,0.95,0.974,0.6907107481916357,0.14392844815850958,1.1433176517482628,1.1268174098574868,0.6556467509221011,0.5348783258079363,500
RF Regr.,RF Clas.,Logistic,0.9,0.892,0.5525698364999777,0.13113512350389206,1.1437489701377335,1.1261227655007,0.6561396929395531,0.5281982168495511,500
RF Regr.,RF Clas.,Logistic,0.95,0.964,0.6584275711452678,0.13113512350389206,1.1437489701377335,1.1261227655007,0.6561396929395531,0.5281982168495511,500
RF Regr.,RF Clas.,RF Clas.,0.9,0.908,0.5220365874708703,0.1260710270982447,1.1429305895422688,1.126981021703091,0.6561581464421511,0.5363056603896544,500
RF Regr.,RF Clas.,RF Clas.,0.95,0.968,0.6220449608950439,0.1260710270982447,1.1429305895422688,1.126981021703091,0.6561581464421511,0.5363056603896544,500
2 changes: 1 addition & 1 deletion results/ssm/ssm_mar_ate_metadata.csv
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
DoubleML Version,Script,Date,Total Runtime (minutes),Python Version,Config File
0.12.dev0,SSMMarATECoverageSimulation,2025-12-04 21:25,255.4998017311096,3.12.3,scripts/ssm/ssm_mar_ate_config.yml
0.12.dev0,SSMMarATECoverageSimulation,2025-12-09 15:29,13.831887169679005,3.12.9,scripts/ssm/ssm_mar_ate_config.yml
26 changes: 26 additions & 0 deletions results/ssm/ssm_mar_ate_tune_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
simulation_parameters:
repetitions: 200
max_runtime: 19800
random_seed: 42
n_jobs: -2
dgp_parameters:
theta:
- 1.0
n_obs:
- 2000
dim_x:
- 20
learner_definitions:
lgbmr: &id001
name: LGBM Regr.
lgbmc: &id002
name: LGBM Clas.
dml_parameters:
learners:
- ml_g: *id001
ml_m: *id002
ml_pi: *id002
confidence_parameters:
level:
- 0.95
- 0.9
5 changes: 5 additions & 0 deletions results/ssm/ssm_mar_ate_tune_coverage.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Learner g,Learner m,Learner pi,level,Tuned,Coverage,CI Length,Bias,Loss g_d0,Loss g_d1,Loss m,Loss pi,repetition
LGBM Regr.,LGBM Clas.,LGBM Clas.,0.9,False,0.97,0.8371879519257365,0.17197778171949135,1.1814391661801091,1.1464920976592006,0.7343828611710117,0.6073260528329507,200
LGBM Regr.,LGBM Clas.,LGBM Clas.,0.9,True,0.86,0.2378523929611131,0.05988962532541387,1.1234779915527298,1.1159925042652648,0.6491732261779875,0.5248390719094591,200
LGBM Regr.,LGBM Clas.,LGBM Clas.,0.95,False,0.985,0.9975709735986762,0.17197778171949135,1.1814391661801091,1.1464920976592006,0.7343828611710117,0.6073260528329507,200
LGBM Regr.,LGBM Clas.,LGBM Clas.,0.95,True,0.935,0.28341860710393996,0.05988962532541387,1.1234779915527298,1.1159925042652648,0.6491732261779875,0.5248390719094591,200
2 changes: 2 additions & 0 deletions results/ssm/ssm_mar_ate_tune_metadata.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
DoubleML Version,Script,Date,Total Runtime (minutes),Python Version,Config File
0.12.dev0,SSMMarATETuningCoverageSimulation,2025-12-09 15:59,15.788600877920787,3.12.9,scripts/ssm/ssm_mar_ate_tune_config.yml
2 changes: 1 addition & 1 deletion scripts/ssm/ssm_mar_ate_config.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Simulation parameters for IRM ATE Coverage

simulation_parameters:
repetitions: 1000
repetitions: 500
max_runtime: 19800 # 5.5 hours in seconds
random_seed: 42
n_jobs: -2
Expand Down
13 changes: 13 additions & 0 deletions scripts/ssm/ssm_mar_ate_tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from montecover.ssm import SSMMarATETuningCoverageSimulation

# Create and run simulation with config file
sim = SSMMarATETuningCoverageSimulation(
config_file="scripts/ssm/ssm_mar_ate_tune_config.yml",
log_level="INFO",
log_file="logs/ssm/ssm_mar_ate_tune_sim.log",
)
sim.run_simulation()
sim.save_results(output_path="results/ssm/", file_prefix="ssm_mar_ate_tune")

# Save config file for reproducibility
sim.save_config("results/ssm/ssm_mar_ate_tune_config.yml")
Loading