Skip to content

Commit 36bdc47

Browse files
committed
Small refactor to comparison plots for easy inheritance in premium package
Small refactor to comparison plots for easy inheritance in premium package
1 parent e8a3bbf commit 36bdc47

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

nannyml/plots/blueprints/comparisons.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -633,10 +633,11 @@ def render_metric_display_name(metric_display_name: Union[str, Tuple]):
633633
class ResultCompareMixin:
634634
def compare(self, other: Result):
635635
return ResultComparison(
636-
self, other, title=self._get_title(other), plot_kwargs=_get_plot_kwargs(self, other) # type: ignore
636+
self, other, title=self.get_title(other), plot_kwargs=_get_plot_kwargs(self, other) # type: ignore
637637
)
638638

639-
def _get_title(self, other: Result):
639+
@property
640+
def titles(self) -> Dict[type, str]:
640641
from nannyml.data_quality.missing.result import Result as MissingValueResult
641642
from nannyml.data_quality.unseen.result import Result as UnseenValuesResult
642643
from nannyml.drift.multivariate.data_reconstruction import Result as DataReconstructionDriftResult
@@ -649,7 +650,7 @@ def _get_title(self, other: Result):
649650
from nannyml.stats.std import Result as StatsStdResult
650651
from nannyml.stats.sum import Result as StatsSumResult
651652

652-
_result_title_names: Dict[type, Any] = {
653+
_titles: Dict[type, Any] = {
653654
UnivariateDriftResult: "Univariate drift",
654655
DataReconstructionDriftResult: "Multivariate drift",
655656
RealizedPerformanceResult: "Realized performance",
@@ -663,7 +664,10 @@ def _get_title(self, other: Result):
663664
StatsSumResult: "Statistics, Sum",
664665
}
665666

666-
return f"<b>{_result_title_names[type(self)]}</b> vs. <b>{_result_title_names[type(other)]}</b>"
667+
return _titles
668+
669+
def get_title(self, other: Result):
670+
return f"<b>{self.titles[type(self)]}</b> vs. <b>{self.titles[type(other)]}</b>"
667671

668672

669673
def _get_plot_kwargs(result: Result, other: Result) -> Dict[str, Any]:

tests/performance_estimation/CBPE/test_cbpe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,10 @@ def test_cbpe_defaults_to_isotonic_calibrator_when_none_given(): # noqa: D103
227227

228228
def test_cbpe_uses_custom_calibrator_when_provided(): # noqa: D103
229229
class TestCalibrator(Calibrator):
230-
def fit(self, y_pred_proba: np.ndarray, y_true: np.ndarray):
230+
def fit(self, y_pred_proba: np.ndarray, y_true: np.ndarray, *args, **kwargs):
231231
pass
232232

233-
def calibrate(self, y_pred_proba: np.ndarray):
233+
def calibrate(self, y_pred_proba: np.ndarray, *args, **kwargs):
234234
pass
235235

236236
estimator = CBPE(

0 commit comments

Comments
 (0)