generated from maycuatroi/python-project-template
-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fcb109c
commit 91ae5a5
Showing
5 changed files
with
84 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,14 @@ | ||
from .accuracy import Accuracy | ||
from .auc import AUC | ||
from .base_metric import BaseMetric | ||
from .slope import Slope | ||
from .rmse import RMSE | ||
from .mae import MAE | ||
from .error_ave import ErrorAve | ||
from .r_squared import RSquared | ||
from .error_ave import ErrorAve | ||
from .error_std import ErrorStd | ||
from .accuracy import Accuracy | ||
from .precision import Precision | ||
from .f1_score import F1Score | ||
from .mae import MAE | ||
from .precision import Precision | ||
from .r_squared import RSquared | ||
from .recall import Recall | ||
from .rmse import RMSE | ||
from .roc_curve import RocCurve | ||
from .slope import Slope |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import matplotlib.pyplot as plt | ||
from sklearn import metrics | ||
|
||
from evo_science.entities.metrics.base_binary_classify_metric import ( | ||
BaseBinaryClassifyMetric, | ||
) | ||
|
||
|
||
class AUC(BaseBinaryClassifyMetric): | ||
name = "Area Under the Curve" | ||
|
||
def _on_init(self, threshold=0.5, plot=False, **kwargs): | ||
self.threshold = threshold | ||
self.plot = plot | ||
|
||
def _calculate_np(self, y_true, y_pred): | ||
auc = metrics.roc_auc_score(y_true, y_pred) | ||
if self.plot: | ||
fpr, tpr, _ = metrics.roc_curve(y_true, y_pred) | ||
plt.figure() | ||
plt.plot( | ||
fpr, | ||
tpr, | ||
color="darkorange", | ||
lw=2, | ||
label="ROC curve (area = %0.2f)" % auc, | ||
) | ||
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--") | ||
plt.xlim([0.0, 1.0]) | ||
plt.ylim([0.0, 1.05]) | ||
plt.xlabel("False Positive Rate") | ||
plt.ylabel("True Positive Rate") | ||
plt.title("Receiver Operating Characteristic Curve") | ||
plt.legend(loc="lower right") | ||
plt.show() | ||
return auc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from sklearn.metrics import roc_curve | ||
|
||
from evo_science.entities.metrics.base_binary_classify_metric import ( | ||
BaseBinaryClassifyMetric, | ||
) | ||
|
||
|
||
class RocCurve(BaseBinaryClassifyMetric): | ||
name = "Roc Curve" | ||
|
||
def _on_init(self, threshold=0.5, plot=False, **kwargs): | ||
self.threshold = threshold | ||
self.plot = plot | ||
|
||
def _calculate_np(self, y_true: np.array, y_pred: np.array): | ||
fpr, tpr, thresholds = roc_curve(y_true, y_pred) | ||
if self.plot: | ||
plt.plot(fpr, tpr) | ||
plt.xlabel("False Positive Rate") | ||
plt.ylabel("True Positive Rate") | ||
plt.title("ROC Curve") | ||
plt.show() | ||
|
||
return fpr.mean(), tpr.mean(), thresholds.mean() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters