Skip to content

Commit

Permalink
Fix issue in falsify method when no tests were performed
Browse files Browse the repository at this point in the history
It now does not raise a division by zero error anymore. Other changes:
- Add new parameter indicating whether the method requires data for all nodes in the graph or also allows a subset of data.
- If no tests were performed, the summary now returns "Cannot be evaluated".

Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
  • Loading branch information
bloebp committed Nov 30, 2023
1 parent c41cefc commit 68de329
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
24 changes: 19 additions & 5 deletions dowhy/gcm/falsify.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,9 @@ def __repr__(self):

def _can_evaluate(self):
can_evaluate = True
if self.summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 0:
return False

for m in (FalsifyConst.VALIDATE_LMC, FalsifyConst.VALIDATE_TPA):
if m not in self.summary:
can_evaluate = False
Expand All @@ -517,6 +520,7 @@ def falsify_graph(
n_jobs: Optional[int] = None,
plot_histogram: bool = False,
plot_kwargs: Optional[Dict] = None,
allow_data_subset: bool = True,
) -> EvaluationResult:
"""
Falsify a given DAG using observational data.
Expand Down Expand Up @@ -563,14 +567,24 @@ def falsify_graph(
:param n_jobs: Number of jobs to use for parallel execution of (conditional) independence tests.
:param plot_histogram: Plot histogram of results from permutation baseline.
:param plot_kwargs: Additional plot arguments to be passed to plot_evaluation_results.
:param allow_data_subset: If True, performs the evaluation even if data is only available for a subset of nodes.
If False, raises an error if not all nodes have data available.
:return: EvaluationResult
"""
if not allow_data_subset and not set([str(node) for node in causal_graph.nodes]).issubset(
set([str(col) for col in data.columns])
):
raise ValueError(
"Did not find data for all nodes of the given graph! Make sure that the node names coincide "
"with the column names in the data."
)

n_jobs = config.default_n_jobs if n_jobs is None else n_jobs
show_progress_bar = config.show_progress_bars if show_progress_bar is None else show_progress_bar
p_values_memory = _PValuesMemory()

if n_permutations is None:
n_permutations = int(1 / significance_level) if not plot_histogram else -1
n_permutations = int(1 / significance_level)

if not plot_kwargs:
plot_kwargs = {}
Expand Down Expand Up @@ -624,10 +638,10 @@ def falsify_graph(
summary[m][FalsifyConst.GIVEN_VIOLATIONS] = summary_given[m][FalsifyConst.N_VIOLATIONS]
summary[m][FalsifyConst.N_TESTS] = summary_given[m][FalsifyConst.N_TESTS]
summary[m][FalsifyConst.F_PERM_VIOLATIONS] = [
perm[FalsifyConst.N_VIOLATIONS] / perm[FalsifyConst.N_TESTS] for perm in summary_perm[m]
perm[FalsifyConst.N_VIOLATIONS] / max(1, perm[FalsifyConst.N_TESTS]) for perm in summary_perm[m]
]
summary[m][FalsifyConst.F_GIVEN_VIOLATIONS] = (
summary[m][FalsifyConst.GIVEN_VIOLATIONS] / summary[m][FalsifyConst.N_TESTS]
summary[m][FalsifyConst.F_GIVEN_VIOLATIONS] = summary[m][FalsifyConst.GIVEN_VIOLATIONS] / max(
1, summary[m][FalsifyConst.N_TESTS]
)
summary[m][FalsifyConst.P_VALUE] = sum(
[
Expand All @@ -653,7 +667,7 @@ def falsify_graph(
significance_level=significance_level,
suggestions={m: summary_given[m] for m in summary_given if m not in validation_methods},
)
if plot_histogram:
if plot_histogram and result.can_evaluate:
plot_evaluation_results(result, **plot_kwargs)
return result

Expand Down
1 change: 1 addition & 0 deletions dowhy/gcm/model_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def evaluate_causal_model(
independence_test=config.independence_test_falsify,
conditional_independence_test=config.conditional_independence_test_falsify,
n_jobs=config.n_jobs,
allow_data_subset=False,
)

return evaluation_result
Expand Down
19 changes: 17 additions & 2 deletions tests/gcm/test_falsify.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.

from functools import partial

import networkx as nx
import numpy as np
import pandas as pd
import pytest
from flaky import flaky

from dowhy.datasets import generate_random_graph
from dowhy.gcm.falsify import (
FalsifyConst,
_PermuteNodes,
falsify_graph,
run_validations,
validate_cm,
validate_lmc,
Expand Down Expand Up @@ -360,3 +360,18 @@ def test_given_minimal_DAG_when_validating_causal_minimality_then_report_no_viol

assert summary[FalsifyConst.VALIDATE_CM][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_CM][FalsifyConst.N_TESTS] == 2


def test_given_data_of_only_a_subset_of_nodes_when_falsify_and_set_allow_data_subset_to_false_then_raise_error():
with pytest.raises(ValueError):
falsify_graph(
nx.DiGraph([("X", "Y"), ("Y", "Z")]),
pd.DataFrame({"X": np.random.random(10), "Y": np.random.random(10)}),
allow_data_subset=False,
)


def test_given_no_data_when_falsify_then_does_not_raise_error_but_cannot_evaluate():
assert not falsify_graph(nx.DiGraph([("X", "Y"), ("Y", "Z")]), pd.DataFrame({})).can_evaluate

assert not falsify_graph(nx.DiGraph([("X", "Y"), ("Y", "Z")]), pd.DataFrame({"X": []})).can_evaluate

0 comments on commit 68de329

Please sign in to comment.