diff --git a/dowhy/gcm/ml/classification.py b/dowhy/gcm/ml/classification.py index d4eaac39d0..d3b82413a8 100644 --- a/dowhy/gcm/ml/classification.py +++ b/dowhy/gcm/ml/classification.py @@ -76,7 +76,7 @@ def create_ada_boost_classifier(**kwargs) -> SklearnClassificationModel: def create_support_vector_classifier(**kwargs) -> SklearnClassificationModel: - return SklearnClassificationModel(SVC(**kwargs)) + return SklearnClassificationModel(SVC(**kwargs, probability=True)) def create_knn_classifier(**kwargs) -> SklearnClassificationModel: diff --git a/tests/gcm/ml/test_classification.py b/tests/gcm/ml/test_classification.py index 8a1fc5207b..7208512a1c 100644 --- a/tests/gcm/ml/test_classification.py +++ b/tests/gcm/ml/test_classification.py @@ -2,6 +2,7 @@ from flaky import flaky from dowhy.gcm.ml import create_hist_gradient_boost_classifier, create_polynom_logistic_regression_classifier +from dowhy.gcm.ml.classification import create_support_vector_classifier @flaky(max_runs=3) @@ -58,3 +59,9 @@ def _generate_data(): mdl.fit(X_training, Y_training) assert np.sum(mdl.predict(X_test).reshape(-1) == Y_test) > 950 + + +def test_given_svc_model_then_supports_predict_probabilities(): + mdl = create_support_vector_classifier() + mdl.fit(np.random.normal(0, 1, 100), np.random.choice(2, 100).astype(str)) + mdl.predict_probabilities(np.random.normal(0, 1, 10))