Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
Signed-off-by: elronbandel <elronbandel@gmail.com>
  • Loading branch information
elronbandel committed Nov 18, 2024
1 parent 8083c7a commit 2ce70d2
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 18 deletions.
7 changes: 4 additions & 3 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from tqdm import tqdm, trange
from tqdm.asyncio import tqdm_asyncio

from .artifact import Artifact, fetch_artifact
from .artifact import Artifact
from .dataclass import InternalField, NonPositionalField
from .deprecation_utils import deprecation
from .image_operators import data_url_to_image, extract_images
from .logging_utils import get_logger
from .operator import PackageRequirementsMixin
from .operators import ArtifactFetcherMixin
from .settings_utils import get_constants, get_settings

constants = get_constants()
Expand Down Expand Up @@ -343,7 +344,7 @@ class IbmGenAiInferenceEngineParams(Artifact):
typical_p: Optional[float] = None


class GenericInferenceEngine(InferenceEngine):
class GenericInferenceEngine(InferenceEngine, ArtifactFetcherMixin):
default: Optional[str] = None

def prepare_engine(self):
Expand All @@ -359,7 +360,7 @@ def prepare_engine(self):
"\nor passing a similar required engine in the default argument"
)
engine_reference = self.default
self.engine, _ = fetch_artifact(engine_reference)
self.engine = self.get_artifact(engine_reference)

def get_engine_id(self):
return "generic_inference_engine"
Expand Down
6 changes: 3 additions & 3 deletions src/unitxt/llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing import Any, Dict, List, Literal, Optional

from .api import infer
from .artifact import fetch_artifact
from .dataclass import Field
from .formats import Format, SystemFormat
from .inference import InferenceEngine, LogProbInferenceEngine, OpenAiInferenceEngine
from .metrics import BulkInstanceMetric
from .operator import SequentialOperator
from .operators import ArtifactFetcherMixin
from .settings_utils import get_settings
from .system_prompts import EmptySystemPrompt, SystemPrompt
from .templates import Template
Expand Down Expand Up @@ -122,7 +122,7 @@ def get_metric_results_from_prediction_outputs(
pass


class LLMAsJudge(LLMAsJudgeBase):
class LLMAsJudge(LLMAsJudgeBase, ArtifactFetcherMixin):
"""LLM-as-judge-based metric class for evaluating correctness of generated predictions.
This class uses the source prompt given to the generator and the generator's predictions to evaluate
Expand Down Expand Up @@ -156,7 +156,7 @@ def _get_input_instances(self, task_data: List[Dict]) -> List:
instances = []
for task_data_instance in task_data:
template = task_data_instance["metadata"]["template"]
template, _ = fetch_artifact(template)
template = self.get_artifact(template)
instance = SequentialOperator(
steps=[template, "formats.empty"]
).process_instance(
Expand Down
8 changes: 4 additions & 4 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from scipy.stats import bootstrap
from scipy.stats._warnings_errors import DegenerateDataWarning

from .artifact import Artifact, fetch_artifact
from .artifact import Artifact
from .dataclass import (
AbstractField,
InternalField,
Expand All @@ -37,7 +37,7 @@
StreamingOperator,
StreamOperator,
)
from .operators import Copy, Set
from .operators import ArtifactFetcherMixin, Copy, Set
from .random_utils import get_seed
from .settings_utils import get_settings
from .stream import MultiStream, Stream
Expand Down Expand Up @@ -4812,7 +4812,7 @@ def _prepare_instances_for_model(self, texts: List[str]):
return processed_stream.to_dataset()["test"]


class MetricsEnsemble(InstanceMetric):
class MetricsEnsemble(InstanceMetric, ArtifactFetcherMixin):
"""Metrics Ensemble class for creating ensemble of given metrics.
Attributes:
Expand All @@ -4836,7 +4836,7 @@ def get_prefix_name(self, i):

def prepare(self):
super().prepare()
self.metrics = [fetch_artifact(metric)[0] for metric in self.metrics]
self.metrics = [self.get_artifact(metric) for metric in self.metrics]
for i, metric in enumerate(self.metrics):
metric.score_prefix = self.get_prefix_name(i)
if self.weights is None:
Expand Down
6 changes: 3 additions & 3 deletions src/unitxt/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,10 +1039,10 @@ class ArtifactFetcherMixin:

@classmethod
def get_artifact(cls, artifact_identifier: str) -> Artifact:
if artifact_identifier not in cls._artifacts_cache:
if str(artifact_identifier) not in cls._artifacts_cache:
artifact, catalog = fetch_artifact(artifact_identifier)
cls._artifacts_cache[artifact_identifier] = artifact
return shallow_copy(cls._artifacts_cache[artifact_identifier])
cls._artifacts_cache[str(artifact_identifier)] = artifact
return shallow_copy(cls._artifacts_cache[str(artifact_identifier)])


class ApplyOperatorsField(InstanceOperator):
Expand Down
10 changes: 5 additions & 5 deletions src/unitxt/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from functools import lru_cache
from typing import Any, Dict, List, Optional, Union

from .artifact import fetch_artifact
from .deprecation_utils import deprecation
from .error_utils import Documentation, UnitxtError, UnitxtWarning
from .logging_utils import get_logger
from .operator import InstanceOperator
from .operators import ArtifactFetcherMixin
from .settings_utils import get_constants
from .type_utils import (
Type,
Expand Down Expand Up @@ -35,7 +35,7 @@ def parse_string_types_instead_of_actual_objects(obj):
return parse_type_string(obj)


class Task(InstanceOperator):
class Task(InstanceOperator, ArtifactFetcherMixin):
"""Task packs the different instance fields into dictionaries by their roles in the task.
Attributes:
Expand Down Expand Up @@ -184,10 +184,10 @@ def process_data_before_dump(self, data):
data["prediction_type"] = to_type_string(data["prediction_type"])
return data

@staticmethod
@classmethod
@lru_cache(maxsize=None)
def get_metric_prediction_type(metric_id: str):
metric = fetch_artifact(metric_id)[0]
def get_metric_prediction_type(cls, metric_id: str):
metric = cls.get_artifact(metric_id)
return metric.prediction_type

def check_metrics_type(self) -> None:
Expand Down

0 comments on commit 2ce70d2

Please sign in to comment.