|
1 | 1 | import numpy as np
|
| 2 | +import pytest |
2 | 3 | from flaky import flaky
|
3 | 4 | from pytest import approx
|
4 | 5 | from sklearn.linear_model import LogisticRegression
|
5 | 6 |
|
6 |
| -from dowhy.gcm.ml import SklearnClassificationModel, create_linear_regressor, create_logistic_regression_classifier |
| 7 | +from dowhy.gcm.ml import ( |
| 8 | + SklearnClassificationModel, |
| 9 | + create_linear_regressor, |
| 10 | + create_linear_regressor_with_given_parameters, |
| 11 | + create_logistic_regression_classifier, |
| 12 | +) |
7 | 13 |
|
8 | 14 |
|
9 | 15 | @flaky(max_runs=5)
|
@@ -78,3 +84,23 @@ def test_when_cloning_sklearn_classification_model_then_returns_a_cloned_object(
|
78 | 84 | assert isinstance(cloned_mdl.sklearn_model, LogisticRegression)
|
79 | 85 | assert mdl != cloned_mdl
|
80 | 86 | assert cloned_mdl.sklearn_model != logistic_regression_model
|
| 87 | + |
| 88 | + |
| 89 | +def test_when_using_linear_regressor_with_given_parameters_then_fit_does_not_override_parameters(): |
| 90 | + mdl = create_linear_regressor_with_given_parameters([1, 2, 3], 4) |
| 91 | + |
| 92 | + assert mdl.coefficients == pytest.approx([1, 2, 3]) |
| 93 | + assert mdl.intercept == 4 |
| 94 | + |
| 95 | + mdl.fit(np.random.normal(0, 1, (100, 3)), np.arange(100)) |
| 96 | + |
| 97 | + assert mdl.coefficients == pytest.approx([1, 2, 3]) |
| 98 | + assert mdl.intercept == 4 |
| 99 | + |
| 100 | + |
| 101 | +def test_when_predict_with_linear_regressor_with_given_parameters_then_returns_expected_results(): |
| 102 | + mdl = create_linear_regressor_with_given_parameters([2], 4) |
| 103 | + |
| 104 | + assert mdl.predict(np.array([0])) == approx(4) |
| 105 | + assert mdl.predict(np.array([1])) == approx(6) |
| 106 | + assert mdl.predict(np.array([0, 1, 2, 3])) == approx([4, 6, 8, 10]) |
0 commit comments