Skip to content

Commit

Permalink
fixed CI errors
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <dafnashein@yahoo.com>
  • Loading branch information
dafnapension committed May 18, 2024
1 parent 53e51be commit 0ab2e07
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def metric(sample: List[Dict[str, Any]]):

# resample the instance scores, and then return the global score each time
scores = numpy.apply_along_axis(
lambda x: metric(x),
lambda x: metric(sample=[instances[i] for i in x]),
axis=axis,
arr=arr,
)
Expand Down Expand Up @@ -564,6 +564,8 @@ def compute_stream_score(
def compute_stream_score_version_for_ci(
self, instances: List[Dict[str, Any]], score_name: str
) -> dict:
if score_name is None or not isinstance(score_name, str):
score_name = self.main_score

Check warning on line 568 in src/unitxt/metrics.py

View check run for this annotation

Codecov / codecov/patch

src/unitxt/metrics.py#L568

Added line #L568 was not covered by tests
full_score = self.compute_stream_score(instances, [score_name])
return full_score[score_name]

Expand Down Expand Up @@ -882,6 +884,8 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato
global_score.update(
self.compute_stream_score(instances=instances, score_names=self.score_names)
)
global_score["score"] = global_score[self.main_score]
global_score["score_name"] = self.main_score

ci_fields = (
list(set(self.ci_scores))
Expand Down
2 changes: 1 addition & 1 deletion tests/library/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_accuracy_max_aggregation(self):
for key, value in global_result.items()
if key in expected_global_result
}
self.assertDictEqual(expected_global_result, global_result)
self.assertDictEqual(global_result, expected_global_result)

instance_targets = [
{"accuracy": 0.0, "score": 0.0, "score_name": "accuracy"},
Expand Down

0 comments on commit 0ab2e07

Please sign in to comment.