Skip to content

Commit

Permalink
Fix issue with auto assignment with imbalanced classes
Browse files Browse the repository at this point in the history
Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
  • Loading branch information
bloebp committed Nov 15, 2023
1 parent eb88735 commit b9ae10b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
26 changes: 20 additions & 6 deletions dowhy/gcm/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sklearn import metrics
from sklearn.exceptions import ConvergenceWarning
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.model_selection import KFold, train_test_split
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
from sklearn.preprocessing import MultiLabelBinarizer

from dowhy.gcm import config
Expand Down Expand Up @@ -343,7 +343,7 @@ def find_best_model(
X: np.ndarray,
Y: np.ndarray,
metric: Optional[Callable[[np.ndarray, np.ndarray], float]] = None,
max_samples_per_split: int = 10000,
max_samples_per_split: int = 20000,
model_selection_splits: int = 5,
n_jobs: Optional[int] = None,
) -> Tuple[Callable[[], PredictionModel], List[Tuple[Callable[[], PredictionModel], float, str]]]:
Expand All @@ -370,16 +370,27 @@ def find_best_model(
labelBinarizer = MultiLabelBinarizer()
labelBinarizer.fit(Y)

kfolds = list(KFold(n_splits=model_selection_splits, shuffle=True).split(range(X.shape[0])))
if is_classification_problem:
if len(np.unique(Y)) == 1:
raise ValueError(
"The given target samples have only one class! To fit a classification model, there "
"should be at least two classes."
)
kfolds = list(StratifiedKFold(n_splits=model_selection_splits, shuffle=True).split(X, Y))
else:
kfolds = list(KFold(n_splits=model_selection_splits, shuffle=True).split(range(X.shape[0])))

def estimate_average_score(prediction_model_factory: Callable[[], PredictionModel], random_seed: int) -> float:
set_random_seed(random_seed)

average_result = 0
average_result = []

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=ConvergenceWarning)
for train_indices, test_indices in kfolds:
if is_classification_problem and len(np.unique(Y[train_indices[:max_samples_per_split]])) == 1:
continue

model_instance = prediction_model_factory()
model_instance.fit(X[train_indices[:max_samples_per_split]], Y[train_indices[:max_samples_per_split]])

Expand All @@ -389,9 +400,12 @@ def estimate_average_score(prediction_model_factory: Callable[[], PredictionMode
y_true = labelBinarizer.transform(y_true)
y_pred = labelBinarizer.transform(y_pred)

average_result += metric(y_true, y_pred)
average_result.append(metric(y_true, y_pred))

return average_result / model_selection_splits
if len(average_result) == 0:
return float("inf")
else:
return float(np.mean(average_result))

random_seeds = np.random.randint(np.iinfo(np.int32).max, size=len(prediction_model_factories))
average_metric_scores = Parallel(n_jobs=n_jobs)(
Expand Down
17 changes: 16 additions & 1 deletion tests/gcm/test_auto.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import networkx as nx
import numpy as np
import pandas as pd
import pytest
from _pytest.python_api import approx
from flaky import flaky
from pytest import mark
Expand All @@ -9,7 +10,7 @@
from sklearn.naive_bayes import GaussianNB
from sklearn.pipeline import Pipeline

from dowhy.gcm import ProbabilisticCausalModel, draw_samples, fit
from dowhy.gcm import ProbabilisticCausalModel, StructuralCausalModel, draw_samples, fit
from dowhy.gcm.auto import AssignmentQuality, assign_causal_mechanisms, has_linear_relationship


Expand Down Expand Up @@ -431,3 +432,17 @@ def test_given_categorical_data_when_print_auto_summary_then_returns_expected_fo
"Based on the type of causal mechanism, the model with the lowest metric value represents the best choice."
in summary_string
)


def test_given_imbalanced_classes_when_auto_assign_mechanism_then_handles_as_expected():
X = np.random.normal(0, 1, 1000)
Y = np.array(["OneClass"] * 1000)

with pytest.raises(ValueError):
assign_causal_mechanisms(StructuralCausalModel(nx.DiGraph([("X", "Y")])), pd.DataFrame({"X": X, "Y": Y}))

# Having at least one sample from the second class should not raise an error.
X = np.append(X, 0)
Y = np.append(Y, "RareClass")

assign_causal_mechanisms(StructuralCausalModel(nx.DiGraph([("X", "Y")])), pd.DataFrame({"X": X, "Y": Y}))

0 comments on commit b9ae10b

Please sign in to comment.