diff --git a/src/krisi/evaluate/group.py b/src/krisi/evaluate/group.py index e67456b3..1ddc8573 100644 --- a/src/krisi/evaluate/group.py +++ b/src/krisi/evaluate/group.py @@ -198,7 +198,4 @@ def evaluate_over_time( ) def __str__(self) -> str: - return "\n".join([metric.__str__() for metric in self.metrics]) - - def __repr__(self) -> str: - return "\n".join([metric.__repr__() for metric in self.metrics]) + return " - ".join([metric.key for metric in self.metrics]) diff --git a/src/krisi/evaluate/metric.py b/src/krisi/evaluate/metric.py index ffd4d6c9..00a78865 100644 --- a/src/krisi/evaluate/metric.py +++ b/src/krisi/evaluate/metric.py @@ -20,7 +20,7 @@ TargetsDS, WeightsDS, ) -from krisi.report.console import print_metric +from krisi.report.console import get_metric_string from krisi.report.type import InteractiveFigure, PlotDefinition, plotly_interactive from krisi.utils.iterable_helpers import ( check_iterable_with_number, @@ -112,12 +112,11 @@ def __setitem__(self, key: str, item: Any) -> None: def __getitem__(self, key: str) -> Any: return getattr(self, key, "Unknown Field") - def __str__(self) -> str: - return print_metric(self) + def __str__(self, repr: bool = False) -> str: + return get_metric_string(self, repr=repr) def __repr__(self) -> str: - print(print_metric(self, repr=True)) - return super().__repr__() + return super().__repr__()[:-1] + f" - {self.__str__(True)}>" def _evaluation(self, *args, **kwargs) -> Metric: if self.calculation == Calculation.rolling: diff --git a/src/krisi/evaluate/scorecard.py b/src/krisi/evaluate/scorecard.py index 5a0929d1..520d95ea 100644 --- a/src/krisi/evaluate/scorecard.py +++ b/src/krisi/evaluate/scorecard.py @@ -630,6 +630,13 @@ def generate_report( ) report.generate_launch() + def __repr__(self): + md = self.__dict__["metadata"] + return ( + super().__repr__()[:-1] + + f"\n{md.model_name:>40s} | {md.dataset_name} \n{md.project_name:>40s} | {self.dataset_type.value}\n\n{get_minimal_summary(self, dataframe=False)}>" + ) + def get_rolling_diagrams(obj: "ScoreCard") -> List[List[InteractiveFigure]]: return [ diff --git a/src/krisi/report/console.py b/src/krisi/report/console.py index 6f119388..fb5fd7b1 100644 --- a/src/krisi/report/console.py +++ b/src/krisi/report/console.py @@ -21,7 +21,7 @@ from krisi.utils.printing import bold -def print_metric(obj: "Metric", repr: bool = False) -> str: +def get_metric_string(obj: "Metric", repr: bool = False) -> str: if repr: return f"{obj.result} | {obj.name}" hyperparams = ""