Skip to content

Commit

Permalink
Change default parameter of SVC model in the GCM module
Browse files Browse the repository at this point in the history
Before, the Support Vector Classifier did not produce probabilities, which are required for different algorithms in the GCM module. This changes the 'probability' parameter to True.

Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
  • Loading branch information
bloebp committed Nov 21, 2023
1 parent c88cc83 commit bd23532
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion dowhy/gcm/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def create_ada_boost_classifier(**kwargs) -> SklearnClassificationModel:


def create_support_vector_classifier(**kwargs) -> SklearnClassificationModel:
return SklearnClassificationModel(SVC(**kwargs))
return SklearnClassificationModel(SVC(**kwargs, probability=True))


def create_knn_classifier(**kwargs) -> SklearnClassificationModel:
Expand Down
7 changes: 7 additions & 0 deletions tests/gcm/ml/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from flaky import flaky

from dowhy.gcm.ml import create_hist_gradient_boost_classifier, create_polynom_logistic_regression_classifier
from dowhy.gcm.ml.classification import create_support_vector_classifier


@flaky(max_runs=3)
Expand Down Expand Up @@ -58,3 +59,9 @@ def _generate_data():
mdl.fit(X_training, Y_training)

assert np.sum(mdl.predict(X_test).reshape(-1) == Y_test) > 950


def test_given_svc_model_then_supports_predict_probabilities():
mdl = create_support_vector_classifier()
mdl.fit(np.random.normal(0, 1, 100), np.random.choice(2, 100).astype(str))
mdl.predict_probabilities(np.random.normal(0, 1, 10))

0 comments on commit bd23532

Please sign in to comment.