From b04189861c2d61df339ce257c48c8057d4332712 Mon Sep 17 00:00:00 2001 From: Patrick Bloebaum Date: Thu, 9 Nov 2023 15:42:20 -0800 Subject: [PATCH] Fix issue with linear regressor with fixed parameters Before, when creating a linear regressor with fixed parameters, these parameters are overridden when fit to data. Now, the parameters remain fixed. Signed-off-by: Patrick Bloebaum --- dowhy/gcm/ml/regression.py | 23 +++++++++++++++++------ tests/gcm/test_ml.py | 28 +++++++++++++++++++++++++++- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/dowhy/gcm/ml/regression.py b/dowhy/gcm/ml/regression.py index 7e9d1a64d6..2b5c96b216 100644 --- a/dowhy/gcm/ml/regression.py +++ b/dowhy/gcm/ml/regression.py @@ -61,14 +61,25 @@ def __str__(self): return str(self._sklearn_mdl) +class LinearRegressionWithFixedParameter(PredictionModel): + def __init__(self, coefficients: np.ndarray, intercept: float): + self.coefficients = coefficients + self.intercept = intercept + + def fit(self, X: np.ndarray, Y: np.ndarray) -> None: + pass + + def predict(self, X: np.ndarray) -> np.ndarray: + return (np.dot(shape_into_2d(X), self.coefficients) + self.intercept).reshape(-1, 1) + + def clone(self): + return LinearRegressionWithFixedParameter(coefficients=self.coefficients, intercept=self.intercept) + + def create_linear_regressor_with_given_parameters( coefficients: np.ndarray, intercept: float = 0, **kwargs -) -> SklearnRegressionModel: - linear_model = LinearRegression(**kwargs) - linear_model.coef_ = coefficients - linear_model.intercept_ = intercept - - return SklearnRegressionModel(linear_model) +) -> LinearRegressionWithFixedParameter: + return LinearRegressionWithFixedParameter(np.array(coefficients), intercept) def create_linear_regressor(**kwargs) -> SklearnRegressionModel: diff --git a/tests/gcm/test_ml.py b/tests/gcm/test_ml.py index 9f927dd227..e3b856b132 100644 --- a/tests/gcm/test_ml.py +++ b/tests/gcm/test_ml.py @@ -1,9 +1,15 @@ import numpy as np +import pytest from flaky import flaky from pytest import approx from sklearn.linear_model import LogisticRegression -from dowhy.gcm.ml import SklearnClassificationModel, create_linear_regressor, create_logistic_regression_classifier +from dowhy.gcm.ml import ( + SklearnClassificationModel, + create_linear_regressor, + create_linear_regressor_with_given_parameters, + create_logistic_regression_classifier, +) @flaky(max_runs=5) @@ -78,3 +84,23 @@ def test_when_cloning_sklearn_classification_model_then_returns_a_cloned_object( assert isinstance(cloned_mdl.sklearn_model, LogisticRegression) assert mdl != cloned_mdl assert cloned_mdl.sklearn_model != logistic_regression_model + + +def test_when_using_linear_regressor_with_given_parameters_then_fit_does_not_override_parameters(): + mdl = create_linear_regressor_with_given_parameters([1, 2, 3], 4) + + assert mdl.coefficients == pytest.approx([1, 2, 3]) + assert mdl.intercept == 4 + + mdl.fit(np.random.normal(0, 1, (100, 3)), np.arange(100)) + + assert mdl.coefficients == pytest.approx([1, 2, 3]) + assert mdl.intercept == 4 + + +def test_when_predict_with_linear_regressor_with_given_parameters_then_returns_expected_results(): + mdl = create_linear_regressor_with_given_parameters([2], 4) + + assert mdl.predict(np.array([0])) == approx(4) + assert mdl.predict(np.array([1])) == approx(6) + assert mdl.predict(np.array([0, 1, 2, 3])) == approx([4, 6, 8, 10])