From bd23532342e917f538ff67f9a5837b05ef021248 Mon Sep 17 00:00:00 2001 From: Patrick Bloebaum Date: Tue, 17 Oct 2023 14:53:49 -0700 Subject: [PATCH] Change default parameter of SVC model in the GCM module Before, the Support Vector Classifier did not produce probabilities, which are required for different algorithms in the GCM module. This changes the 'probability' parameter to True. Signed-off-by: Patrick Bloebaum --- dowhy/gcm/ml/classification.py | 2 +- tests/gcm/ml/test_classification.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/dowhy/gcm/ml/classification.py b/dowhy/gcm/ml/classification.py index d4eaac39d0..d3b82413a8 100644 --- a/dowhy/gcm/ml/classification.py +++ b/dowhy/gcm/ml/classification.py @@ -76,7 +76,7 @@ def create_ada_boost_classifier(**kwargs) -> SklearnClassificationModel: def create_support_vector_classifier(**kwargs) -> SklearnClassificationModel: - return SklearnClassificationModel(SVC(**kwargs)) + return SklearnClassificationModel(SVC(**kwargs, probability=True)) def create_knn_classifier(**kwargs) -> SklearnClassificationModel: diff --git a/tests/gcm/ml/test_classification.py b/tests/gcm/ml/test_classification.py index 8a1fc5207b..7208512a1c 100644 --- a/tests/gcm/ml/test_classification.py +++ b/tests/gcm/ml/test_classification.py @@ -2,6 +2,7 @@ from flaky import flaky from dowhy.gcm.ml import create_hist_gradient_boost_classifier, create_polynom_logistic_regression_classifier +from dowhy.gcm.ml.classification import create_support_vector_classifier @flaky(max_runs=3) @@ -58,3 +59,9 @@ def _generate_data(): mdl.fit(X_training, Y_training) assert np.sum(mdl.predict(X_test).reshape(-1) == Y_test) > 950 + + +def test_given_svc_model_then_supports_predict_probabilities(): + mdl = create_support_vector_classifier() + mdl.fit(np.random.normal(0, 1, 100), np.random.choice(2, 100).astype(str)) + mdl.predict_probabilities(np.random.normal(0, 1, 10))