Skip to content

Commit

Permalink
moved filtering out the aggregator into the metric itself
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <dafnashein@yahoo.com>
  • Loading branch information
dafnapension committed May 31, 2024
1 parent 24e354f commit f251d6a
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 109 deletions.
189 changes: 80 additions & 109 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@
warnings.filterwarnings("ignore", category=DegenerateDataWarning)


warnings.filterwarnings("ignore", category=DegenerateDataWarning)


def abstract_factory():
return {}

Expand Down Expand Up @@ -240,6 +237,11 @@ class MetricWithConfidenceInterval(Metric):
confidence_level: float = 0.95
ci_scores: List[str] = None

# Whether to filter the instances before employing the aggregation for the global score.
# useful when the user is interested in only a subset of the instances, defined
# by a condition evaluated on each instance.
filter_by_condition: FilterByCondition = None

# whether to split the instances to groups, aggregate over each
# and use the (non-weighted) averaged global scores of the groups-
# as the whole stream global score.
Expand Down Expand Up @@ -464,10 +466,11 @@ def metric(sample: List[Dict[str, Any]]):
# aggregate over one group, which can be the whole stream when split_to_groups_by is None, for metric evaluation.
# It take into account: filtering, or splitting to control and compqrison, but does not assume further
# splitting of the input instances into groups by split_to_groups_by
# returns a dictionary of named scores
@abstractmethod
def aggregate_one_group(
self, instances: List[Dict[str, Any]], score_names: List[str]
) -> Any:
) -> dict:
pass

# This does deal with split_to_groups_by, when is not None. Returned is a dict whose keys
Expand Down Expand Up @@ -522,25 +525,21 @@ def average_groups_global_scores(self, groups_global_scores: dict) -> dict:
result = defaultdict(list)

fields_to_average = set()
for _, group_global_score in groups_global_scores.items():
if isinstance(group_global_score, dict):
# the score of the current group is a dict and not nan
for k, v in group_global_score.items():
if isinstance(v, str):
result[k] = v
elif isinstance(v, float):
fields_to_average.add(k)
result[k].append(v)
else:
assert isoftype(
v, List[float]
), f"unexpected type of score {v} in group's score field {k}"
result[k].append(np.nanmean(v))
fields_to_average.add(k)
else:
assert np.isnan(
group_global_score
), "group global score should be either a dict or np.nan"
for group_name in sorted(groups_global_scores.keys()):
group_global_score = groups_global_scores[group_name]
# the score of the current group is a dict
for k, v in group_global_score.items():
if isinstance(v, str):
result[k] = v
elif isinstance(v, float):
fields_to_average.add(k)
result[k].append(v)
else:
assert isoftype(
v, List[float]
), f"unexpected type of score {v} in group's score field {k}"
result[k].append(np.nanmean(v))
fields_to_average.add(k)

