Skip to content

Commit

Permalink
Fix bug in networkx plot function with 0 error strenghts
Browse files Browse the repository at this point in the history
Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
  • Loading branch information
bloebp committed Dec 11, 2023
1 parent 4f31734 commit 72986a8
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion dowhy/utils/networkx_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def plot_causal_graph_networkx(
if (source, target) not in causal_strengths:
causal_strengths[(source, target)] = strength

if strength is not None:
if causal_strengths[(source, target)] is not None:
max_strength = max(max_strength, abs(causal_strengths[(source, target)]))

if (source, target) not in colors:
Expand Down
5 changes: 5 additions & 0 deletions dowhy/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ def bar_plot(


def _calc_arrow_width(strength: float, max_strength: float):
if max_strength == 0:
return 4.1
elif max_strength < 0:
raise ValueError("Got a negative strength! The strength needs to be positive.")

return 0.1 + 4.0 * float(abs(strength)) / float(max_strength)


Expand Down
6 changes: 6 additions & 0 deletions tests/utils/test_plotting.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import networkx as nx
import numpy as np
import pandas as pd
import pytest
from _pytest.python_api import approx

from dowhy.utils import plot, plot_adjacency_matrix
from dowhy.utils.networkx_plotting import plot_causal_graph_networkx
from dowhy.utils.plotting import _calc_arrow_width, bar_plot


Expand Down Expand Up @@ -48,6 +50,10 @@ def test_calc_arrow_width():
assert _calc_arrow_width(0.5, max_strength=0.5) == approx(4.1, abs=0.01)
assert _calc_arrow_width(0.35, max_strength=0.5) == approx(2.9, abs=0.01)
assert _calc_arrow_width(100, max_strength=101) == approx(4.06, abs=0.01)
assert _calc_arrow_width(100, max_strength=0) == 4.1

with pytest.raises(ValueError):
_calc_arrow_width(100, max_strength=-1)


def test_given_misspecified_uncertainties_when_bar_plot_then_does_not_raise_error():
Expand Down

0 comments on commit 72986a8

Please sign in to comment.