Skip to content

Commit

Permalink
allow global metrics on subsets of instances
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <dafnashein@yahoo.com>
  • Loading branch information
dafnapension committed Apr 4, 2024
1 parent 902d8bc commit 9aa1551
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 33 deletions.
101 changes: 68 additions & 33 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
StreamingOperator,
StreamInstanceOperator,
)
from .operators import CopyFields
from .operators import CopyFields, FilterByCondition
from .random_utils import get_seed
from .settings_utils import get_settings
from .stream import MultiStream, Stream
Expand Down Expand Up @@ -435,42 +435,26 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):

# calculate scores for single instances
process_single_instances = True
instance_filters: List[FilterByCondition] = OptionalField(
default_factory=lambda: []
)

# flake8: noqa: C901
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
references = []
predictions = []
task_data = []
global_score = {}

instances = []

for instance in stream:
if "score" not in instance:
instance["score"] = {"global": global_score, "instance": {}}
else:
global_score = instance["score"]["global"]

instance_references, instance_prediction = (
instance["references"],
instance["prediction"],
)
references.append(instance_references)
predictions.append(instance_prediction)
instances.append(instance)
instance["score"] = {"global": {}, "instance": {}}

instance_task_data = (
instance["task_data"] if "task_data" in instance else {}
)
task_data.append(instance_task_data)
instance_score = None
# for backward compatibility
no_score_value = np.nan
if self.process_single_instances:
try:
instance_score = self._compute(
[instance_references],
[instance_prediction],
[instance_task_data],
[instance["references"]],
[instance["prediction"]],
[instance["task_data"] if "task_data" in instance else {}],
)
except:
no_score_value = None
Expand All @@ -484,17 +468,68 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato
instance_score[self.main_score] = no_score_value

instance["score"]["instance"].update(instance_score)
self._validate_references_and_prediction(references, predictions)

result = self._compute(references, predictions, task_data)
instances.append(instance)

global_score.update(result)
global_score = instances[0]["score"]["global"]
# all instances have the same global score

score_name = global_score["score_name"]
confidence_interval = self.compute_global_confidence_intervals(
references, predictions, task_data, score_name
)
global_score.update(confidence_interval)
ms = MultiStream.from_iterables({"tmp": instances})

class WideOpen(FilterByCondition):
def _is_required(self, instance: dict) -> bool:
return True

self.instance_filters.append(WideOpen(condition="eq", values={}))

for instance_filter in self.instance_filters:
if isoftype(instance_filter, WideOpen):
score_prefix = ""
else:
filter_values = str(instance_filter.values)
filter_values = filter_values.replace("'", "")
filter_values = re.sub(r"[{\[]", "(", filter_values)
filter_values = re.sub(r"[}\]]", ")", filter_values)
filter_values = re.sub(
r"([ ]*):([ ]*)",
"_" + instance_filter.condition + "_",
filter_values,
)
score_prefix = self.main_score + "_" + filter_values
if score_prefix not in global_score:
global_score[score_prefix] = {}

references = []
predictions = []
task_data = []
for instance in instance_filter(ms)["tmp"]:
references.append(instance["references"])
predictions.append(instance["prediction"])
task_data.append(
instance["task_data"] if "task_data" in instance else {}
)

self._validate_references_and_prediction(references, predictions)
result = self._compute(references, predictions, task_data)

if score_prefix == "":
global_score.update(result)
else:
global_score[score_prefix].update(result)

score_name = (
global_score["score_name"]
if score_prefix == ""
else global_score[score_prefix]["score_name"]
)

confidence_interval = self.compute_global_confidence_intervals(
references, predictions, task_data, score_name
)
if score_prefix == "":
global_score.update(confidence_interval)
else:
global_score[score_prefix].update(confidence_interval)

for instance in instances:
instance["score"]["global"] = global_score
Expand Down
43 changes: 43 additions & 0 deletions tests/library/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
TokenOverlap,
UnsortedListExactMatch,
)
from src.unitxt.operators import FilterByCondition
from src.unitxt.stream import MultiStream
from src.unitxt.test_utils.metrics import apply_metric
from tests.utils import UnitxtTestCase

Expand Down Expand Up @@ -702,6 +704,47 @@ def test_kendalltau(self):
global_target = 0.81649658092772
self.assertAlmostEqual(global_target, outputs[0]["score"]["global"]["score"])

predictions2 = ["1.0", "3.0", "2.0"]
references2 = [["-1.0"], ["1.0"], ["0.0"]]
outputs = apply_metric(
metric=metric, predictions=predictions2, references=references2
)
global_target2 = 1.0
self.assertAlmostEqual(global_target2, outputs[0]["score"]["global"]["score"])

instances = [
{"planet": "mars", "prediction": pred, "references": ref}
for pred, ref in zip(predictions, references)
]
instances = instances + [
{"planet": "earth", "prediction": pred, "references": ref}
for pred, ref in zip(predictions2, references2)
]
multi_stream = MultiStream.from_iterables({"test": instances}, copying=True)

fbc1 = FilterByCondition(condition="eq", values={"planet": "mars"})
fbc2 = FilterByCondition(condition="eq", values={"planet": "earth"})
metric = KendallTauMetric(instance_filters=[fbc1, fbc2])
output_instances = list(metric(multi_stream)["test"])
self.assertAlmostEqual(
global_target,
output_instances[0]["score"]["global"]["kendalltau_b_(planet_eq_mars)"][
"score"
],
)
self.assertAlmostEqual(
global_target2,
output_instances[0]["score"]["global"]["kendalltau_b_(planet_eq_earth)"][
"score"
],
)
self.assertEqual(
output_instances[0]["score"]["global"]["score_name"],
output_instances[0]["score"]["global"]["kendalltau_b_(planet_eq_earth)"][
"score_name"
],
)

def test_detector(self):
metric = Detector(model_name="MilaNLProc/bert-base-uncased-ear-misogyny")
predictions = ["I hate women.", "I do not hate women."]
Expand Down

0 comments on commit 9aa1551

Please sign in to comment.