From 2ce70d26367a1c0e02ebb0cb4d83b3482a5c7ce9 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Mon, 18 Nov 2024 18:12:42 +0200 Subject: [PATCH] Fix Signed-off-by: elronbandel --- src/unitxt/inference.py | 7 ++++--- src/unitxt/llm_as_judge.py | 6 +++--- src/unitxt/metrics.py | 8 ++++---- src/unitxt/operators.py | 6 +++--- src/unitxt/task.py | 10 +++++----- 5 files changed, 19 insertions(+), 18 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index cf73c2736..23382f8e1 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -15,12 +15,13 @@ from tqdm import tqdm, trange from tqdm.asyncio import tqdm_asyncio -from .artifact import Artifact, fetch_artifact +from .artifact import Artifact from .dataclass import InternalField, NonPositionalField from .deprecation_utils import deprecation from .image_operators import data_url_to_image, extract_images from .logging_utils import get_logger from .operator import PackageRequirementsMixin +from .operators import ArtifactFetcherMixin from .settings_utils import get_constants, get_settings constants = get_constants() @@ -343,7 +344,7 @@ class IbmGenAiInferenceEngineParams(Artifact): typical_p: Optional[float] = None -class GenericInferenceEngine(InferenceEngine): +class GenericInferenceEngine(InferenceEngine, ArtifactFetcherMixin): default: Optional[str] = None def prepare_engine(self): @@ -359,7 +360,7 @@ def prepare_engine(self): "\nor passing a similar required engine in the default argument" ) engine_reference = self.default - self.engine, _ = fetch_artifact(engine_reference) + self.engine = self.get_artifact(engine_reference) def get_engine_id(self): return "generic_inference_engine" diff --git a/src/unitxt/llm_as_judge.py b/src/unitxt/llm_as_judge.py index 171417e94..225a284cc 100644 --- a/src/unitxt/llm_as_judge.py +++ b/src/unitxt/llm_as_judge.py @@ -2,12 +2,12 @@ from typing import Any, Dict, List, Literal, Optional from .api import infer -from .artifact import fetch_artifact from .dataclass import Field from .formats import Format, SystemFormat from .inference import InferenceEngine, LogProbInferenceEngine, OpenAiInferenceEngine from .metrics import BulkInstanceMetric from .operator import SequentialOperator +from .operators import ArtifactFetcherMixin from .settings_utils import get_settings from .system_prompts import EmptySystemPrompt, SystemPrompt from .templates import Template @@ -122,7 +122,7 @@ def get_metric_results_from_prediction_outputs( pass -class LLMAsJudge(LLMAsJudgeBase): +class LLMAsJudge(LLMAsJudgeBase, ArtifactFetcherMixin): """LLM-as-judge-based metric class for evaluating correctness of generated predictions. This class uses the source prompt given to the generator and the generator's predictions to evaluate @@ -156,7 +156,7 @@ def _get_input_instances(self, task_data: List[Dict]) -> List: instances = [] for task_data_instance in task_data: template = task_data_instance["metadata"]["template"] - template, _ = fetch_artifact(template) + template = self.get_artifact(template) instance = SequentialOperator( steps=[template, "formats.empty"] ).process_instance( diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index ab65a0dc6..d127ab5ef 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -17,7 +17,7 @@ from scipy.stats import bootstrap from scipy.stats._warnings_errors import DegenerateDataWarning -from .artifact import Artifact, fetch_artifact +from .artifact import Artifact from .dataclass import ( AbstractField, InternalField, @@ -37,7 +37,7 @@ StreamingOperator, StreamOperator, ) -from .operators import Copy, Set +from .operators import ArtifactFetcherMixin, Copy, Set from .random_utils import get_seed from .settings_utils import get_settings from .stream import MultiStream, Stream @@ -4812,7 +4812,7 @@ def _prepare_instances_for_model(self, texts: List[str]): return processed_stream.to_dataset()["test"] -class MetricsEnsemble(InstanceMetric): +class MetricsEnsemble(InstanceMetric, ArtifactFetcherMixin): """Metrics Ensemble class for creating ensemble of given metrics. Attributes: @@ -4836,7 +4836,7 @@ def get_prefix_name(self, i): def prepare(self): super().prepare() - self.metrics = [fetch_artifact(metric)[0] for metric in self.metrics] + self.metrics = [self.get_artifact(metric) for metric in self.metrics] for i, metric in enumerate(self.metrics): metric.score_prefix = self.get_prefix_name(i) if self.weights is None: diff --git a/src/unitxt/operators.py b/src/unitxt/operators.py index 8edf8c32b..7229ceac5 100644 --- a/src/unitxt/operators.py +++ b/src/unitxt/operators.py @@ -1039,10 +1039,10 @@ class ArtifactFetcherMixin: @classmethod def get_artifact(cls, artifact_identifier: str) -> Artifact: - if artifact_identifier not in cls._artifacts_cache: + if str(artifact_identifier) not in cls._artifacts_cache: artifact, catalog = fetch_artifact(artifact_identifier) - cls._artifacts_cache[artifact_identifier] = artifact - return shallow_copy(cls._artifacts_cache[artifact_identifier]) + cls._artifacts_cache[str(artifact_identifier)] = artifact + return shallow_copy(cls._artifacts_cache[str(artifact_identifier)]) class ApplyOperatorsField(InstanceOperator): diff --git a/src/unitxt/task.py b/src/unitxt/task.py index eb3851c37..1092bd7bc 100644 --- a/src/unitxt/task.py +++ b/src/unitxt/task.py @@ -2,11 +2,11 @@ from functools import lru_cache from typing import Any, Dict, List, Optional, Union -from .artifact import fetch_artifact from .deprecation_utils import deprecation from .error_utils import Documentation, UnitxtError, UnitxtWarning from .logging_utils import get_logger from .operator import InstanceOperator +from .operators import ArtifactFetcherMixin from .settings_utils import get_constants from .type_utils import ( Type, @@ -35,7 +35,7 @@ def parse_string_types_instead_of_actual_objects(obj): return parse_type_string(obj) -class Task(InstanceOperator): +class Task(InstanceOperator, ArtifactFetcherMixin): """Task packs the different instance fields into dictionaries by their roles in the task. Attributes: @@ -184,10 +184,10 @@ def process_data_before_dump(self, data): data["prediction_type"] = to_type_string(data["prediction_type"]) return data - @staticmethod + @classmethod @lru_cache(maxsize=None) - def get_metric_prediction_type(metric_id: str): - metric = fetch_artifact(metric_id)[0] + def get_metric_prediction_type(cls, metric_id: str): + metric = cls.get_artifact(metric_id) return metric.prediction_type def check_metrics_type(self) -> None: