Skip to content

Commit 30412cc

Browse files
committed
Add explicit support for discrete ANMs
- Add new Discrete Additive Noise Model class that enforces the outputs to be discrete. This should help in generating more consistent data. - As part of this, revised the auto assignment function and revised its docstring. - Revise the auto assignment summary. - Revise the evaluation summary. Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
1 parent 7c015b7 commit 30412cc

File tree

12 files changed

+539
-143
lines changed

12 files changed

+539
-143
lines changed

docs/source/user_guide/modeling_gcm/model_evaluation.rst

Lines changed: 93 additions & 50 deletions
Large diffs are not rendered by default.

dowhy/gcm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
MedianDeviationScorer,
1111
RescaledMedianCDFQuantileScorer,
1212
)
13-
from .causal_mechanisms import AdditiveNoiseModel, ClassifierFCM, PostNonlinearModel
13+
from .causal_mechanisms import AdditiveNoiseModel, ClassifierFCM, DiscreteAdditiveNoiseModel, PostNonlinearModel
1414
from .causal_models import InvertibleStructuralCausalModel, ProbabilisticCausalModel, StructuralCausalModel
1515
from .confidence_intervals import confidence_intervals
1616
from .confidence_intervals_cms import bootstrap_sampling, fit_and_compute

dowhy/gcm/auto.py

Lines changed: 153 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from sklearn.preprocessing import MultiLabelBinarizer
1515

