Skip to content

Commit

Permalink
Add tests for metric analyzer
Browse files Browse the repository at this point in the history
  • Loading branch information
botirk38 committed Sep 3, 2024
1 parent 8fd649e commit a7abdc5
Showing 1 changed file with 73 additions and 72 deletions.
145 changes: 73 additions & 72 deletions tests/unit_tests/huggingface_pipelines/metric_analyzer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Any, Dict
from unittest.mock import Mock, patch

import pytest
from datasets import load_metric # type: ignore

from huggingface_pipelines.metric_analyzer import (
MetricAnalyzerPipeline,
Expand All @@ -25,113 +23,101 @@ def sample_config():
@pytest.fixture
def sample_batch():
return {
"text": ["This is a test sentence.", "Another example sentence."],
"reconstructed_text": [
"This is a test sentence.",
"A different example sentence.",
],
"text": ["Hello world", "This is a test"],
"reconstructed_text": ["Hello earth", "This is a quiz"],
}


def test_metric_pipeline_config():
config = MetricPipelineConfig(
metrics=["bleu", "rouge"],
low_score_threshold=0.6,
columns=["text"],
reconstructed_columns=["reconstructed_text"],
output_column_suffix="score",
metrics=["bleu"],
low_score_threshold=0.7,
columns=["col1"],
reconstructed_columns=["rec_col1"],
output_column_suffix="test",
)
assert config.metrics == ["bleu", "rouge"]
assert config.low_score_threshold == 0.6
assert config.columns == ["text"]
assert config.reconstructed_columns == ["reconstructed_text"]
assert config.output_column_suffix == "score"
assert config.metrics == ["bleu"]
assert config.low_score_threshold == 0.7
assert config.columns == ["col1"]
assert config.reconstructed_columns == ["rec_col1"]
assert config.output_column_suffix == "test"


@patch("huggingface_pipelines.metric_analyzer.load_metric")
def test_metric_analyzer_pipeline_init(mock_load_metric, sample_config):
mock_load_metric.return_value = Mock()
@patch("evaluate.load")
def test_metric_analyzer_pipeline_init(mock_load, sample_config):
mock_load.return_value = Mock()
pipeline = MetricAnalyzerPipeline(MetricPipelineConfig(**sample_config))
assert len(pipeline.metrics) == 2
assert "bleu" in pipeline.metrics
assert "rouge" in pipeline.metrics
mock_load_metric.assert_any_call("bleu")
mock_load_metric.assert_any_call("rouge")

mock_load.assert_any_call("bleu")
mock_load.assert_any_call("rouge")

@patch("huggingface_pipelines.metric_analyzer.load_metric")
def test_compute_metric(mock_load_metric, sample_config):
mock_metric = Mock()
mock_metric.compute.return_value = {"score": 0.8}
mock_load_metric.return_value = mock_metric

def test_compute_metric(sample_config):
pipeline = MetricAnalyzerPipeline(MetricPipelineConfig(**sample_config))
result = pipeline.compute_metric(
"bleu", [["This is a reference."]], ["This is a prediction."]
)
pipeline.metrics["bleu"] = Mock()
pipeline.metrics["bleu"].compute.return_value = {"score": 0.8}

result = pipeline.compute_metric(
"bleu", [["Hello", "world"]], ["Hello", "earth"])
assert result == {"score": 0.8}
mock_metric.compute.assert_called_once_with(
predictions=["This is a prediction."], references=[["This is a reference."]]
pipeline.metrics["bleu"].compute.assert_called_once_with(
predictions=["Hello", "earth"], references=[["Hello", "world"]]
)


@patch("huggingface_pipelines.metric_analyzer.load_metric")
def test_process_batch(mock_load_metric, sample_config, sample_batch):
mock_metric = Mock()
mock_metric.compute.return_value = {"score": 0.8}
mock_load_metric.return_value = mock_metric

def test_process_batch(sample_config, sample_batch):
pipeline = MetricAnalyzerPipeline(MetricPipelineConfig(**sample_config))
pipeline.compute_metric = Mock(return_value={"score": 0.75})

result = pipeline.process_batch(sample_batch)

assert "text_references" in result
assert "text_predictions" in result
assert "text_bleu_score" in result
assert "text_rouge_score" in result
assert "text_bleu_score_low" in result
assert "text_rouge_score_low" in result
assert result["text_bleu_score"] == [0.8, 0.8]
assert result["text_rouge_score"] == [0.8, 0.8]
assert result["text_bleu_score_low"] == [False, False]
assert result["text_rouge_score_low"] == [False, False]


@patch("huggingface_pipelines.metric_analyzer.load_metric")
def test_process_batch_with_list_input(mock_load_metric, sample_config):
mock_metric = Mock()
mock_metric.compute.return_value = {"score": 0.8}
mock_load_metric.return_value = mock_metric

pipeline = MetricAnalyzerPipeline(MetricPipelineConfig(**sample_config))
batch = {
"text": [["This", "is", "a", "test"], ["Another", "example"]],
"reconstructed_text": [
["This", "is", "a", "test"],
["A", "different", "example"],
],
}
result = pipeline.process_batch(batch)

assert "text_bleu_score" in result
assert "text_rouge_score" in result
assert result["text_bleu_score"] == [0.8, 0.8]
assert result["text_rouge_score"] == [0.8, 0.8]
assert result["text_bleu_score"] == [0.75, 0.75]
assert result["text_bleu_score_low"] == [False, False]


def test_process_batch_mismatch_columns():
config = MetricPipelineConfig(
metrics=["bleu"],
columns=["text1", "text2"],
reconstructed_columns=["reconstructed_text1"],
metrics=["bleu"], columns=["col1", "col2"], reconstructed_columns=["rec_col1"]
)
pipeline = MetricAnalyzerPipeline(config)

with pytest.raises(ValueError, match="Mismatch in number of columns"):
pipeline.process_batch({"text1": ["Test"], "reconstructed_text1": ["Test"]})
pipeline.process_batch({"col1": ["text"], "rec_col1": ["text"]})


def test_process_batch_list_input(sample_config):
config = MetricPipelineConfig(**sample_config)
pipeline = MetricAnalyzerPipeline(config)
pipeline.compute_metric = Mock(return_value={"score": 0.8})

batch = {
"text": [["Hello", "world"], ["This", "is", "a", "test"]],
"reconstructed_text": [["Hello", "earth"], ["This", "is", "a", "quiz"]],
}

result = pipeline.process_batch(batch)

assert result["text_references"] == [
[["Hello", "world"]],
[["This", "is", "a", "test"]],
]
assert result["text_predictions"] == [
["Hello", "earth"],
["This", "is", "a", "quiz"],
]


def test_metric_analyzer_pipeline_factory(sample_config):
factory = MetricAnalyzerPipelineFactory()
pipeline = factory.create_pipeline(sample_config)

assert isinstance(pipeline, MetricAnalyzerPipeline)
assert pipeline.config.metrics == sample_config["metrics"]
assert pipeline.config.low_score_threshold == sample_config["low_score_threshold"]
Expand All @@ -142,5 +128,20 @@ def test_metric_analyzer_pipeline_factory(sample_config):
assert pipeline.config.output_column_suffix == sample_config["output_column_suffix"]


if __name__ == "__main__":
pytest.main()
@pytest.mark.parametrize(
"score,expected",
[
(0.7, [False, False]),
(0.5, [True, True]),
(0.6, [False, False]),
],
)
def test_low_score_threshold(sample_config, sample_batch, score, expected):
pipeline = MetricAnalyzerPipeline(MetricPipelineConfig(**sample_config))
pipeline.compute_metric = Mock(return_value={"score": score})

result = pipeline.process_batch(sample_batch)

assert result["text_bleu_score_low"] == expected
assert result["text_rouge_score_low"] == expected

0 comments on commit a7abdc5

Please sign in to comment.