diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index 29bb10e6f9..11c7d0b22c 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -25,7 +25,7 @@ StreamingOperator, StreamInstanceOperator, ) -from .operators import CopyFields +from .operators import CopyFields, FilterByCondition from .random_utils import get_seed from .settings_utils import get_settings from .stream import MultiStream, Stream @@ -435,42 +435,26 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval): # calculate scores for single instances process_single_instances = True + instance_filters: List[FilterByCondition] = OptionalField( + default_factory=lambda: [] + ) + # flake8: noqa: C901 def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator: - references = [] - predictions = [] - task_data = [] - global_score = {} - instances = [] - for instance in stream: if "score" not in instance: - instance["score"] = {"global": global_score, "instance": {}} - else: - global_score = instance["score"]["global"] - - instance_references, instance_prediction = ( - instance["references"], - instance["prediction"], - ) - references.append(instance_references) - predictions.append(instance_prediction) - instances.append(instance) + instance["score"] = {"global": {}, "instance": {}} - instance_task_data = ( - instance["task_data"] if "task_data" in instance else {} - ) - task_data.append(instance_task_data) instance_score = None # for backward compatibility no_score_value = np.nan if self.process_single_instances: try: instance_score = self._compute( - [instance_references], - [instance_prediction], - [instance_task_data], + [instance["references"]], + [instance["prediction"]], + [instance["task_data"] if "task_data" in instance else {}], ) except: no_score_value = None @@ -484,17 +468,68 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato instance_score[self.main_score] = no_score_value instance["score"]["instance"].update(instance_score) - self._validate_references_and_prediction(references, predictions) - result = self._compute(references, predictions, task_data) + instances.append(instance) - global_score.update(result) + global_score = instances[0]["score"]["global"] + # all instances have the same global score - score_name = global_score["score_name"] - confidence_interval = self.compute_global_confidence_intervals( - references, predictions, task_data, score_name - ) - global_score.update(confidence_interval) + ms = MultiStream.from_iterables({"tmp": instances}) + + class WideOpen(FilterByCondition): + def _is_required(self, instance: dict) -> bool: + return True + + self.instance_filters.append(WideOpen(condition="eq", values={})) + + for instance_filter in self.instance_filters: + if isoftype(instance_filter, WideOpen): + score_prefix = "" + else: + filter_values = str(instance_filter.values) + filter_values = filter_values.replace("'", "") + filter_values = re.sub(r"[{\[]", "(", filter_values) + filter_values = re.sub(r"[}\]]", ")", filter_values) + filter_values = re.sub( + r"([ ]*):([ ]*)", + "_" + instance_filter.condition + "_", + filter_values, + ) + score_prefix = self.main_score + "_" + filter_values + if score_prefix not in global_score: + global_score[score_prefix] = {} + + references = [] + predictions = [] + task_data = [] + for instance in instance_filter(ms)["tmp"]: + references.append(instance["references"]) + predictions.append(instance["prediction"]) + task_data.append( + instance["task_data"] if "task_data" in instance else {} + ) + + self._validate_references_and_prediction(references, predictions) + result = self._compute(references, predictions, task_data) + + if score_prefix == "": + global_score.update(result) + else: + global_score[score_prefix].update(result) + + score_name = ( + global_score["score_name"] + if score_prefix == "" + else global_score[score_prefix]["score_name"] + ) + + confidence_interval = self.compute_global_confidence_intervals( + references, predictions, task_data, score_name + ) + if score_prefix == "": + global_score.update(confidence_interval) + else: + global_score[score_prefix].update(confidence_interval) for instance in instances: instance["score"]["global"] = global_score diff --git a/tests/library/test_metrics.py b/tests/library/test_metrics.py index ace37905b7..06f93bf9eb 100644 --- a/tests/library/test_metrics.py +++ b/tests/library/test_metrics.py @@ -44,6 +44,8 @@ TokenOverlap, UnsortedListExactMatch, ) +from src.unitxt.operators import FilterByCondition +from src.unitxt.stream import MultiStream from src.unitxt.test_utils.metrics import apply_metric from tests.utils import UnitxtTestCase @@ -702,6 +704,47 @@ def test_kendalltau(self): global_target = 0.81649658092772 self.assertAlmostEqual(global_target, outputs[0]["score"]["global"]["score"]) + predictions2 = ["1.0", "3.0", "2.0"] + references2 = [["-1.0"], ["1.0"], ["0.0"]] + outputs = apply_metric( + metric=metric, predictions=predictions2, references=references2 + ) + global_target2 = 1.0 + self.assertAlmostEqual(global_target2, outputs[0]["score"]["global"]["score"]) + + instances = [ + {"planet": "mars", "prediction": pred, "references": ref} + for pred, ref in zip(predictions, references) + ] + instances = instances + [ + {"planet": "earth", "prediction": pred, "references": ref} + for pred, ref in zip(predictions2, references2) + ] + multi_stream = MultiStream.from_iterables({"test": instances}, copying=True) + + fbc1 = FilterByCondition(condition="eq", values={"planet": "mars"}) + fbc2 = FilterByCondition(condition="eq", values={"planet": "earth"}) + metric = KendallTauMetric(instance_filters=[fbc1, fbc2]) + output_instances = list(metric(multi_stream)["test"]) + self.assertAlmostEqual( + global_target, + output_instances[0]["score"]["global"]["kendalltau_b_(planet_eq_mars)"][ + "score" + ], + ) + self.assertAlmostEqual( + global_target2, + output_instances[0]["score"]["global"]["kendalltau_b_(planet_eq_earth)"][ + "score" + ], + ) + self.assertEqual( + output_instances[0]["score"]["global"]["score_name"], + output_instances[0]["score"]["global"]["kendalltau_b_(planet_eq_earth)"][ + "score_name" + ], + ) + def test_detector(self): metric = Detector(model_name="MilaNLProc/bert-base-uncased-ear-misogyny") predictions = ["I hate women.", "I do not hate women."]