From 0e0f4bd022e41a21c4bfa323dafb9d3c171e14f0 Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Tue, 3 Sep 2024 18:57:40 +0100 Subject: [PATCH] Fix linting issues --- huggingface_pipelines/metric_analyzer.py | 10 ++++++---- .../huggingface_pipelines/metric_analyzer.py | 4 +--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/huggingface_pipelines/metric_analyzer.py b/huggingface_pipelines/metric_analyzer.py index d94cbc7..df1ce06 100644 --- a/huggingface_pipelines/metric_analyzer.py +++ b/huggingface_pipelines/metric_analyzer.py @@ -4,7 +4,11 @@ import evaluate # type: ignore -from huggingface_pipelines.pipeline import Pipeline, PipelineConfig, PipelineFactory # type: ignore +from huggingface_pipelines.pipeline import ( # type: ignore + Pipeline, + PipelineConfig, + PipelineFactory, +) logger = logging.getLogger(__name__) @@ -89,8 +93,7 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: if isinstance(original_data[0], list): original_data = [" ".join(item) for item in original_data] if isinstance(reconstructed_data[0], list): - reconstructed_data = [" ".join(item) - for item in reconstructed_data] + reconstructed_data = [" ".join(item) for item in reconstructed_data] references = [[ref.split()] for ref in original_data] predictions = [pred.split() for pred in reconstructed_data] @@ -125,4 +128,3 @@ class MetricAnalyzerPipelineFactory(PipelineFactory): def create_pipeline(self, config: Dict[str, Any]) -> Pipeline: pipeline_config = MetricPipelineConfig(**config) return MetricAnalyzerPipeline(pipeline_config) - diff --git a/tests/unit_tests/huggingface_pipelines/metric_analyzer.py b/tests/unit_tests/huggingface_pipelines/metric_analyzer.py index 2c1baea..517dde1 100644 --- a/tests/unit_tests/huggingface_pipelines/metric_analyzer.py +++ b/tests/unit_tests/huggingface_pipelines/metric_analyzer.py @@ -57,8 +57,7 @@ def test_compute_metric(sample_config): pipeline.metrics["bleu"] = Mock() pipeline.metrics["bleu"].compute.return_value = {"score": 0.8} - result = pipeline.compute_metric( - "bleu", [["Hello", "world"]], ["Hello", "earth"]) + result = pipeline.compute_metric("bleu", [["Hello", "world"]], ["Hello", "earth"]) assert result == {"score": 0.8} pipeline.metrics["bleu"].compute.assert_called_once_with( predictions=["Hello", "earth"], references=[["Hello", "world"]] @@ -144,4 +143,3 @@ def test_low_score_threshold(sample_config, sample_batch, score, expected): assert result["text_bleu_score_low"] == expected assert result["text_rouge_score_low"] == expected -