Skip to content

Commit

Permalink
Fix issue with linear regressor with fixed parameters
Browse files Browse the repository at this point in the history
Before, when creating a linear regressor with fixed parameters, these parameters are overridden when fit to data. Now, the parameters remain fixed.

Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
  • Loading branch information
bloebp committed Nov 9, 2023
1 parent 041d5ab commit b041898
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 7 deletions.
23 changes: 17 additions & 6 deletions dowhy/gcm/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,25 @@ def __str__(self):
return str(self._sklearn_mdl)


class LinearRegressionWithFixedParameter(PredictionModel):
def __init__(self, coefficients: np.ndarray, intercept: float):
self.coefficients = coefficients
self.intercept = intercept

def fit(self, X: np.ndarray, Y: np.ndarray) -> None:
pass

def predict(self, X: np.ndarray) -> np.ndarray:
return (np.dot(shape_into_2d(X), self.coefficients) + self.intercept).reshape(-1, 1)

def clone(self):
return LinearRegressionWithFixedParameter(coefficients=self.coefficients, intercept=self.intercept)


def create_linear_regressor_with_given_parameters(
coefficients: np.ndarray, intercept: float = 0, **kwargs
) -> SklearnRegressionModel:
linear_model = LinearRegression(**kwargs)
linear_model.coef_ = coefficients
linear_model.intercept_ = intercept

return SklearnRegressionModel(linear_model)
) -> LinearRegressionWithFixedParameter:
return LinearRegressionWithFixedParameter(np.array(coefficients), intercept)


def create_linear_regressor(**kwargs) -> SklearnRegressionModel:
Expand Down
28 changes: 27 additions & 1 deletion tests/gcm/test_ml.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import numpy as np
import pytest
from flaky import flaky
from pytest import approx
from sklearn.linear_model import LogisticRegression

from dowhy.gcm.ml import SklearnClassificationModel, create_linear_regressor, create_logistic_regression_classifier
from dowhy.gcm.ml import (
SklearnClassificationModel,
create_linear_regressor,
create_linear_regressor_with_given_parameters,
create_logistic_regression_classifier,
)


@flaky(max_runs=5)
Expand Down Expand Up @@ -78,3 +84,23 @@ def test_when_cloning_sklearn_classification_model_then_returns_a_cloned_object(
assert isinstance(cloned_mdl.sklearn_model, LogisticRegression)
assert mdl != cloned_mdl
assert cloned_mdl.sklearn_model != logistic_regression_model


def test_when_using_linear_regressor_with_given_parameters_then_fit_does_not_override_parameters():
mdl = create_linear_regressor_with_given_parameters([1, 2, 3], 4)

assert mdl.coefficients == pytest.approx([1, 2, 3])
assert mdl.intercept == 4

mdl.fit(np.random.normal(0, 1, (100, 3)), np.arange(100))

assert mdl.coefficients == pytest.approx([1, 2, 3])
assert mdl.intercept == 4


def test_when_predict_with_linear_regressor_with_given_parameters_then_returns_expected_results():
mdl = create_linear_regressor_with_given_parameters([2], 4)

assert mdl.predict(np.array([0])) == approx(4)
assert mdl.predict(np.array([1])) == approx(6)
assert mdl.predict(np.array([0, 1, 2, 3])) == approx([4, 6, 8, 10])

0 comments on commit b041898

Please sign in to comment.