Skip to content

Commit

Permalink
aggregators grow to multiple score names
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <dafnashein@yahoo.com>
  • Loading branch information
dafnapension committed May 31, 2024
1 parent f251d6a commit 269ea31
Showing 1 changed file with 51 additions and 55 deletions.
106 changes: 51 additions & 55 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from collections import Counter, defaultdict
from copy import deepcopy
from dataclasses import field
from operator import itemgetter
from statistics import mean
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple

Expand Down Expand Up @@ -341,7 +340,7 @@ def score_based_confidence_interval(
# if aggregation_func is None, we simply take the mean of the resampled instance scores
# otherwise, the aggregation_func needs to be applied AFTER resampling the instances;
# that is, re-form the groups, calculate the function, and take the mean of the group scores
aggregation_func = self.average_item_scores
aggregation_func = AverageItemsAggregator.aggregate_one_group_score_names
for score_name in score_names:
# If all computed instance level scores are the same, there is no point in computing
# confidence intervals. So skip to the next score.
Expand All @@ -354,8 +353,8 @@ def statistic(arr, axis, score_name=score_name):
# iterate over the rows and compute the metric on each resampling
scores = numpy.apply_along_axis(
lambda resampled_instances: aggregation_func(
resampled_instances, score_name
),
instances=resampled_instances, score_names=score_names
)[score_name],
axis=axis,
arr=arr,
)
Expand Down Expand Up @@ -420,9 +419,9 @@ def statistic(arr, axis):
# iterate over the rows and compute the metric on each resampling
def metric(sample: List[Dict[str, Any]]):
try:
return self.compute_stream_score_version_for_ci(
instances=sample, score_name=score_name
)
return self.compute_stream_score(
instances=sample, score_names=[score_name]
)[score_name]
except Exception as e:
# this happens in edge cases, for example, when the sampling creates a
# sample where all strings are empty and this fails bleu.
Expand Down Expand Up @@ -559,15 +558,6 @@ def compute_stream_score(
groups_global_scores=groups_global_scores
)

# variation for ci - score_based_confidence_interval, that returns a single float, in analogy with
# average_item_scores
def compute_stream_score_version_for_ci(
self, instances: List[Dict[str, Any]], score_name: str
) -> dict:
assert score_name is not None and isinstance(score_name, str)
full_score = self.compute_stream_score(instances, [score_name])
return full_score[score_name]

def ci_from_groups_global_scores(
self, groups_global_scores: dict, score_prefix: str = ""
) -> dict:
Expand All @@ -583,7 +573,7 @@ def ci_from_groups_global_scores(
instances=to_sample_from,
score_names=list(set(self.ci_scores)),
ci_score_prefix=score_prefix,
aggregation_func=self.average_item_scores,
aggregation_func=AverageItemsAggregator.aggregate_one_group_score_names,
)


Expand Down Expand Up @@ -761,29 +751,36 @@ def aggregate_one_group(

class Aggregator(Artifact):
@abstractmethod
# one group can also be the whole stream
def aggregate_one_group_no_named(
self, instances: List[Dict[str, Any]], score_name: str
) -> Any:
def aggregate_one_group_score_names(
self, instances: List[Dict[str, Any]], score_names: List[str]
):
pass


class AverageItemsAggregator(Aggregator):
def aggregate_one_group_no_named(
self, instances: List[Dict[str, Any]], score_name: str
@staticmethod
def aggregate_one_group_score_names(
instances: List[Dict[str, Any]], score_names: List[str]
) -> float:
return MetricWithConfidenceInterval.average_item_scores(
instances=instances, score_name=score_name
)
return {
score_name: MetricWithConfidenceInterval.average_item_scores(
instances=instances, score_name=score_name
)
for score_name in score_names
}


class MaxItemsAggregator(Aggregator):
def aggregate_one_group_no_named(
self, instances: List[Dict[str, Any]], score_name: str
@staticmethod
def aggregate_one_group_score_names(
instances: List[Dict[str, Any]], score_names: List[str]
) -> float:
return MetricWithConfidenceInterval.max_item_scores(
instances=instances, score_name=score_name
)
return {
score_name: MetricWithConfidenceInterval.max_item_scores(
instances=instances, score_name=score_name
)
for score_name in score_names
}


# generate a score that compares the groups' scores of two subsets of the input stream: group 'control' and group 'comparison'
Expand All @@ -793,8 +790,8 @@ class ControlComparisonAggregator(Aggregator):
default_factory=lambda: PerformanceDropRateFloatsReducer()
)

def aggregate_one_group_no_named(
self, instances: List[Dict[str, Any]], score_name: str
def aggregate_one_group_score_names(
self, instances: List[Dict[str, Any]], score_names: List[str]
) -> float:
pair_of_groups = {
side: [
Expand All @@ -804,14 +801,19 @@ def aggregate_one_group_no_named(
]
for side in ["control", "comparison"]
}
pair_of_floats = {
side: [
instance["score"]["instance"][score_name]
for instance in pair_of_groups[side]
]
for side in ["control", "comparison"]
}
return self.control_comparison_floats_calculator.reduce_floats(pair_of_floats)
to_return = {}
for score_name in score_names:
pair_of_floats = {
side: [
instance["score"]["instance"][score_name]
for instance in pair_of_groups[side]
]
for side in ["control", "comparison"]
}
to_return[
score_name
] = self.control_comparison_floats_calculator.reduce_floats(pair_of_floats)
return to_return


class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):
Expand Down Expand Up @@ -875,7 +877,7 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato
instances=instances,
score_names=ci_fields,
ci_score_prefix="",
aggregation_func=self.compute_stream_score_version_for_ci,
aggregation_func=self.compute_stream_score,
)

global_score.update(confidence_interval)
Expand All @@ -902,12 +904,9 @@ def aggregate_one_group(
]
instances = filtered_instances

return {
score_name: self.aggregator.aggregate_one_group_no_named(
instances=instances, score_name=score_name
)
for score_name in score_names
}
return self.aggregator.aggregate_one_group_score_names(
instances=instances, score_names=score_names
)


class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
Expand Down Expand Up @@ -1021,7 +1020,7 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato
instances=instances,
score_names=list(set(self.ci_scores)),
ci_score_prefix=self.prefix,
aggregation_func=self.compute_stream_score_version_for_ci,
aggregation_func=self.compute_stream_score,
)
else:
# dress the individual groups's score like instance scores: for each group generate
Expand Down Expand Up @@ -1090,12 +1089,9 @@ def aggregate_one_group(
]
instances = filtered_instances

return {
score_name: self.aggregator.aggregate_one_group_no_named(
instances=instances, score_name=score_name
)
for score_name in score_names
}
return self.aggregator.aggregate_one_group_score_names(
instances=instances, score_names=score_names
)

@abstractmethod
def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
Expand Down

0 comments on commit 269ea31

Please sign in to comment.