Skip to content

Commit 103ee83

Browse files
committed
Change default parameter of SVC model in the GCM module
Before, the Support Vector Classifier did not produce probabilities, which are required for different algorithms in the GCM module. This changes the 'probability' parameter to True. Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
1 parent b2e75a7 commit 103ee83

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

dowhy/gcm/ml/classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def create_ada_boost_classifier(**kwargs) -> SklearnClassificationModel:
7676

7777

7878
def create_support_vector_classifier(**kwargs) -> SklearnClassificationModel:
79-
return SklearnClassificationModel(SVC(**kwargs))
79+
return SklearnClassificationModel(SVC(**kwargs, probability=True))
8080

8181

8282
def create_knn_classifier(**kwargs) -> SklearnClassificationModel:

tests/gcm/ml/test_classification.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from flaky import flaky
33

44
from dowhy.gcm.ml import create_hist_gradient_boost_classifier, create_polynom_logistic_regression_classifier
5+
from dowhy.gcm.ml.classification import create_support_vector_classifier
56

67

78
@flaky(max_runs=3)
@@ -58,3 +59,9 @@ def _generate_data():
5859
mdl.fit(X_training, Y_training)
5960

6061
assert np.sum(mdl.predict(X_test).reshape(-1) == Y_test) > 950
62+
63+
64+
def test_given_svc_model_then_supports_predict_probabilities():
65+
mdl = create_support_vector_classifier()
66+
mdl.fit(np.random.normal(0, 1, 100), np.random.choice(2, 100).astype(str))
67+
mdl.predict_probabilities(np.random.normal(0, 1, 10))

0 commit comments

Comments
 (0)