Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue in falsify method when no tests were performed #1089

Merged
merged 1 commit into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions dowhy/gcm/falsify.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,9 @@ def __repr__(self):

def _can_evaluate(self):
can_evaluate = True
if self.summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 0:
return False

for m in (FalsifyConst.VALIDATE_LMC, FalsifyConst.VALIDATE_TPA):
if m not in self.summary:
can_evaluate = False
Expand All @@ -517,6 +520,7 @@ def falsify_graph(
n_jobs: Optional[int] = None,
plot_histogram: bool = False,
plot_kwargs: Optional[Dict] = None,
allow_data_subset: bool = True,
) -> EvaluationResult:
"""
Falsify a given DAG using observational data.
Expand Down Expand Up @@ -563,14 +567,24 @@ def falsify_graph(
:param n_jobs: Number of jobs to use for parallel execution of (conditional) independence tests.
:param plot_histogram: Plot histogram of results from permutation baseline.
:param plot_kwargs: Additional plot arguments to be passed to plot_evaluation_results.
:param allow_data_subset: If True, performs the evaluation even if data is only available for a subset of nodes.
If False, raises an error if not all nodes have data available.
:return: EvaluationResult
"""
if not allow_data_subset and not set([str(node) for node in causal_graph.nodes]).issubset(
set([str(col) for col in data.columns])
):
raise ValueError(
"Did not find data for all nodes of the given graph! Make sure that the node names coincide "
"with the column names in the data."
)

n_jobs = config.default_n_jobs if n_jobs is None else n_jobs
show_progress_bar = config.show_progress_bars if show_progress_bar is None else show_progress_bar
p_values_memory = _PValuesMemory()

if n_permutations is None:
n_permutations = int(1 / significance_level) if not plot_histogram else -1
n_permutations = int(1 / significance_level)

if not plot_kwargs:
plot_kwargs = {}
Expand Down Expand Up @@ -624,10 +638,10 @@ def falsify_graph(
summary[m][FalsifyConst.GIVEN_VIOLATIONS] = summary_given[m][FalsifyConst.N_VIOLATIONS]
summary[m][FalsifyConst.N_TESTS] = summary_given[m][FalsifyConst.N_TESTS]
summary[m][FalsifyConst.F_PERM_VIOLATIONS] = [
perm[FalsifyConst.N_VIOLATIONS] / perm[FalsifyConst.N_TESTS] for perm in summary_perm[m]
perm[FalsifyConst.N_VIOLATIONS] / max(1, perm[FalsifyConst.N_TESTS]) for perm in summary_perm[m]
]
summary[m][FalsifyConst.F_GIVEN_VIOLATIONS] = (
summary[m][FalsifyConst.GIVEN_VIOLATIONS] / summary[m][FalsifyConst.N_TESTS]
summary[m][FalsifyConst.F_GIVEN_VIOLATIONS] = summary[m][FalsifyConst.GIVEN_VIOLATIONS] / max(
1, summary[m][FalsifyConst.N_TESTS]
)
summary[m][FalsifyConst.P_VALUE] = sum(
[
Expand All @@ -653,7 +667,7 @@ def falsify_graph(
significance_level=significance_level,
suggestions={m: summary_given[m] for m in summary_given if m not in validation_methods},
)
if plot_histogram:
if plot_histogram and result.can_evaluate:
plot_evaluation_results(result, **plot_kwargs)
return result

Expand Down
1 change: 1 addition & 0 deletions dowhy/gcm/model_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def evaluate_causal_model(
independence_test=config.independence_test_falsify,
conditional_independence_test=config.conditional_independence_test_falsify,
n_jobs=config.n_jobs,
allow_data_subset=False,
)

return evaluation_result
Expand Down
19 changes: 17 additions & 2 deletions tests/gcm/test_falsify.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.

from functools import partial

import networkx as nx
import numpy as np
import pandas as pd
import pytest
from flaky import flaky

from dowhy.datasets import generate_random_graph
from dowhy.gcm.falsify import (
FalsifyConst,
_PermuteNodes,
falsify_graph,
run_validations,
validate_cm,
validate_lmc,
Expand Down Expand Up @@ -360,3 +360,18 @@ def test_given_minimal_DAG_when_validating_causal_minimality_then_report_no_viol

assert summary[FalsifyConst.VALIDATE_CM][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_CM][FalsifyConst.N_TESTS] == 2


def test_given_data_of_only_a_subset_of_nodes_when_falsify_and_set_allow_data_subset_to_false_then_raise_error():
with pytest.raises(ValueError):
falsify_graph(
nx.DiGraph([("X", "Y"), ("Y", "Z")]),
pd.DataFrame({"X": np.random.random(10), "Y": np.random.random(10)}),
allow_data_subset=False,
)


def test_given_no_data_when_falsify_then_does_not_raise_error_but_cannot_evaluate():
assert not falsify_graph(nx.DiGraph([("X", "Y"), ("Y", "Z")]), pd.DataFrame({})).can_evaluate

assert not falsify_graph(nx.DiGraph([("X", "Y"), ("Y", "Z")]), pd.DataFrame({"X": []})).can_evaluate