Skip to content

Commit 21fba9f

Browse files
committed
Fix issue with linear regressor with fixed parameters
Before, when creating a linear regressor with fixed parameters, these parameters are overridden when fit to data. Now, the parameters remain fixed. Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
1 parent b9ae10b commit 21fba9f

File tree

2 files changed

+44
-7
lines changed

2 files changed

+44
-7
lines changed

dowhy/gcm/ml/regression.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,25 @@ def __str__(self):
6161
return str(self._sklearn_mdl)
6262

6363

64+
class LinearRegressionWithFixedParameter(PredictionModel):
65+
def __init__(self, coefficients: np.ndarray, intercept: float):
66+
self.coefficients = coefficients
67+
self.intercept = intercept
68+
69+
def fit(self, X: np.ndarray, Y: np.ndarray) -> None:
70+
pass
71+
72+
def predict(self, X: np.ndarray) -> np.ndarray:
73+
return (np.dot(shape_into_2d(X), self.coefficients) + self.intercept).reshape(-1, 1)
74+
75+
def clone(self):
76+
return LinearRegressionWithFixedParameter(coefficients=self.coefficients, intercept=self.intercept)
77+
78+
6479
def create_linear_regressor_with_given_parameters(
6580
coefficients: np.ndarray, intercept: float = 0, **kwargs
66-
) -> SklearnRegressionModel:
67-
linear_model = LinearRegression(**kwargs)
68-
linear_model.coef_ = coefficients
69-
linear_model.intercept_ = intercept
70-
71-
return SklearnRegressionModel(linear_model)
81+
) -> LinearRegressionWithFixedParameter:
82+
return LinearRegressionWithFixedParameter(np.array(coefficients), intercept)
7283

7384

7485
def create_linear_regressor(**kwargs) -> SklearnRegressionModel:

tests/gcm/test_ml.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
import numpy as np
2+
import pytest
23
from flaky import flaky
34
from pytest import approx
45
from sklearn.linear_model import LogisticRegression
56

6-
from dowhy.gcm.ml import SklearnClassificationModel, create_linear_regressor, create_logistic_regression_classifier
7+
from dowhy.gcm.ml import (
8+
SklearnClassificationModel,
9+
create_linear_regressor,
10+
create_linear_regressor_with_given_parameters,
11+
create_logistic_regression_classifier,
12+
)
713

814

915
@flaky(max_runs=5)
@@ -78,3 +84,23 @@ def test_when_cloning_sklearn_classification_model_then_returns_a_cloned_object(
7884
assert isinstance(cloned_mdl.sklearn_model, LogisticRegression)
7985
assert mdl != cloned_mdl
8086
assert cloned_mdl.sklearn_model != logistic_regression_model
87+
88+
89+
def test_when_using_linear_regressor_with_given_parameters_then_fit_does_not_override_parameters():
90+
mdl = create_linear_regressor_with_given_parameters([1, 2, 3], 4)
91+
92+
assert mdl.coefficients == pytest.approx([1, 2, 3])
93+
assert mdl.intercept == 4
94+
95+
mdl.fit(np.random.normal(0, 1, (100, 3)), np.arange(100))
96+
97+
assert mdl.coefficients == pytest.approx([1, 2, 3])
98+
assert mdl.intercept == 4
99+
100+
101+
def test_when_predict_with_linear_regressor_with_given_parameters_then_returns_expected_results():
102+
mdl = create_linear_regressor_with_given_parameters([2], 4)
103+
104+
assert mdl.predict(np.array([0])) == approx(4)
105+
assert mdl.predict(np.array([1])) == approx(6)
106+
assert mdl.predict(np.array([0, 1, 2, 3])) == approx([4, 6, 8, 10])

0 commit comments

Comments
 (0)