Skip to content

Commit

Permalink
Adjust format based on new formatter version
Browse files Browse the repository at this point in the history
The new black formatter version (24.3.0) requires new changes, which are addressed in this commit.

Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
  • Loading branch information
bloebp committed Apr 12, 2024
1 parent 74e51c3 commit bbf4c9b
Show file tree
Hide file tree
Showing 22 changed files with 59 additions and 92 deletions.
6 changes: 2 additions & 4 deletions dowhy/causal_estimators/causalml.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@


class _CausalmlEstimator(Protocol):
def estimate_ate(self, *args, **kwargs):
...
def estimate_ate(self, *args, **kwargs): ...

def fit_predict(self, *args, **kwargs):
...
def fit_predict(self, *args, **kwargs): ...


logger = logging.getLogger(__name__)
Expand Down
15 changes: 5 additions & 10 deletions dowhy/causal_estimators/econml.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,15 @@


class _EconmlEstimator(Protocol):
def fit(self, *args, **kwargs):
...
def fit(self, *args, **kwargs): ...

def effect(self, *args, **kwargs):
...
def effect(self, *args, **kwargs): ...

def effect_interval(self, *args, **kwargs):
...
def effect_interval(self, *args, **kwargs): ...

def effect_inference(self, *args, **kwargs):
...
def effect_inference(self, *args, **kwargs): ...

def shap_values(self, *args, **kwargs):
...
def shap_values(self, *args, **kwargs): ...


class Econml(CausalEstimator):
Expand Down
1 change: 0 additions & 1 deletion dowhy/causal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


class CausalGraph:

"""Class for creating and modifying the causal graph.
Accepts a networkx DiGraph, a :py:class:`ProbabilisticCausalModel <dowhy.gcm.ProbabilisticCausalModel`, a graph string (or a text file) in gml format (preferred) or dot format. Graphviz-like attributes can be set for edges and nodes. E.g. style="dashed" as an edge attribute ensures that the edge is drawn with a dashed line.
Expand Down
1 change: 0 additions & 1 deletion dowhy/causal_identifier/identified_estimand.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


class IdentifiedEstimand:

"""Class for storing a causal estimand, typically as a result of the identification step."""

def __init__(
Expand Down
1 change: 1 addition & 0 deletions dowhy/causal_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" Module containing the main model class for the dowhy package.
"""

import logging
import typing
import warnings
Expand Down
1 change: 0 additions & 1 deletion dowhy/causal_refuter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class SignificanceTestType(Enum):


class CausalRefuter:

"""Base class for different refutation methods.
Subclasses implement specific refutations methods.
Expand Down
1 change: 0 additions & 1 deletion dowhy/causal_refuters/add_unobserved_common_cause.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@


class AddUnobservedCommonCause(CausalRefuter):

"""Add an unobserved confounder for refutation.
AddUnobservedCommonCause class supports three methods:
Expand Down
1 change: 1 addition & 0 deletions dowhy/causal_refuters/overrule/ruleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Overlap in Observational Studies. In S. Chiappa & R. Calandra (Eds.), Proceedings of the Twenty Third International
Conference on Artificial Intelligence and Statistics (Vol. 108, pp. 788–798). PMLR. https://arxiv.org/abs/1907.04138
"""

from typing import Callable, Dict, List, Optional

import numpy as np
Expand Down
4 changes: 2 additions & 2 deletions dowhy/gcm/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __str__(self):
% self._nodes[node]["model_performances"][0][2]
)

for (model, performance, metric_name) in self._nodes[node]["model_performances"]:
for model, performance, metric_name in self._nodes[node]["model_performances"]:
summary_strings.append("%s: %s" % (str(model()).replace("()", ""), str(performance)))

summary_strings.append(
Expand Down Expand Up @@ -344,7 +344,7 @@ def assign_causal_mechanisms(
+ ",N).",
)

for (model, performance, metric_name) in model_performances:
for model, performance, metric_name in model_performances:
auto_assignment_summary.add_model_performance(node, model, performance, metric_name)

return auto_assignment_summary
Expand Down
1 change: 1 addition & 0 deletions dowhy/gcm/falsify.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This module provides functionality to falsify a user-given DAG given observed data."""

