Skip to content

Commit 322f8e5

Browse files
committed
Add verbose parameter to GCM auto assign function
When using the auto assignment function, it now provides more details into the fitting process when verbose is set to True (default). Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
1 parent 86fcb28 commit 322f8e5

File tree

9 files changed

+131
-10
lines changed

9 files changed

+131
-10
lines changed

dowhy/gcm/auto.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from dowhy.gcm import config
1616
from dowhy.gcm.causal_mechanisms import AdditiveNoiseModel, ClassifierFCM
1717
from dowhy.gcm.causal_models import CAUSAL_MECHANISM, ProbabilisticCausalModel, validate_causal_model_assignment
18+
from dowhy.gcm.config import add_info_log_msg
1819
from dowhy.gcm.ml import (
1920
ClassificationModel,
2021
PredictionModel,
@@ -92,6 +93,7 @@ def assign_causal_mechanisms(
9293
based_on: pd.DataFrame,
9394
quality: AssignmentQuality = AssignmentQuality.GOOD,
9495
override_models: bool = False,
96+
verbose: bool = True,
9597
) -> None:
9698
"""Automatically assigns appropriate causal models. If causal models are already assigned to nodes and
9799
override_models is set to False, this function only validates the assignments with respect to the graph structure.
@@ -129,41 +131,70 @@ def assign_causal_mechanisms(
129131
130132
:return: None
131133
"""
134+
add_info_log_msg("----- Starting automatic model assignment -----", verbose)
132135
for node in causal_model.graph.nodes:
136+
add_info_log_msg("--Node: %s" % node, verbose)
133137
if not override_models and CAUSAL_MECHANISM in causal_model.graph.nodes[node]:
138+
add_info_log_msg(
139+
"Node %s already has a model assigned and the override parameter is false. "
140+
"Will skip this node." % node,
141+
verbose,
142+
)
134143
validate_causal_model_assignment(causal_model.graph, node)
135144
continue
136-
assign_causal_mechanism_node(causal_model, node, based_on, quality)
145+
146+
assign_causal_mechanism_node(causal_model, node, based_on, quality, verbose)
147+
148+
add_info_log_msg("----- Finished automatic model assignment -----", verbose)
137149

138150

139151
def assign_causal_mechanism_node(
140152
causal_model: ProbabilisticCausalModel,
141153
node: str,
142154
based_on: pd.DataFrame,
143155
quality: AssignmentQuality = AssignmentQuality.GOOD,
156+
verbose: bool = True,
144157
) -> None:
145158
if is_root_node(causal_model.graph, node):
146159
causal_model.set_causal_mechanism(node, EmpiricalDistribution())
160+
161+
add_info_log_msg(
162+
"Identified %s as a root node. Assigning %s to the node." % (node, causal_model.causal_mechanism(node)),
163+
verbose,
164+
)
147165
else:
166+
add_info_log_msg("Identified %s as a non-root node." % node, verbose)
167+
148168
prediction_model = select_model(
149169
based_on[get_ordered_predecessors(causal_model.graph, node)].to_numpy(),
150170
based_on[node].to_numpy(),
151171
quality,
172+
verbose,
152173
)
153174

154175
if isinstance(prediction_model, ClassificationModel):
155176
causal_model.set_causal_mechanism(node, ClassifierFCM(prediction_model))
177+
178+
add_info_log_msg("Assigning %s to the node %s." % (causal_model.causal_mechanism(node), node), verbose)
156179
else:
157180
causal_model.set_causal_mechanism(node, AdditiveNoiseModel(prediction_model))
158181

182+
add_info_log_msg(
183+
"Assigning a %s to the node %s." % (causal_model.causal_mechanism(node), node),
184+
verbose,
185+
)
186+
159187

160188
def select_model(
161-
X: np.ndarray, Y: np.ndarray, model_selection_quality: AssignmentQuality
189+
X: np.ndarray, Y: np.ndarray, model_selection_quality: AssignmentQuality, verbose: bool
162190
) -> Union[PredictionModel, ClassificationModel]:
191+
add_info_log_msg("Looking for the best prediction model based on the %s." % model_selection_quality, verbose)
163192
if model_selection_quality == AssignmentQuality.BEST:
164193
try:
165194
from dowhy.gcm.ml.autogluon import AutoGluonClassifier, AutoGluonRegressor
166195

196+
add_info_log_msg("Using an autogluon model", verbose)
197+
167198
if is_categorical(Y):
168199
return AutoGluonClassifier()
169200
else:
@@ -190,8 +221,12 @@ def select_model(
190221
list_of_classifier += [partial(create_polynom_logistic_regression_classifier, max_iter=1000)]
191222

192223
if is_categorical(Y):
224+
add_info_log_msg("The node seems to be categorical. Checking classification models...", verbose)
225+
193226
return find_best_model(list_of_classifier, X, Y, model_selection_splits=model_selection_splits)()
194227
else:
228+
add_info_log_msg("The node seems to be continuous. Checking regression models....", verbose)
229+
195230
return find_best_model(list_of_regressor, X, Y, model_selection_splits=model_selection_splits)()
196231

197232

@@ -253,19 +288,24 @@ def find_best_model(
253288
max_samples_per_split: int = 10000,
254289
model_selection_splits: int = 5,
255290
n_jobs: Optional[int] = None,
291+
verbose: bool = True,
256292
) -> Callable[[], PredictionModel]:
257293
n_jobs = config.default_n_jobs if n_jobs is None else n_jobs
258294

259295
X, Y = shape_into_2d(X, Y)
260296

261297
is_classification_problem = isinstance(prediction_model_factories[0](), ClassificationModel)
262298

299+
metric_name = "given"
300+
263301
if metric is None:
302+
metric_name = "(negative) F1"
264303
if is_classification_problem:
265304
metric = lambda y_true, y_preds: -metrics.f1_score(
266305
y_true, y_preds, average="macro", zero_division=0
267306
) # Higher score is better
268307
else:
308+
metric_name = "mean squared error (MSE)"
269309
metric = metrics.mean_squared_error
270310

271311
labelBinarizer = None
@@ -301,5 +341,22 @@ def estimate_average_score(prediction_model_factory: Callable[[], PredictionMode
301341
delayed(estimate_average_score)(prediction_model_factory, int(random_seed))
302342
for prediction_model_factory, random_seed in zip(prediction_model_factories, random_seeds)
303343
)
344+
sorted_results = sorted(zip(prediction_model_factories, average_metric_scores), key=lambda x: x[1])
345+
best_model = sorted_results[0]
346+
347+
add_info_log_msg(
348+
"Using %d splits and the %s metric. The results are:\n-%s"
349+
% (
350+
model_selection_splits,
351+
metric_name,
352+
"\n-".join(["%s: %s" % (str(result[0]()).replace("()", ""), str(result[1])) for result in sorted_results]),
353+
),
354+
verbose,
355+
)
356+
add_info_log_msg(
357+
"Based on this, selecting %s as the best model to minimize the %s metric."
358+
% (str(best_model[0]()).replace("()", ""), metric_name),
359+
verbose,
360+
)
304361

305-
return sorted(zip(prediction_model_factories, average_metric_scores), key=lambda x: x[1])[0][0]
362+
return best_model[0]

dowhy/gcm/causal_mechanisms.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import numpy as np
1111

1212
from dowhy.gcm.ml import ClassificationModel, PredictionModel
13-
from dowhy.gcm.ml.regression import InvertibleFunction
13+
from dowhy.gcm.ml.regression import InvertibleFunction, SklearnRegressionModel
1414
from dowhy.gcm.util.general import is_categorical, shape_into_2d
1515

1616

@@ -155,9 +155,14 @@ def evaluate(self, parent_samples: np.ndarray, noise_samples: np.ndarray) -> np.
155155
return self._invertible_function.evaluate(predictions + noise_samples)
156156

157157
def __str__(self) -> str:
158+
if isinstance(self._prediction_model, SklearnRegressionModel):
159+
prediction_model_string = self._prediction_model.sklearn_model.__class__.__name__
160+
else:
161+
prediction_model_string = self._prediction_model.__class__.__name__
162+
158163
return "%s with %s and an %s" % (
159164
self.__class__.__name__,
160-
self._prediction_model.__class__.__name__,
165+
prediction_model_string,
161166
self._invertible_function.__class__.__name__,
162167
)
163168

@@ -207,6 +212,14 @@ def __init__(self, prediction_model: PredictionModel, noise_model: Optional[Stoc
207212
def clone(self):
208213
return AdditiveNoiseModel(prediction_model=self.prediction_model.clone(), noise_model=self.noise_model.clone())
209214

215+
def __str__(self) -> str:
216+
if isinstance(self._prediction_model, SklearnRegressionModel):
217+
prediction_model_string = self._prediction_model.sklearn_model.__class__.__name__
218+
else:
219+
prediction_model_string = self._prediction_model.__class__.__name__
220+
221+
return "AdditiveNoiseModel using %s" % prediction_model_string
222+
210223

211224
class ProbabilityEstimatorModel(ABC):
212225
@abstractmethod
@@ -291,3 +304,6 @@ def get_class_names(self, class_indices: np.ndarray) -> List[str]:
291304
@property
292305
def classifier_model(self) -> ClassificationModel:
293306
return self._classifier_model
307+
308+
def __repr__(self):
309+
return "Classifier FCM based on %s" % self.classifier_model

dowhy/gcm/confidence_intervals_cms.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def fit_and_compute(
5353
bootstrap_training_data: pd.DataFrame,
5454
bootstrap_data_subset_size_fraction: float = 0.75,
5555
auto_assign_quality: Optional[auto.AssignmentQuality] = None,
56+
auto_assign_verbose: bool = False,
5657
*args,
5758
**kwargs,
5859
):
@@ -78,6 +79,8 @@ def fit_and_compute(
7879
:param auto_assign_quality: If a quality is provided, then the existing causal mechanisms in the given causal_model
7980
are overridden by new automatically inferred mechanisms based on the provided
8081
AssignmentQuality. If None is given, the existing assigned mechanisms are used.
82+
:param auto_assign_verbose: If True, the auto assignment logs additional information about the model selection
83+
process.
8184
:param args: Args passed through verbatim to the causal queries.
8285
:param kwargs: Keyword args passed through verbatim to the causal queries.
8386
:return: A tuple containing (1) the median of causal query results and (2) the confidence intervals.
@@ -94,7 +97,9 @@ def snapshot():
9497
]
9598

9699
if auto_assign_quality is not None:
97-
auto.assign_causal_mechanisms(causal_model_copy, sampled_data, auto_assign_quality, override_models=True)
100+
auto.assign_causal_mechanisms(
101+
causal_model_copy, sampled_data, auto_assign_quality, override_models=True, verbose=auto_assign_verbose
102+
)
98103

99104
fit(causal_model_copy, sampled_data)
100105
return f(causal_model_copy, *args, **kwargs)

dowhy/gcm/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1+
import logging
2+
import sys
3+
14
show_progress_bars = True
25
default_n_jobs = -1
36

7+
_logger = logging.getLogger(__name__)
8+
_logger.setLevel(logging.INFO)
9+
_logger.addHandler(logging.StreamHandler(sys.stdout))
10+
411

512
def enable_progress_bars():
613
global show_progress_bars
@@ -15,3 +22,8 @@ def disable_progress_bars():
1522
def set_default_n_jobs(n_jobs: int) -> None:
1623
global default_n_jobs
1724
default_n_jobs = n_jobs
25+
26+
27+
def add_info_log_msg(msg: str, verbose: bool) -> None:
28+
if verbose:
29+
_logger.info(msg)

dowhy/gcm/distribution_change.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def distribution_change(
101101
mechanism_change_test_significance_level: float = 0.05,
102102
mechanism_change_test_fdr_control_method: Optional[str] = "fdr_bh",
103103
auto_assignment_quality: Optional[AssignmentQuality] = None,
104+
auto_assignment_verbose: bool = False,
104105
return_additional_info: bool = False,
105106
shapley_config: Optional[ShapleyConfig] = None,
106107
graph_factory: Callable[[Any], DirectedGraph] = nx.DiGraph,
@@ -139,6 +140,8 @@ def distribution_change(
139140
old and new graph. However, they are re-fitted on the given data.
140141
If set to a valid assignment quality, new models are automatically assigned to the
141142
old and new graph based on the respective data.
143+
:param auto_assignment_verbose: If True, the auto assignment logs additional information about the model selection
144+
process.
142145
:param return_additional_info: If set to True, three additional items are returned: a dictionary indicating
143146
whether each node's mechanism changed, the causal DAG whose causal models are
144147
learned from old data, and the causal DAG whose causal models are learned from new
@@ -160,7 +163,13 @@ def distribution_change(
160163
if auto_assignment_quality is None:
161164
clone_causal_models(causal_model.graph, causal_model_old.graph)
162165
else:
163-
assign_causal_mechanisms(causal_model_old, old_data, override_models=True, quality=auto_assignment_quality)
166+
assign_causal_mechanisms(
167+
causal_model_old,
168+
old_data,
169+
override_models=True,
170+
quality=auto_assignment_quality,
171+
verbose=auto_assignment_verbose,
172+
)
164173
invariant_nodes = list(set(invariant_nodes).intersection(set(causal_graph_old.nodes)))
165174
_remove_invariant_nodes(invariant_nodes, causal_model_old, old_data, auto_assignment_quality)
166175

@@ -169,7 +178,13 @@ def distribution_change(
169178
if auto_assignment_quality is None:
170179
clone_causal_models(causal_graph_old, causal_model_new.graph)
171180
else:
172-
assign_causal_mechanisms(causal_model_new, new_data, override_models=True, quality=auto_assignment_quality)
181+
assign_causal_mechanisms(
182+
causal_model_new,
183+
new_data,
184+
override_models=True,
185+
quality=auto_assignment_quality,
186+
verbose=auto_assignment_verbose,
187+
)
173188

174189
mechanism_changes = _fit_accounting_for_mechanism_change(
175190
causal_model_old,

dowhy/gcm/fitting_sampling.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
33
Functions in this module should be considered experimental, meaning there might be breaking API changes in the future.
44
"""
5-
65
from typing import Any
76

87
import networkx as nx

dowhy/gcm/influence.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def intrinsic_causal_influence(
226226
num_samples_baseline: int = 2000,
227227
max_batch_size: int = 250,
228228
auto_assign_quality: auto.AssignmentQuality = auto.AssignmentQuality.GOOD,
229+
auto_assign_verbose: bool = False,
229230
shapley_config: Optional[ShapleyConfig] = None,
230231
) -> Dict[Any, float]:
231232
"""Computes the causal contribution of each upstream noise term of the target node (including the noise of the
@@ -268,6 +269,8 @@ def intrinsic_causal_influence(
268269
significant impact on the overall memory usage. If set to -1, all samples are used in one
269270
batch.
270271
:param auto_assign_quality: Auto assign quality for the 'approx' prediction_model option.
272+
:param auto_assign_verbose: If True, the auto assignment logs additional information about the model selection
273+
process if selected 'approx' for the prediction_model option.
271274
:param shapley_config: :class:`~dowhy.gcm.shapley.ShapleyConfig` for the Shapley estimator.
272275
:return: Intrinsic causal contribution of each ancestor node to the statistical property defined by the
273276
attribution_func of the target node.
@@ -292,6 +295,7 @@ def intrinsic_causal_influence(
292295
target_samples,
293296
auto_assign_quality,
294297
target_is_categorical,
298+
auto_assign_verbose,
295299
)
296300

297301
if attribution_func is None:
@@ -332,6 +336,7 @@ def intrinsic_causal_influence_sample(
332336
num_noise_feature_samples: int = 5000,
333337
max_batch_size: int = 100,
334338
auto_assign_quality: auto.AssignmentQuality = auto.AssignmentQuality.GOOD,
339+
auto_assign_verbose: bool = False,
335340
shapley_config: Optional[ShapleyConfig] = None,
336341
) -> List[Dict[Any, Any]]:
337342
"""Estimates the intrinsic causal impact of upstream nodes on a specified target_node, using the provided
@@ -367,6 +372,8 @@ def intrinsic_causal_influence_sample(
367372
:param max_batch_size: Maximum batch size for estimating multiple predictions at once. This has a significant influence on the
368373
overall memory usage. If set to -1, all samples are used in one batch.
369374
:param auto_assign_quality: Auto assign quality for the 'approx' prediction_model option.
375+
:param auto_assign_verbose: If True, the auto assignment logs additional information about the model selection
376+
process if selected 'approx' for the prediction_model option.
370377
:param shapley_config: :class:`~dowhy.gcm.shapley.ShapleyConfig` for the Shapley estimator.
371378
:return: A list of dictionaries indicating the intrinsic causal influence of a node on the target for a particular
372379
sample. This is, each dictionary belongs to one baseline sample.
@@ -403,6 +410,7 @@ def intrinsic_causal_influence_sample(
403410
target_samples,
404411
auto_assign_quality,
405412
False, # Currently only supports continues target since we need to reconstruct its noise term.
413+
auto_assign_verbose,
406414
)
407415

408416
shapley_vales = feature_relevance_sample(
@@ -465,6 +473,7 @@ def _get_icc_noise_function(
465473
target_samples: np.ndarray,
466474
auto_assign_quality: auto.AssignmentQuality,
467475
target_is_categorical: bool,
476+
auto_assign_verbose: bool,
468477
) -> Callable[[np.ndarray], np.ndarray]:
469478
if isinstance(prediction_model, str) and prediction_model not in ("approx", "exact"):
470479
raise ValueError(
@@ -478,7 +487,9 @@ def _get_icc_noise_function(
478487
return prediction_model.predict
479488

480489
if prediction_model == "approx":
481-
prediction_model = auto.select_model(noise_samples, target_samples, auto_assign_quality)
490+
prediction_model = auto.select_model(
491+
noise_samples, target_samples, auto_assign_quality, verbose=auto_assign_verbose
492+
)
482493
prediction_model.fit(noise_samples, target_samples)
483494

484495
if target_is_categorical:

dowhy/gcm/ml/regression.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ def clone(self):
5757
"""
5858
return SklearnRegressionModel(sklearn_mdl=sklearn.clone(self._sklearn_mdl))
5959

60+
def __str__(self):
61+
return str(self._sklearn_mdl)
62+
6063

6164
def create_linear_regressor_with_given_parameters(
6265
coefficients: np.ndarray, intercept: float = 0, **kwargs

dowhy/gcm/stochastic_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ def draw_samples(self, num_samples: int) -> np.ndarray:
202202
def clone(self):
203203
return EmpiricalDistribution()
204204

205+
def __str__(self):
206+
return "Empirical Distribution"
207+
205208

206209
class BayesianGaussianMixtureDistribution(StochasticModel):
207210
def __init__(self) -> None:

0 commit comments

Comments
 (0)