Skip to content

Commit

Permalink
Fix linting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
botirk38 committed Sep 3, 2024
1 parent a7abdc5 commit 0e0f4bd
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
10 changes: 6 additions & 4 deletions huggingface_pipelines/metric_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

import evaluate # type: ignore

from huggingface_pipelines.pipeline import Pipeline, PipelineConfig, PipelineFactory # type: ignore
from huggingface_pipelines.pipeline import ( # type: ignore
Pipeline,
PipelineConfig,
PipelineFactory,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,8 +93,7 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(original_data[0], list):
original_data = [" ".join(item) for item in original_data]
if isinstance(reconstructed_data[0], list):
reconstructed_data = [" ".join(item)
for item in reconstructed_data]
reconstructed_data = [" ".join(item) for item in reconstructed_data]

references = [[ref.split()] for ref in original_data]
predictions = [pred.split() for pred in reconstructed_data]
Expand Down Expand Up @@ -125,4 +128,3 @@ class MetricAnalyzerPipelineFactory(PipelineFactory):
def create_pipeline(self, config: Dict[str, Any]) -> Pipeline:
pipeline_config = MetricPipelineConfig(**config)
return MetricAnalyzerPipeline(pipeline_config)

4 changes: 1 addition & 3 deletions tests/unit_tests/huggingface_pipelines/metric_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def test_compute_metric(sample_config):
pipeline.metrics["bleu"] = Mock()
pipeline.metrics["bleu"].compute.return_value = {"score": 0.8}

result = pipeline.compute_metric(
"bleu", [["Hello", "world"]], ["Hello", "earth"])
result = pipeline.compute_metric("bleu", [["Hello", "world"]], ["Hello", "earth"])
assert result == {"score": 0.8}
pipeline.metrics["bleu"].compute.assert_called_once_with(
predictions=["Hello", "earth"], references=[["Hello", "world"]]
Expand Down Expand Up @@ -144,4 +143,3 @@ def test_low_score_threshold(sample_config, sample_batch, score, expected):

assert result["text_bleu_score_low"] == expected
assert result["text_rouge_score_low"] == expected

0 comments on commit 0e0f4bd

Please sign in to comment.