From 68de329255639222135d97f3c24cbe0c7e6eba06 Mon Sep 17 00:00:00 2001 From: Patrick Bloebaum Date: Wed, 29 Nov 2023 17:06:49 -0800 Subject: [PATCH] Fix issue in falsify method when no tests were performed It now does not raise a division by zero error anymore. Other changes: - Add new parameter indicating whether the method requires data for all nodes in the graph or also allows a subset of data. - If no tests were performed, the summary now returns "Cannot be evaluated". Signed-off-by: Patrick Bloebaum --- dowhy/gcm/falsify.py | 24 +++++++++++++++++++----- dowhy/gcm/model_evaluation.py | 1 + tests/gcm/test_falsify.py | 19 +++++++++++++++++-- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/dowhy/gcm/falsify.py b/dowhy/gcm/falsify.py index 3a794c348a..704fff2f78 100644 --- a/dowhy/gcm/falsify.py +++ b/dowhy/gcm/falsify.py @@ -497,6 +497,9 @@ def __repr__(self): def _can_evaluate(self): can_evaluate = True + if self.summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 0: + return False + for m in (FalsifyConst.VALIDATE_LMC, FalsifyConst.VALIDATE_TPA): if m not in self.summary: can_evaluate = False @@ -517,6 +520,7 @@ def falsify_graph( n_jobs: Optional[int] = None, plot_histogram: bool = False, plot_kwargs: Optional[Dict] = None, + allow_data_subset: bool = True, ) -> EvaluationResult: """ Falsify a given DAG using observational data. @@ -563,14 +567,24 @@ def falsify_graph( :param n_jobs: Number of jobs to use for parallel execution of (conditional) independence tests. :param plot_histogram: Plot histogram of results from permutation baseline. :param plot_kwargs: Additional plot arguments to be passed to plot_evaluation_results. + :param allow_data_subset: If True, performs the evaluation even if data is only available for a subset of nodes. + If False, raises an error if not all nodes have data available. :return: EvaluationResult """ + if not allow_data_subset and not set([str(node) for node in causal_graph.nodes]).issubset( + set([str(col) for col in data.columns]) + ): + raise ValueError( + "Did not find data for all nodes of the given graph! Make sure that the node names coincide " + "with the column names in the data." + ) + n_jobs = config.default_n_jobs if n_jobs is None else n_jobs show_progress_bar = config.show_progress_bars if show_progress_bar is None else show_progress_bar p_values_memory = _PValuesMemory() if n_permutations is None: - n_permutations = int(1 / significance_level) if not plot_histogram else -1 + n_permutations = int(1 / significance_level) if not plot_kwargs: plot_kwargs = {} @@ -624,10 +638,10 @@ def falsify_graph( summary[m][FalsifyConst.GIVEN_VIOLATIONS] = summary_given[m][FalsifyConst.N_VIOLATIONS] summary[m][FalsifyConst.N_TESTS] = summary_given[m][FalsifyConst.N_TESTS] summary[m][FalsifyConst.F_PERM_VIOLATIONS] = [ - perm[FalsifyConst.N_VIOLATIONS] / perm[FalsifyConst.N_TESTS] for perm in summary_perm[m] + perm[FalsifyConst.N_VIOLATIONS] / max(1, perm[FalsifyConst.N_TESTS]) for perm in summary_perm[m] ] - summary[m][FalsifyConst.F_GIVEN_VIOLATIONS] = ( - summary[m][FalsifyConst.GIVEN_VIOLATIONS] / summary[m][FalsifyConst.N_TESTS] + summary[m][FalsifyConst.F_GIVEN_VIOLATIONS] = summary[m][FalsifyConst.GIVEN_VIOLATIONS] / max( + 1, summary[m][FalsifyConst.N_TESTS] ) summary[m][FalsifyConst.P_VALUE] = sum( [ @@ -653,7 +667,7 @@ def falsify_graph( significance_level=significance_level, suggestions={m: summary_given[m] for m in summary_given if m not in validation_methods}, ) - if plot_histogram: + if plot_histogram and result.can_evaluate: plot_evaluation_results(result, **plot_kwargs) return result diff --git a/dowhy/gcm/model_evaluation.py b/dowhy/gcm/model_evaluation.py index 1a1d62612f..aa9d1c3a31 100644 --- a/dowhy/gcm/model_evaluation.py +++ b/dowhy/gcm/model_evaluation.py @@ -401,6 +401,7 @@ def evaluate_causal_model( independence_test=config.independence_test_falsify, conditional_independence_test=config.conditional_independence_test_falsify, n_jobs=config.n_jobs, + allow_data_subset=False, ) return evaluation_result diff --git a/tests/gcm/test_falsify.py b/tests/gcm/test_falsify.py index 350e3b16f4..a5a5d026f2 100644 --- a/tests/gcm/test_falsify.py +++ b/tests/gcm/test_falsify.py @@ -1,16 +1,16 @@ -# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. - from functools import partial import networkx as nx import numpy as np import pandas as pd +import pytest from flaky import flaky from dowhy.datasets import generate_random_graph from dowhy.gcm.falsify import ( FalsifyConst, _PermuteNodes, + falsify_graph, run_validations, validate_cm, validate_lmc, @@ -360,3 +360,18 @@ def test_given_minimal_DAG_when_validating_causal_minimality_then_report_no_viol assert summary[FalsifyConst.VALIDATE_CM][FalsifyConst.N_VIOLATIONS] == 0 assert summary[FalsifyConst.VALIDATE_CM][FalsifyConst.N_TESTS] == 2 + + +def test_given_data_of_only_a_subset_of_nodes_when_falsify_and_set_allow_data_subset_to_false_then_raise_error(): + with pytest.raises(ValueError): + falsify_graph( + nx.DiGraph([("X", "Y"), ("Y", "Z")]), + pd.DataFrame({"X": np.random.random(10), "Y": np.random.random(10)}), + allow_data_subset=False, + ) + + +def test_given_no_data_when_falsify_then_does_not_raise_error_but_cannot_evaluate(): + assert not falsify_graph(nx.DiGraph([("X", "Y"), ("Y", "Z")]), pd.DataFrame({})).can_evaluate + + assert not falsify_graph(nx.DiGraph([("X", "Y"), ("Y", "Z")]), pd.DataFrame({"X": []})).can_evaluate