for k, v in result.items():
if k in fields_to_average:
Expand All @@ -552,8 +551,7 @@ def average_groups_global_scores(self, groups_global_scores: dict) -> dict:
def compute_stream_score(
self, instances: List[Dict[str, Any]], score_names: List[str]
) -> dict:
if score_names is None:
score_names = [self.main_score]
assert score_names is not None
groups_global_scores = self.aggregate_stream_to_groups_scores(
instances=instances, score_names=score_names
)
Expand All @@ -566,8 +564,7 @@ def compute_stream_score(
def compute_stream_score_version_for_ci(
self, instances: List[Dict[str, Any]], score_name: str
) -> dict:
if score_name is None or not isinstance(score_name, str):
score_name = self.main_score
assert score_name is not None and isinstance(score_name, str)
full_score = self.compute_stream_score(instances, [score_name])
return full_score[score_name]

Expand Down Expand Up @@ -605,8 +602,6 @@ class GlobalMetric(StreamOperator, MetricWithConfidenceInterval):
# calculate scores for single instances
process_single_instances = True

# whether to filter instances before aggregating over them
filter_by_condition: FilterByCondition = None
# generate a score that compares the groups' scores of two subsets of the input stream: group 'control' and group 'comparison'
control_comparison: Dict[Literal["control", "comparison"], FilterByCondition] = None
control_comparison_floats_calculator: ControlComparisonFloatsReducer = Field(
Expand Down Expand Up @@ -725,7 +720,7 @@ def compute(

def aggregate_one_group(
self, instances: List[Dict[str, Any]], score_names: List[str]
) -> Any:
) -> dict:
# for Global metric, only self.main_score counts
if self.filter_by_condition is not None:
filtered_instances = [
Expand Down Expand Up @@ -767,14 +762,14 @@ def aggregate_one_group(
class Aggregator(Artifact):
@abstractmethod
# one group can also be the whole stream
def aggregate_one_group(
def aggregate_one_group_no_named(
self, instances: List[Dict[str, Any]], score_name: str
) -> Any:
pass


class AverageItemsAggregator(Aggregator):
def aggregate_one_group(
def aggregate_one_group_no_named(
self, instances: List[Dict[str, Any]], score_name: str
) -> float:
return MetricWithConfidenceInterval.average_item_scores(
Expand All @@ -783,41 +778,22 @@ def aggregate_one_group(


class MaxItemsAggregator(Aggregator):
def aggregate_one_group(
def aggregate_one_group_no_named(
self, instances: List[Dict[str, Any]], score_name: str
) -> float:
return MetricWithConfidenceInterval.max_item_scores(
instances=instances, score_name=score_name
)


# filter instances before aggregating over them
class FilteredAggregator(Aggregator):
filter_by_condition: FilterByCondition

def aggregate_one_group(
self, instances: List[Dict[str, Any]], score_name: str
) -> float:
filtered_instances = [
instance
for instance in instances
if self.filter_by_condition._is_required(instance)
]
if len(filtered_instances) == 0:
return np.nan
return MetricWithConfidenceInterval.average_item_scores(
instances=filtered_instances, score_name=score_name
)


# generate a score that compares the groups' scores of two subsets of the input stream: group 'control' and group 'comparison'
class ControlComparisonAggregator(Aggregator):
control_comparison: Dict[Literal["control", "comparison"], FilterByCondition]
control_comparison_floats_calculator: ControlComparisonFloatsReducer = Field(
default_factory=lambda: PerformanceDropRateFloatsReducer()
)

def aggregate_one_group(
def aggregate_one_group_no_named(
self, instances: List[Dict[str, Any]], score_name: str
) -> float:
pair_of_groups = {
Expand Down Expand Up @@ -850,7 +826,6 @@ class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):
aggregating_function_name: str = "mean"
split_to_groups_by = None

# no split to groups, no filtering, or control-comparison, for now.
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
if self.main_score is None:
self.main_score = "f1"
Expand Down Expand Up @@ -919,16 +894,20 @@ def compute(
def aggregate_one_group(
self, instances: List[Dict[str, Any]], score_names: List[str]
) -> dict:
gs = {}
for score_name in score_names:
gs.update(
{
score_name: self.aggregator.aggregate_one_group(
instances=instances, score_name=score_name
)
}
if self.filter_by_condition is not None:
filtered_instances = [
instance
for instance in instances
if self.filter_by_condition._is_required(instance)
]
instances = filtered_instances

return {
score_name: self.aggregator.aggregate_one_group_no_named(
instances=instances, score_name=score_name
)
return gs
for score_name in score_names
}


class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
Expand Down Expand Up @@ -973,12 +952,11 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
# How to yield one score, float, for each score_name in score_names, from a list of instances: either the whole stream or one group.
# For InstanceMetric, this aggregation is over the instance scores, already sitting in each instance, in subfield
# instance["score"]["instance"], which is a dict mapping score_name to (instance) score value.
# Tyically, aggregating is to be overridden by the subclasses. If None, and not set by subclasses, then for InstanceMetric -
# the defaults set are:
# aggregating_function_name: "mean",
# aggregating_function: MetricWithConfidenceInterval.average_item_scores,
# }
# Tyically, aggregating is to be overridden by the subclasses.
aggregator: Aggregator = Field(default_factory=lambda: AverageItemsAggregator())

# This name is used to prefix the score_name.
# Use to be expanded to global metric and bulk instance too.
aggregating_function_name: str = "mean"

reference_field: str = NonPositionalField(default="references")
Expand Down Expand Up @@ -1104,16 +1082,20 @@ def compute_instance_scores(
def aggregate_one_group(
self, instances: List[Dict[str, Any]], score_names: List[str]
) -> dict:
gs = {}
for score_name in score_names:
gs.update(
{
score_name: self.aggregator.aggregate_one_group(
instances=instances, score_name=score_name
)
}
if self.filter_by_condition is not None:
filtered_instances = [
instance
for instance in instances
if self.filter_by_condition._is_required(instance)
]
instances = filtered_instances

return {
score_name: self.aggregator.aggregate_one_group_no_named(
instances=instances, score_name=score_name
)
return gs
for score_name in score_names
}

@abstractmethod
def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
Expand Down Expand Up @@ -3277,69 +3259,58 @@ class FixedGroupMeanStringContainment(StringContainment):

# take only the (fixed) group mean of baseline or other (paraphrases) scores
class FixedGroupMeanBaselineAccuracy(Accuracy):
filter_by_condition = Field(
default_factory=lambda: FilterByCondition(
values={"task_data/variant_type": ["original"]}, condition="in"
)
)
split_to_groups_by = Field(
default_factory=lambda: SplitByValue(fields=["task_data/group_id"])
)
ci_samples_from_groups_scores = True
aggregating_function_name = "mean_baseline"

aggregator = Field(
default_factory=lambda: FilteredAggregator(
filter_by_condition=FilterByCondition(
values={"task_data/variant_type": ["original"]}, condition="in"
)
)
)


class FixedGroupMeanParaphraseAccuracy(Accuracy):
filter_by_condition = Field(
default_factory=lambda: FilterByCondition(
values={"task_data/variant_type": ["paraphrase"]}, condition="in"
)
)
split_to_groups_by = Field(
default_factory=lambda: SplitByValue(fields=["task_data/group_id"])
)
ci_samples_from_groups_scores = True
aggregating_function_name = "mean_paraphrase"

aggregator = Field(
default_factory=lambda: FilteredAggregator(
filter_by_condition=FilterByCondition(
values={"task_data/variant_type": ["paraphrase"]}, condition="in"
)
)
)
aggregating_function_name = "mean_paraphrase"


# same as above but using StringContainment
class FixedGroupMeanBaselineStringContainment(StringContainment):
filter_by_condition = Field(
default_factory=lambda: FilterByCondition(
values={"task_data/variant_type": ["original"]}, condition="in"
)
)
split_to_groups_by = Field(
default_factory=lambda: SplitByValue(fields=["task_data/group_id"])
)
ci_samples_from_groups_scores = True
aggregating_function_name = "mean_baseline"

aggregator = Field(
default_factory=lambda: FilteredAggregator(
filter_by_condition=FilterByCondition(
values={"task_data/variant_type": ["original"]}, condition="in"
)
)
)


class FixedGroupMeanParaphraseStringContainment(StringContainment):
filter_by_condition = Field(
default_factory=lambda: FilterByCondition(
values={"task_data/variant_type": ["paraphrase"]}, condition="in"
)
)
split_to_groups_by = Field(
default_factory=lambda: SplitByValue(fields=["task_data/group_id"])
)
ci_samples_from_groups_scores = True
aggregating_function_name = "mean_paraphrase"

aggregator = Field(
default_factory=lambda: FilteredAggregator(
filter_by_condition=FilterByCondition(
values={"task_data/variant_type": ["paraphrase"]}, condition="in"
)
)
)


# using PDR
class FixedGroupPDRParaphraseAccuracy(Accuracy):
Expand Down
16 changes: 16 additions & 0 deletions tests/library/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
TokenOverlap,
UnsortedListExactMatch,
)
from unitxt.operators import SplitByValue
from unitxt.test_utils.metrics import apply_metric

from tests.utils import UnitxtTestCase
Expand Down Expand Up @@ -1284,6 +1285,21 @@ def test_grouped_instance_metric_confidence_interval(self):
},
)

global_metric_to_test_on_groups = Rouge()
global_metric_to_test_on_groups.split_to_groups_by = SplitByValue(
fields=["task_data/group_id"]
)
global_metric_to_test_on_groups.ci_samples_from_groups_scores = True
global_metric_to_test_on_groups.ci_scores = [
global_metric_to_test_on_groups.main_score
]
global_metric_to_test_on_groups.prefix = ""
self._test_grouped_instance_confidence_interval(
metric=global_metric_to_test_on_groups,
expected_ci_low=0.15308065714331093,
expected_ci_high=0.7666666666666666,
)

def _test_grouped_instance_confidence_interval(
self,
metric,
Expand Down

0 comments on commit f251d6a

Please sign in to comment.