import warnings
from dataclasses import dataclass, field
from enum import Enum, auto
Expand Down
7 changes: 4 additions & 3 deletions dowhy/gcm/feature_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
be blackbox prediction models, it is also possible to explain causal mechanisms with respect to the direct parents.
In these cases, it would be possible to incorporate the noise to represent the part of the generation process that
cannot be explained by the parents."""

from typing import Any, Callable, Dict, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -258,9 +259,9 @@ def single_sample_set_function(subset: np.ndarray) -> Union[np.ndarray, float]:
baseline_samples=baseline_samples,
baseline_feature_indices=np.arange(0, feature_samples.shape[1])[subset == 1],
return_averaged_results=False,
feature_perturbation="randomize_columns_jointly"
if randomize_features_jointly
else "randomize_columns_independently",
feature_perturbation=(
"randomize_columns_jointly" if randomize_features_jointly else "randomize_columns_independently"
),
max_batch_size=max_batch_size,
)

Expand Down
1 change: 1 addition & 0 deletions dowhy/gcm/independence_test/regression.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" Regression based (conditional) independence test. Testing independence via regression, i.e. if a variable has
information about another variable, then they are dependent.
"""

from typing import Callable, List, Optional, Union

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions dowhy/gcm/influence.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This module provides functions to estimate causal influences."""

import logging
import warnings
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union, cast
Expand Down
2 changes: 1 addition & 1 deletion dowhy/gcm/shapley.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def parallel_job(input_subset: Tuple[int], parallel_random_seed: int) -> Union[f
)

subset_to_result_map = {}
for (subset, result) in zip(evaluation_subsets, subset_results):
for subset, result in zip(evaluation_subsets, subset_results):
subset_to_result_map[subset] = result

return subset_to_result_map
Expand Down
1 change: 1 addition & 0 deletions dowhy/gcm/validation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Contains a method to reject the causal graph and validate causal mechanisms such as post non-linear models."""

from enum import Enum, auto
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand Down
1 change: 1 addition & 0 deletions dowhy/graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This module defines the fundamental interfaces and functions related to causal graphs."""

import itertools
import logging
import re
Expand Down
1 change: 0 additions & 1 deletion dowhy/interpreters/confounder_distribution_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def discrete_dist_plot(labels, not_treated_counts, treated_counts, ax, title, va
ax.legend()

def interpret(self, data: pd.DataFrame):

"""
Shows distribution changes for confounding variables before and after applying inverse propensity weights.
"""
Expand Down
12 changes: 2 additions & 10 deletions dowhy/utils/dgps/linear_dgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,10 @@ def generate_data(self, sample_size):
def generation_process(self):
self.weights["confounder=>treatment"] = self.generate_weights((len(self.confounder), len(self.treatment)))
self.weights["confounder=>treatment"][0,] = (
self.weights["confounder=>treatment"][
0,
]
+ 100
self.weights["confounder=>treatment"][0,] + 100
) # increasing weight of the first confounder
self.weights["confounder=>outcome"] = self.generate_weights((len(self.confounder), len(self.outcome)))
self.weights["confounder=>outcome"][0,] = (
self.weights["confounder=>outcome"][
0,
]
+ 100
)
self.weights["confounder=>outcome"][0,] = self.weights["confounder=>outcome"][0,] + 100
self.weights["effect_modifier=>outcome"] = self.generate_weights((len(self.effect_modifier), len(self.outcome)))
self.weights["treatment=>outcome"] = self.generate_weights((len(self.treatment), len(self.outcome)))

Expand Down
4 changes: 2 additions & 2 deletions dowhy/utils/graphviz_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def plot_causal_graph_graphviz(
layout_prog = "dot"

max_strength = 0.0
for (source, target, strength) in causal_graph.edges(data="CAUSAL_STRENGTH", default=None):
for source, target, strength in causal_graph.edges(data="CAUSAL_STRENGTH", default=None):
if (source, target) not in causal_strengths:
causal_strengths[(source, target)] = strength
if causal_strengths[(source, target)] is not None:
Expand All @@ -46,7 +46,7 @@ def plot_causal_graph_graphviz(
else:
pygraphviz_graph.add_node(node)

for (source, target) in causal_graph.edges():
for source, target in causal_graph.edges():
causal_strength = causal_strengths[(source, target)]
color = colors[(source, target)]
if causal_strength is not None:
Expand Down
2 changes: 1 addition & 1 deletion dowhy/utils/networkx_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def plot_causal_graph_networkx(
colors = deepcopy(colors)

max_strength = 0.0
for (source, target, strength) in causal_graph.edges(data="CAUSAL_STRENGTH", default=None):
for source, target, strength in causal_graph.edges(data="CAUSAL_STRENGTH", default=None):
if (source, target) not in causal_strengths:
causal_strengths[(source, target)] = strength

Expand Down
Loading

0 comments on commit bbf4c9b

Please sign in to comment.