Skip to content

Commit

Permalink
- Add metrics AUC, Roc curve
Browse files Browse the repository at this point in the history
  • Loading branch information
maycuatroi committed Mar 21, 2024
1 parent fcb109c commit 91ae5a5
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 7 deletions.
15 changes: 9 additions & 6 deletions evo_science/entities/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from .accuracy import Accuracy
from .auc import AUC
from .base_metric import BaseMetric
from .slope import Slope
from .rmse import RMSE
from .mae import MAE
from .error_ave import ErrorAve
from .r_squared import RSquared
from .error_ave import ErrorAve
from .error_std import ErrorStd
from .accuracy import Accuracy
from .precision import Precision
from .f1_score import F1Score
from .mae import MAE
from .precision import Precision
from .r_squared import RSquared
from .recall import Recall
from .rmse import RMSE
from .roc_curve import RocCurve
from .slope import Slope
36 changes: 36 additions & 0 deletions evo_science/entities/metrics/auc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import matplotlib.pyplot as plt
from sklearn import metrics

from evo_science.entities.metrics.base_binary_classify_metric import (
BaseBinaryClassifyMetric,
)


class AUC(BaseBinaryClassifyMetric):
name = "Area Under the Curve"

def _on_init(self, threshold=0.5, plot=False, **kwargs):
self.threshold = threshold
self.plot = plot

def _calculate_np(self, y_true, y_pred):
auc = metrics.roc_auc_score(y_true, y_pred)
if self.plot:
fpr, tpr, _ = metrics.roc_curve(y_true, y_pred)
plt.figure()
plt.plot(
fpr,
tpr,
color="darkorange",
lw=2,
label="ROC curve (area = %0.2f)" % auc,
)
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic Curve")
plt.legend(loc="lower right")
plt.show()
return auc
1 change: 1 addition & 0 deletions evo_science/entities/metrics/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


class Recall(BaseBinaryClassifyMetric):
name = "Recall"

def _calculate_np(self, y_true: np.array, y_pred: np.array):
y_pred = self._binary_threshold(y_pred)
Expand Down
26 changes: 26 additions & 0 deletions evo_science/entities/metrics/roc_curve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_curve

from evo_science.entities.metrics.base_binary_classify_metric import (
BaseBinaryClassifyMetric,
)


class RocCurve(BaseBinaryClassifyMetric):
name = "Roc Curve"

def _on_init(self, threshold=0.5, plot=False, **kwargs):
self.threshold = threshold
self.plot = plot

def _calculate_np(self, y_true: np.array, y_pred: np.array):
fpr, tpr, thresholds = roc_curve(y_true, y_pred)
if self.plot:
plt.plot(fpr, tpr)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.show()

return fpr.mean(), tpr.mean(), thresholds.mean()
13 changes: 12 additions & 1 deletion examples/titanic_survival_prediction/linear_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,18 @@ def example_lr_model():

model.fit(x=x, y=y)
model.evaluate(
x=x, y=y, metrics=[Slope, ErrorStd, Accuracy(threshold=0.5), Precision, F1Score]
x=x,
y=y,
metrics=[
Slope,
ErrorStd,
Accuracy(threshold=0.5),
Precision,
Recall,
F1Score,
RocCurve,
AUC,
],
)
model.calculate_coefficients(x=x)

Expand Down

0 comments on commit 91ae5a5

Please sign in to comment.