From 72986a859db309b1afef92a542955569669bd8ea Mon Sep 17 00:00:00 2001 From: Patrick Bloebaum Date: Fri, 8 Dec 2023 12:00:35 -0800 Subject: [PATCH] Fix bug in networkx plot function with 0 error strenghts Signed-off-by: Patrick Bloebaum --- dowhy/utils/networkx_plotting.py | 2 +- dowhy/utils/plotting.py | 5 +++++ tests/utils/test_plotting.py | 6 ++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/dowhy/utils/networkx_plotting.py b/dowhy/utils/networkx_plotting.py index efafa45472..06f97c89d9 100644 --- a/dowhy/utils/networkx_plotting.py +++ b/dowhy/utils/networkx_plotting.py @@ -32,7 +32,7 @@ def plot_causal_graph_networkx( if (source, target) not in causal_strengths: causal_strengths[(source, target)] = strength - if strength is not None: + if causal_strengths[(source, target)] is not None: max_strength = max(max_strength, abs(causal_strengths[(source, target)])) if (source, target) not in colors: diff --git a/dowhy/utils/plotting.py b/dowhy/utils/plotting.py index 93d37eaf79..2f67001cc8 100644 --- a/dowhy/utils/plotting.py +++ b/dowhy/utils/plotting.py @@ -176,6 +176,11 @@ def bar_plot( def _calc_arrow_width(strength: float, max_strength: float): + if max_strength == 0: + return 4.1 + elif max_strength < 0: + raise ValueError("Got a negative strength! The strength needs to be positive.") + return 0.1 + 4.0 * float(abs(strength)) / float(max_strength) diff --git a/tests/utils/test_plotting.py b/tests/utils/test_plotting.py index 880e34b56c..ccc6d75d2f 100644 --- a/tests/utils/test_plotting.py +++ b/tests/utils/test_plotting.py @@ -1,9 +1,11 @@ import networkx as nx import numpy as np import pandas as pd +import pytest from _pytest.python_api import approx from dowhy.utils import plot, plot_adjacency_matrix +from dowhy.utils.networkx_plotting import plot_causal_graph_networkx from dowhy.utils.plotting import _calc_arrow_width, bar_plot @@ -48,6 +50,10 @@ def test_calc_arrow_width(): assert _calc_arrow_width(0.5, max_strength=0.5) == approx(4.1, abs=0.01) assert _calc_arrow_width(0.35, max_strength=0.5) == approx(2.9, abs=0.01) assert _calc_arrow_width(100, max_strength=101) == approx(4.06, abs=0.01) + assert _calc_arrow_width(100, max_strength=0) == 4.1 + + with pytest.raises(ValueError): + _calc_arrow_width(100, max_strength=-1) def test_given_misspecified_uncertainties_when_bar_plot_then_does_not_raise_error():