1616
from dowhy.gcm import config
17-
from dowhy.gcm.causal_mechanisms import AdditiveNoiseModel, ClassifierFCM
17+
from dowhy.gcm.causal_mechanisms import AdditiveNoiseModel, ClassifierFCM, DiscreteAdditiveNoiseModel
1818
from dowhy.gcm.causal_models import CAUSAL_MECHANISM, ProbabilisticCausalModel, validate_causal_model_assignment
1919
from dowhy.gcm.ml import (
2020
ClassificationModel,
@@ -48,6 +48,7 @@
4848
auto_apply_encoders,
4949
auto_fit_encoders,
5050
is_categorical,
51+
is_discrete,
5152
set_random_seed,
5253
shape_into_2d,
5354
)
@@ -108,7 +109,43 @@ def add_model_performance(self, node, model: str, performance: str, metric_name:
108109
def __str__(self):
109110
summary_strings = []
110111

111-
summary_strings.append("Analyzed %d nodes." % len(list(self._nodes)))
112+
summary_strings.append(
113+
"When using this auto assignment function, the given data is used to automatically assign a causal "
114+
"mechanism to each node. Note that causal mechanisms can also be customized and assigned manually.\n"
115+
"The following types of causal mechanisms are considered for the automatic selection:"
116+
)
117+
summary_strings.append("\nIf root node:")
118+
summary_strings.append(
119+
"An empirical distribution, i.e., the distribution is represented by randomly sampling from the provided "
120+
"data. This provides a flexible and non-parametric way to model the marginal distribution and is valid for "
121+
"all types of data modalities."
122+
)
123+
summary_strings.append("\nIf non-root node and the data is continuous:")
124+
summary_strings.append(
125+
"Additive Noise Models (ANM) of the form X_i = f(PA_i) + N_i, where PA_i are the "
126+
"parents of X_i and the unobserved noise N_i is assumed to be independent of PA_i."
127+
"To select the best model for f, different regression models are evaluated and the model "
128+
"with the smallest mean squared error is selected."
129+
"Note that minimizing the mean squared error here is equivalent to selecting the best "
130+
"choice of an ANM."
131+
)
132+
summary_strings.append("\nIf non-root node and the data is discrete:")
133+
summary_strings.append(
134+
"Discrete Additive Noise Models have almost the same definition as non-discrete ANMs, but come with an "
135+
"additional constraint for f to only return discrete values.\n"
136+
"Note that 'discrete' here refers to numerical values with an order. If the data is categorical, consider "
137+
"representing them as strings to ensure proper model selection."
138+
)
139+
summary_strings.append("\nIf non-root node and the data is categorical:")
140+
summary_strings.append(
141+
"A functional causal model based on a classifier, i.e., X_i = f(PA_i, N_i).\n"
142+
"Here, N_i follows a uniform distribution on [0, 1] and is used to randomly sample a "
143+
"class (category) using the conditional probability distribution produced by a "
144+
"classification model."
145+
"Here, different model classes are evaluated using the (negative) F1 score and the best"
146+
" performing model class is selected."
147+
)
148+
summary_strings.append("\nIn total, %d nodes were analyzed:" % len(list(self._nodes)))
112149

113150
for node in self._nodes:
114151
summary_strings.append("\n--- Node: %s" % node)
@@ -123,11 +160,13 @@ def __str__(self):
123160
for (model, performance, metric_name) in self._nodes[node]["model_performances"]:
124161
summary_strings.append("%s: %s" % (str(model()).replace("()", ""), str(performance)))
125162

126-
summary_strings.append(
127-
"Based on the type of causal mechanism, the model with the lowest metric value "
128-
"represents the best choice."
129-
)
130-
163+
summary_strings.append(
164+
"\n===Note===\nNote, based on the selected auto assignment quality, the set of " "evaluated models changes."
165+
)
166+
summary_strings.append(
167+
"For more insights toward the quality of the fitted graphical causal model, consider "
168+
"using the evaluate_causal_model function after fitting the causal mechanisms."
169+
)
131170
return "\n".join(summary_strings)
132171

133172

@@ -137,26 +176,86 @@ def assign_causal_mechanisms(
137176
quality: AssignmentQuality = AssignmentQuality.GOOD,
138177
override_models: bool = False,
139178
) -> AutoAssignmentSummary:
140-
"""Automatically assigns appropriate causal models. If causal models are already assigned to nodes and
141-
override_models is set to False, this function only validates the assignments with respect to the graph structure.
142-
Here, the validation checks whether root nodes have StochasticModels and non-root ConditionalStochasticModels
143-
assigned.
179+
"""Automatically assigns appropriate causal mechanisms to nodes. If causal mechanisms are already assigned to nodes
180+
and override_models is set to False, this function only validates the assignments with respect to the graph
181+
structure. This is, the validation checks whether root nodes have StochasticModels and non-root
182+
ConditionalStochasticModels assigned.
183+
184+
The following types of causal mechanisms are considered for the automatic selection:
185+
186+
If root node:
187+
An empirical distribution, i.e., the distribution is represented by randomly sampling from the provided data.
188+
This provides a flexible and non-parametric way to model the marginal distribution and is valid for all types of
189+
data modalities.
190+
191+
If non-root node and the data is continuous:
192+
Additive Noise Models (ANM) of the form X_i = f(PA_i) + N_i, where PA_i are the parents of X_i and the unobserved
193+
noise N_i is assumed to be independent of PA_i. To select the best model for f, different regression models are
194+
evaluated and the model with the smallest mean squared error is selected. Note that minimizing the mean squared
195+
error here is equivalent to selecting the best choice of an ANM.
196+
197+
If non-root node and the data is discrete:
198+
Discrete Additive Noise Models have almost the same definition as non-discrete ANMs, but come with an additional
199+
constraint to return discrete values. Note that 'discrete' here refers to numerical values with an order. If the
200+
data is categorical, consider representing them as strings to ensure proper model selection.
201+
202+
If non-root node and the data is categorical:
203+
A functional causal model based on a classifier, i.e., X_i = f(PA_i, N_i).
204+
Here, N_i follows a uniform distribution on [0, 1] and is used to randomly sample a class (category) using the
205+
conditional probability distribution produced by a classification model. Here, different model classes are evaluated
206+
using the (negative) F1 score and the best performing model class is selected.
207+
208+
The current model zoo is:
209+
210+
With "GOOD" quality:
211+
Numerical:
212+
- Linear Regressor
213+
- Linear Regressor with polynomial features
214+
- Histogram Gradient Boost Regressor
215+
216+
Categorical:
217+
- Logistic Regressor
218+
- Logistic Regressor with polynomial features
219+
- Histogram Gradient Boost Classifier
220+
221+
With "BETTER" quality:
222+
Numerical:
223+
- Linear Regressor
224+
- Linear Regressor with polynomial features
225+
- Gradient Boost Regressor
226+
- Ridge Regressor
227+
- Lasso Regressor
228+
- Random Forest Regressor
229+
- Support Vector Regressor
230+
- Extra Trees Regressor
231+
- KNN Regressor
232+
- Ada Boost Regressor
233+
234+
Categorical:
235+
- Logistic Regressor
236+
- Logistic Regressor with polynomial features
237+
- Histogram Gradient Boost Classifier
238+
- Random Forest Classifier
239+
- Extra Trees Classifier
240+
- Support Vector Classifier
241+
- KNN Classifier
242+
- Gaussian Naive Bayes Classifier
243+
- Ada Boost Classifier
244+
245+
With "BEST" quality:
246+
An auto ML model based on AutoGluon (optional dependency, needs to be installed).
144247
145248
:param causal_model: The causal model to whose nodes to assign causal models.
146249
:param based_on: Jointly sampled data corresponding to the nodes of the given graph.
147250
:param quality: AssignmentQuality for the automatic model selection and model accuracy. This changes the type of
148-
prediction model and time spent on the selection. Options are:
149-
- AssignmentQuality.GOOD: Compares a linear, polynomial and gradient boost model on small test-training split
150-
of the data. The best performing model is then selected.
251+
prediction model and time spent on the selection. See the docstring for a list of potential models.
252+
The options for the quality are:
253+
- AssignmentQuality.GOOD: Only a small set of models are evaluated.
151254
Model selection speed: Fast
152255
Model training speed: Fast
153256
Model inference speed: Fast
154257
Model accuracy: Medium
155-
- AssignmentQuality.BETTER: Compares multiple model types and uses the one with the best performance
156-
averaged over multiple splits of the training data. By default, the model with the smallest root mean
157-
squared error is selected for regression problems and the model with the highest F1 score is selected for
158-
classification problems. For a list of possible models, see _LIST_OF_POTENTIAL_REGRESSORS_BETTER and
159-
_LIST_OF_POTENTIAL_CLASSIFIERS_BETTER, respectively.
258+
- AssignmentQuality.BETTER: A larger set of models are evaluated.
160259
Model selection speed: Medium
161260
Model training speed: Fast
162261
Model inference speed: Fast
@@ -168,8 +267,8 @@ def assign_causal_mechanisms(
168267
Model training speed: Slow
169268
Model inference speed: Slow-Medium
170269
Model accuracy: Best
171-
:param override_models: If set to True, existing model assignments are replaced with automatically selected
172-
ones. If set to False, the assigned models are only validated with respect to the graph
270+
:param override_models: If set to True, existing mechanism assignments are replaced with automatically selected
271+
ones. If set to False, the assigned mechanisms are only validated with respect to the graph
173272
structure.
174273
:return: A summary object containing details about the model selection process.
175274
"""
@@ -179,7 +278,8 @@ def assign_causal_mechanisms(
179278
if not override_models and CAUSAL_MECHANISM in causal_model.graph.nodes[node]:
180279
auto_assignment_summary.add_node_log_message(
181280
node,
182-
"Node %s already has a model assigned and the override parameter is False. Skipping this node." % node,
281+
"Node %s already has a causal mechanism assigned and the override parameter is False. Skipping this "
282+
"node." % node,
183283
)
184284
validate_causal_model_assignment(causal_model.graph, node)
185285
continue
@@ -189,16 +289,36 @@ def assign_causal_mechanisms(
189289
if is_root_node(causal_model.graph, node):
190290
auto_assignment_summary.add_node_log_message(
191291
node,
192-
"Node %s is a root node. Assigning '%s' to the node representing the marginal distribution."
292+
"Node %s is a root node. Therefore, assigning '%s' to the node representing the marginal distribution."
193293
% (node, causal_model.causal_mechanism(node)),
194294
)
195295
else:
296+
data_type = "continuous"
297+
if isinstance(causal_model.causal_mechanism(node), ClassifierFCM):
298+
data_type = "categorical"
299+
elif isinstance(causal_model.causal_mechanism(node), DiscreteAdditiveNoiseModel):
300+
data_type = "discrete"
301+
196302
auto_assignment_summary.add_node_log_message(
197303
node,
198-
"Node %s is a non-root node. Assigning '%s' to the node." % (node, causal_model.causal_mechanism(node)),
304+
"Node %s is a non-root node with %s data. Assigning '%s' to the node."
305+
% (
306+
node,
307+
data_type,
308+
causal_model.causal_mechanism(node),
309+
),
199310
)
200311

201-
if isinstance(causal_model.causal_mechanism(node), AdditiveNoiseModel):
312+
if isinstance(causal_model.causal_mechanism(node), DiscreteAdditiveNoiseModel):
313+
auto_assignment_summary.add_node_log_message(
314+
node,
315+
"This represents the discrete causal relationship as "
316+
+ str(node)
317+
+ " := f("
318+
+ ",".join([str(parent) for parent in get_ordered_predecessors(causal_model.graph, node)])
319+
+ ") + N.",
320+
)
321+
elif isinstance(causal_model.causal_mechanism(node), AdditiveNoiseModel):
202322
auto_assignment_summary.add_node_log_message(
203323
node,
204324
"This represents the causal relationship as "
@@ -230,16 +350,21 @@ def assign_causal_mechanism_node(
230350
causal_model.set_causal_mechanism(node, EmpiricalDistribution())
231351
model_performances = []
232352
else:
353+
node_data = based_on[node].to_numpy()
354+
233355
best_model, model_performances = select_model(
234356
based_on[get_ordered_predecessors(causal_model.graph, node)].to_numpy(),
235-
based_on[node].to_numpy(),
357+
node_data,
236358
quality,
237359
)
238360

239361
if isinstance(best_model, ClassificationModel):
240362
causal_model.set_causal_mechanism(node, ClassifierFCM(best_model))
241363
else:
242-
causal_model.set_causal_mechanism(node, AdditiveNoiseModel(best_model))
364+
if is_discrete(node_data):
365+
causal_model.set_causal_mechanism(node, DiscreteAdditiveNoiseModel(best_model))
366+
else:
367+
causal_model.set_causal_mechanism(node, AdditiveNoiseModel(best_model))
243368

244369
return model_performances
245370

@@ -263,7 +388,7 @@ def select_model(
263388
elif model_selection_quality == AssignmentQuality.GOOD:
264389
list_of_regressor = list(_LIST_OF_POTENTIAL_REGRESSORS_GOOD)
265390
list_of_classifier = list(_LIST_OF_POTENTIAL_CLASSIFIERS_GOOD)
266-
model_selection_splits = 2
391+
model_selection_splits = 5
267392
elif model_selection_quality == AssignmentQuality.BETTER:
268393
list_of_regressor = list(_LIST_OF_POTENTIAL_REGRESSORS_BETTER)
269394
list_of_classifier = list(_LIST_OF_POTENTIAL_CLASSIFIERS_BETTER)

dowhy/gcm/causal_mechanisms.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from dowhy.gcm.ml import ClassificationModel, PredictionModel
1010
from dowhy.gcm.ml.regression import InvertibleFunction, SklearnRegressionModel
11-
from dowhy.gcm.util.general import is_categorical, shape_into_2d
11+
from dowhy.gcm.util.general import is_categorical, is_discrete, shape_into_2d
1212

1313

1414
class StochasticModel(ABC):
@@ -218,6 +218,52 @@ def __str__(self) -> str:
218218
return "AdditiveNoiseModel using %s" % prediction_model_string
219219

220220

221+
class DiscreteAdditiveNoiseModel(AdditiveNoiseModel):
222+
"""Implements a discrete ANM. This is, it follows a normal ANM of the form Y = f(X) + N, where N is assumed to be
223+
independent of X and f is forced to output discrete values. To allow for flexible models, f can be any regression
224+
model and the output will be rounded to a discrete value accordingly. Note that this remains a valid additive noise
225+
model, but assumes that Y can take any integer value."""
226+
227+
def fit(self, X: np.ndarray, Y: np.ndarray) -> None:
228+
if not is_discrete(Y):
229+
raise ValueError("Cannot fit a discrete ANM to non-discrete target values!")
230+
231+
X, Y = shape_into_2d(X, Y)
232+
Y = Y.astype(np.int32)
233+
234+
self._prediction_model.fit(X=X, Y=Y)
235+
self._noise_model.fit(self._rounded_prediction(X) - Y)
236+
237+
def evaluate(self, parent_samples: np.ndarray, noise_samples: np.ndarray) -> np.ndarray:
238+
if not is_discrete(noise_samples):
239+
raise ValueError("Noise values have to be discrete!")
240+
241+
parent_samples, noise_samples = shape_into_2d(parent_samples, noise_samples)
242+
predictions = shape_into_2d(self._rounded_prediction(parent_samples))
243+
244+
return predictions + noise_samples
245+
246+
def estimate_noise(self, target_samples: np.ndarray, parent_samples: np.ndarray) -> np.ndarray:
247+
if not is_discrete(target_samples):
248+
raise ValueError("Target samples have to be discrete!")
249+
250+
target_samples, parent_samples = shape_into_2d(target_samples, parent_samples)
251+
252+
return target_samples - self._rounded_prediction(parent_samples)
253+
254+
def _rounded_prediction(self, X: np.ndarray) -> np.ndarray:
255+
return np.round(self._prediction_model.predict(X).astype(float)).astype(np.int32)
256+
257+
def clone(self):
258+
return DiscreteAdditiveNoiseModel(
259+
prediction_model=self.prediction_model.clone(),
260+
noise_model=self.noise_model.clone(),
261+
)
262+
263+
def __str__(self) -> str:
264+
return "Discrete " + super().__str__()
265+
266+
221267
class ProbabilityEstimatorModel(ABC):
222268
@abstractmethod
223269
def estimate_probabilities(self, parent_samples: np.ndarray) -> np.ndarray:

0 commit comments

Comments
 (0)