Skip to content

Commit

Permalink
By default, the GCM fit method now also evaluates the fitted models
Browse files Browse the repository at this point in the history
The models are evaluated based on the KL divergence between the generated distribution and the observed one.

Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
  • Loading branch information
bloebp committed Oct 17, 2023
1 parent 558e6fd commit 2a800bd
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 11 deletions.
8 changes: 4 additions & 4 deletions dowhy/gcm/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def assign_causal_mechanism_node(
causal_model: ProbabilisticCausalModel,
node: str,
based_on: pd.DataFrame,
quality: AssignmentQuality = AssignmentQuality.GOOD,
verbose: bool = True,
quality: AssignmentQuality,
verbose: bool,
) -> None:
if is_root_node(causal_model.graph, node):
causal_model.set_causal_mechanism(node, EmpiricalDistribution())
Expand Down Expand Up @@ -226,11 +226,11 @@ def select_model(
if is_categorical(Y):
add_info_log_msg("The node seems to be categorical. Checking classification models...", verbose)

return find_best_model(list_of_classifier, X, Y, model_selection_splits=model_selection_splits)()
return find_best_model(list_of_classifier, X, Y, model_selection_splits=model_selection_splits, verbose=verbose)()
else:
add_info_log_msg("The node seems to be continuous. Checking regression models....", verbose)

return find_best_model(list_of_regressor, X, Y, model_selection_splits=model_selection_splits)()
return find_best_model(list_of_regressor, X, Y, model_selection_splits=model_selection_splits, verbose=verbose)()


def has_linear_relationship(X: np.ndarray, Y: np.ndarray, max_num_samples: int = 3000) -> bool:
Expand Down
10 changes: 5 additions & 5 deletions dowhy/gcm/confidence_intervals_cms.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def fit_and_compute(
bootstrap_training_data: pd.DataFrame,
bootstrap_data_subset_size_fraction: float = 0.75,
auto_assign_quality: Optional[auto.AssignmentQuality] = None,
auto_assign_verbose: bool = False,
verbose: bool = False,
*args,
**kwargs,
):
Expand All @@ -79,8 +79,8 @@ def fit_and_compute(
:param auto_assign_quality: If a quality is provided, then the existing causal mechanisms in the given causal_model
are overridden by new automatically inferred mechanisms based on the provided
AssignmentQuality. If None is given, the existing assigned mechanisms are used.
:param auto_assign_verbose: If True, the auto assignment logs additional information about the model selection
process.
:param verbose: If True, the auto assignment logs additional information about the model selection process and the
fitting will provide additional information about the model qualities.
:param args: Args passed through verbatim to the causal queries.
:param kwargs: Keyword args passed through verbatim to the causal queries.
:return: A tuple containing (1) the median of causal query results and (2) the confidence intervals.
Expand All @@ -98,10 +98,10 @@ def snapshot():

if auto_assign_quality is not None:
auto.assign_causal_mechanisms(
causal_model_copy, sampled_data, auto_assign_quality, override_models=True, verbose=auto_assign_verbose
causal_model_copy, sampled_data, auto_assign_quality, override_models=True, verbose=verbose
)

fit(causal_model_copy, sampled_data)
fit(causal_model_copy, sampled_data, evaluate_models=verbose)
return f(causal_model_copy, *args, **kwargs)

return snapshot
66 changes: 64 additions & 2 deletions dowhy/gcm/fitting_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Functions in this module should be considered experimental, meaning there might be breaking API changes in the future.
"""
from typing import Any
from typing import Any, Dict, Union

import networkx as nx
import numpy as np
Expand All @@ -16,13 +16,22 @@
validate_causal_dag,
validate_causal_model_assignment,
)
from dowhy.gcm.config import add_info_log_msg
from dowhy.gcm.divergence import auto_estimate_kl_divergence
from dowhy.graph import get_ordered_predecessors, is_root_node


def fit(causal_model: ProbabilisticCausalModel, data: pd.DataFrame):
def fit(causal_model: ProbabilisticCausalModel,
data: pd.DataFrame,
evaluate_models: bool = True,
max_samples_for_evaluation: int = 5000) -> Union[None, Dict[Any, float]]:
"""Learns generative causal models of nodes in the causal graph from data.
:param causal_model: The causal model containing the mechanisms that will be fitted.
:param evaluate_models: If True, the models are evaluated using the KL divergence between generated data and the
given observed data after fitting the models. The results additionally returned as a
dictionary.
:param max_samples_for_evaluation: The maximum number of samples used for evaluating a causal mechanism.
:param data: Observations of nodes in the causal model.
"""
progress_bar = tqdm(
Expand All @@ -43,6 +52,48 @@ def fit(causal_model: ProbabilisticCausalModel, data: pd.DataFrame):

fit_causal_model_of_target(causal_model, node, data)

if evaluate_models:
evaluation_result = {}
add_info_log_msg("----- Evaluating models -----", True)
for node in nx.topological_sort(causal_model.graph):
if is_root_node(causal_model.graph, node):
kl_divergence = auto_estimate_kl_divergence(
causal_model.causal_mechanism(node).draw_samples(min(max_samples_for_evaluation, data.shape[0])),
data[node].to_numpy()[:5000])

add_info_log_msg("- Root node %s has a KL divergence of: %f -- %s" % (
node, kl_divergence, _get_kl_divergence_interpretation_string(kl_divergence)), True)

evaluation_result[node] = kl_divergence
else:
ground_truth = data[get_ordered_predecessors(causal_model.graph, node) + [node]].to_numpy()[
:max_samples_for_evaluation * 2]
inputs_draw = ground_truth[:ground_truth.shape[0] // 2, :-1]

kl_divergence = auto_estimate_kl_divergence(
np.column_stack([inputs_draw, causal_model.causal_mechanism(node).draw_samples(inputs_draw)]),
ground_truth[ground_truth.shape[0] // 2:])
substructure_string = "(" + ', '.join(
get_ordered_predecessors(causal_model.graph, node)) + ") -> " + str(node)
add_info_log_msg("- Causal mechanism for %s has a KL divergence of: %f -- %s"
% (substructure_string, kl_divergence,
_get_kl_divergence_interpretation_string(kl_divergence)),
True)

evaluation_result[substructure_string] = kl_divergence

generated_samples = draw_samples(causal_model, min(max_samples_for_evaluation, data.shape[0]))
kl_divergence = auto_estimate_kl_divergence(generated_samples.to_numpy(),
data[generated_samples.columns].to_numpy())
add_info_log_msg(
"--- Overall: The KL divergence between the join distribution of the generated data based on the fitted "
"models and the training data is: %f -- %s" % (
kl_divergence, _get_kl_divergence_interpretation_string(kl_divergence)), True)
evaluation_result["Overall"] = kl_divergence
add_info_log_msg("----- Finish evaluation -----", True)

return evaluation_result


def fit_causal_model_of_target(
causal_model: ProbabilisticCausalModel, target_node: Any, training_data: pd.DataFrame
Expand Down Expand Up @@ -102,3 +153,14 @@ def draw_samples(causal_model: ProbabilisticCausalModel, num_samples: int) -> pd

def _parent_samples_of(node: Any, scm: ProbabilisticCausalModel, samples: pd.DataFrame) -> np.ndarray:
return samples[get_ordered_predecessors(scm.graph, node)].to_numpy()


def _get_kl_divergence_interpretation_string(kl_divergence: float) -> str:
if kl_divergence < 0.5:
return "The estimated KL divergence indicates an overall good representation of the data distribution."
elif 0.5 <= kl_divergence < 1:
return "The estimated KL divergence might indicate some smaller mismatches between the distributions."
elif 1 <= kl_divergence < 3:
return "The estimated KL divergence indicates some significant mismatches between the distributions. Consider using models that better fit the distribution."
else:
return "The estimated KL divergence indicates a significant mismatches between the distributions. Consider using models that better fit the distribution.."

0 comments on commit 2a800bd

Please sign in to comment.