-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Major code refactor to unify quasi experiment classes #381
Changes from 67 commits
df1989f
d36e289
0945fc4
d0d3bc3
f87577b
e82325d
4ccea2c
119f749
841cca4
e5a07d9
7b1e600
48cf9b5
5c2b103
b23b373
0435456
ac2389d
c6ae453
ff5122f
291dc47
4d10175
4442d5b
3757199
6c4e43c
ae7c405
a85ccfb
00c1290
5b5ccd2
c3df3eb
2fb344a
fd74658
41bf080
c77de98
ea2b859
b0b4539
b0dc8c6
0af9bfb
28f3b07
e84c199
b0eabff
1b26499
133ee3b
5c07e4d
224ec84
100784f
d6e058c
dab5824
4b8141d
47d4479
70e58f8
c080fa9
f915f77
6822c61
1cdf7c2
fbc4c94
fa640b8
4edc6d5
38b0e68
6a9214b
0cf4f46
60cfb2a
121fe46
cc62438
644cf6b
dede64a
02dacb2
3f9763a
01ce582
6f6fade
66f22aa
77b0b2b
9bc3d25
e0b0847
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you interested in the imports like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm unclear on this comment. See a response below #381 (comment) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2024 The PyMC Labs Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright 2024 The PyMC Labs Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Base class for quasi experimental designs. | ||
""" | ||
|
||
from abc import abstractmethod | ||
|
||
from sklearn.base import RegressorMixin | ||
|
||
from causalpy.pymc_models import PyMCModel | ||
from causalpy.skl_models import create_causalpy_compatible_class | ||
|
||
|
||
class BaseExperiment: | ||
"""Base class for quasi experimental designs.""" | ||
wd60622 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
supports_bayes: bool | ||
supports_ols: bool | ||
|
||
def __init__(self, model=None): | ||
# Ensure we've made any provided Scikit Learn model (as identified as being type | ||
# RegressorMixin) compatible with CausalPy by appending our custom methods. | ||
if isinstance(model, RegressorMixin): | ||
model = create_causalpy_compatible_class(model) | ||
|
||
if model is not None: | ||
self.model = model | ||
|
||
if isinstance(self.model, PyMCModel) and not self.supports_bayes: | ||
raise ValueError("Bayesian models not supported.") | ||
|
||
if isinstance(self.model, RegressorMixin) and not self.supports_ols: | ||
raise ValueError("OLS models not supported.") | ||
|
||
if self.model is None: | ||
raise ValueError("model not set or passed.") | ||
|
||
@property | ||
def idata(self): | ||
"""Return the InferenceData object of the model. Only relevant for PyMC models.""" | ||
return self.model.idata | ||
|
||
def print_coefficients(self, round_to=None): | ||
"""Ask the model to print its coefficients.""" | ||
self.model.print_coefficients(self.labels, round_to) | ||
|
||
def plot(self, *args, **kwargs) -> tuple: | ||
"""Plot the model. | ||
|
||
Internally, this function dispatches to either `bayesian_plot` or `ols_plot` | ||
depending on the model type. | ||
""" | ||
if isinstance(self.model, PyMCModel): | ||
return self.bayesian_plot(*args, **kwargs) | ||
elif isinstance(self.model, RegressorMixin): | ||
return self.ols_plot(*args, **kwargs) | ||
else: | ||
raise ValueError("Unsupported model type") | ||
|
||
@abstractmethod | ||
def bayesian_plot(self, *args, **kwargs): | ||
"""Abstract method for plotting the model.""" | ||
raise NotImplementedError("bayesian_plot method not yet implemented") | ||
|
||
@abstractmethod | ||
def ols_plot(self, *args, **kwargs): | ||
"""Abstract method for plotting the model.""" | ||
raise NotImplementedError("ols_plot method not yet implemented") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks - did a global find/replace