Skip to content

Commit

Permalink
fixed CI errors
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <dafnashein@yahoo.com>
  • Loading branch information
dafnapension committed May 31, 2024
1 parent 70d8788 commit 24e354f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def metric(sample: List[Dict[str, Any]]):

# resample the instance scores, and then return the global score each time
scores = numpy.apply_along_axis(
lambda x: metric(x),
lambda x: metric(sample=[instances[i] for i in x]),
axis=axis,
arr=arr,
)
Expand Down Expand Up @@ -566,6 +566,8 @@ def compute_stream_score(
def compute_stream_score_version_for_ci(
self, instances: List[Dict[str, Any]], score_name: str
) -> dict:
if score_name is None or not isinstance(score_name, str):
score_name = self.main_score
full_score = self.compute_stream_score(instances, [score_name])
return full_score[score_name]

Expand Down Expand Up @@ -886,6 +888,8 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato
global_score.update(
self.compute_stream_score(instances=instances, score_names=self.score_names)
)
global_score["score"] = global_score[self.main_score]
global_score["score_name"] = self.main_score

ci_fields = (
list(set(self.ci_scores))
Expand Down
2 changes: 1 addition & 1 deletion tests/library/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_accuracy_max_aggregation(self):
for key, value in global_result.items()
if key in expected_global_result
}
self.assertDictEqual(expected_global_result, global_result)
self.assertDictEqual(global_result, expected_global_result)

instance_targets = [
{"accuracy": 0.0, "score": 0.0, "score_name": "accuracy"},
Expand Down

0 comments on commit 24e354f

Please sign in to comment.