From 0ab2e078dc5a2fc68632e18f552346174d979425 Mon Sep 17 00:00:00 2001 From: dafnapension Date: Fri, 17 May 2024 12:38:11 +0300 Subject: [PATCH] fixed CI errors Signed-off-by: dafnapension --- src/unitxt/metrics.py | 6 +++++- tests/library/test_metrics.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index 3207f7c239..0631edd530 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -427,7 +427,7 @@ def metric(sample: List[Dict[str, Any]]): # resample the instance scores, and then return the global score each time scores = numpy.apply_along_axis( - lambda x: metric(x), + lambda x: metric(sample=[instances[i] for i in x]), axis=axis, arr=arr, ) @@ -564,6 +564,8 @@ def compute_stream_score( def compute_stream_score_version_for_ci( self, instances: List[Dict[str, Any]], score_name: str ) -> dict: + if score_name is None or not isinstance(score_name, str): + score_name = self.main_score full_score = self.compute_stream_score(instances, [score_name]) return full_score[score_name] @@ -882,6 +884,8 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato global_score.update( self.compute_stream_score(instances=instances, score_names=self.score_names) ) + global_score["score"] = global_score[self.main_score] + global_score["score_name"] = self.main_score ci_fields = ( list(set(self.ci_scores)) diff --git a/tests/library/test_metrics.py b/tests/library/test_metrics.py index 9fcd4127cc..211c075c6b 100644 --- a/tests/library/test_metrics.py +++ b/tests/library/test_metrics.py @@ -215,7 +215,7 @@ def test_accuracy_max_aggregation(self): for key, value in global_result.items() if key in expected_global_result } - self.assertDictEqual(expected_global_result, global_result) + self.assertDictEqual(global_result, expected_global_result) instance_targets = [ {"accuracy": 0.0, "score": 0.0, "score_name": "accuracy"},