-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #431 from MannLabs/add-two-step-classifier
Add two step classifier
- Loading branch information
Showing
7 changed files
with
693 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import logging | ||
|
||
import numpy as np | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.preprocessing import StandardScaler | ||
|
||
from alphadia.fdrexperimental import Classifier | ||
|
||
logger = logging.getLogger() | ||
|
||
|
||
class LogisticRegressionClassifier(Classifier): | ||
def __init__(self) -> None: | ||
"""Binary classifier using a logistic regression model.""" | ||
self.scaler = StandardScaler() | ||
self.model = LogisticRegression() | ||
self._fitted = False | ||
|
||
@property | ||
def fitted(self) -> bool: | ||
return self._fitted | ||
|
||
def fit(self, x: np.ndarray, y: np.ndarray) -> None: | ||
"""Fit the classifier to the data. | ||
Parameters | ||
---------- | ||
x : np.array, dtype=float | ||
Training data of shape (n_samples, n_features). | ||
y : np.array, dtype=int | ||
Target values of shape (n_samples,) or (n_samples, n_classes). | ||
""" | ||
x_scaled = self.scaler.fit_transform(x) | ||
self.model.fit(x_scaled, y) | ||
self._fitted = True | ||
|
||
def predict(self, x: np.ndarray) -> np.ndarray: | ||
"""Predict the class of the data. | ||
Parameters | ||
---------- | ||
x : np.array, dtype=float | ||
Data of shape (n_samples, n_features). | ||
Returns | ||
------- | ||
y : np.array, dtype=float | ||
Predicted class probabilities of shape (n_samples, n_classes). | ||
""" | ||
x_scaled = self.scaler.transform(x) | ||
return self.model.predict(x_scaled) | ||
|
||
def predict_proba(self, x: np.ndarray) -> np.ndarray: | ||
"""Predict the class probabilities of the data. | ||
Parameters | ||
---------- | ||
x : np.array, dtype=float | ||
Data of shape (n_samples, n_features). | ||
Returns | ||
------- | ||
y : np.array, dtype=float | ||
Predicted class probabilities of shape (n_samples, n_classes). | ||
""" | ||
x_scaled = self.scaler.transform(x) | ||
return self.model.predict_proba(x_scaled) | ||
|
||
def to_state_dict(self) -> dict: | ||
"""Return the state of the classifier as a dictionary. | ||
Returns | ||
------- | ||
dict : dict | ||
Dictionary containing the state of the classifier. | ||
""" | ||
state_dict = {"_fitted": self._fitted} | ||
|
||
if self._fitted: | ||
state_dict.update( | ||
{ | ||
"scaler_mean": self.scaler.mean_, | ||
"scaler_var": self.scaler.var_, | ||
"scaler_scale": self.scaler.scale_, | ||
"scaler_n_samples_seen": self.scaler.n_samples_seen_, | ||
"model_coef": self.model.coef_, | ||
"model_intercept": self.model.intercept_, | ||
"model_classes": self.model.classes_, | ||
"is_fitted": self._fitted, | ||
} | ||
) | ||
|
||
return state_dict | ||
|
||
def from_state_dict(self, state_dict: dict) -> None: | ||
"""Load the state of the classifier from a dictionary. | ||
Parameters | ||
---------- | ||
dict : dict | ||
Dictionary containing the state of the classifier. | ||
""" | ||
self._fitted = state_dict["_fitted"] | ||
|
||
if self._fitted: | ||
self.scaler = StandardScaler() | ||
self.scaler.mean_ = np.array(state_dict["scaler_mean"]) | ||
self.scaler.var_ = np.array(state_dict["scaler_var"]) | ||
self.scaler.scale_ = np.array(state_dict["scaler_scale"]) | ||
self.scaler.n_samples_seen_ = np.array(state_dict["scaler_n_samples_seen"]) | ||
|
||
self.model = LogisticRegression() | ||
self.model.coef_ = np.array(state_dict["model_coef"]) | ||
self.model.intercept_ = np.array(state_dict["model_intercept"]) | ||
self.model.classes_ = np.array(state_dict["model_classes"]) |
Oops, something went wrong.