Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
Signed-off-by: elronbandel <elronbandel@gmail.com>
  • Loading branch information
elronbandel committed Nov 18, 2024
1 parent f202c3a commit bd8e176
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 21 deletions.
3 changes: 1 addition & 2 deletions examples/evaluate_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@


# Infere using flan t5 base using HF API
model_name = "google/flan-t5-base"
inference_model = HFPipelineBasedInferenceEngine(
model_name=model_name, max_new_tokens=32
model_name="google/flan-t5-base", max_new_tokens=32
)

predictions = inference_model.infer(test_dataset)
Expand Down
9 changes: 4 additions & 5 deletions examples/evaluate_different_demo_selections.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@
num_demos=num_demos,
demos_pool_size=50,
loader_limit=200,
max_test_instances=100,
max_test_instances=10,
sampler=demo_sampler,
split="test",
)

test_dataset = dataset["test"]

predictions = inference_model.infer(test_dataset)
evaluated_dataset = evaluate(predictions=predictions, data=test_dataset)
predictions = inference_model.infer(dataset)
evaluated_dataset = evaluate(predictions=predictions, data=dataset)

logger.info(
f"Sample input and output for sampler {demo_sampler} and num_demos '{num_demos}':"
Expand Down
10 changes: 5 additions & 5 deletions examples/evaluate_image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
with settings.context(
disable_hf_datasets_cache=False,
):
inference_model = HFLlavaInferenceEngine(
model_name="llava-hf/llava-interleave-qwen-0.5b-hf", max_new_tokens=32
)

dataset = load_dataset(
card="cards.doc_vqa.lmms_eval",
template="templates.qa.with_context.title",
format="formats.chat_api",
loader_limit=300,
loader_limit=10,
augmentor="augmentors.image.grey_scale",
split="test",
)

inference_model = HFLlavaInferenceEngine(
model_name="llava-hf/llava-interleave-qwen-0.5b-hf", max_new_tokens=32
)

predictions = inference_model.infer(dataset)
evaluated_dataset = evaluate(predictions=predictions, data=dataset)

Expand Down
4 changes: 4 additions & 0 deletions prepare/tasks/qa/multiple_choice/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
},
reference_fields={"answer": Union[int, str], "choices": List[str]},
prediction_type=str,
augmentable_inputs=["context", "question"],
metrics=["metrics.accuracy"],
),
"tasks.qa.multiple_choice.with_context",
Expand All @@ -26,6 +27,7 @@
input_fields={"topic": str, "question": str, "choices": List[str]},
reference_fields={"answer": Union[int, str], "choices": List[str]},
prediction_type=str,
augmentable_inputs=["topic", "question"],
metrics=["metrics.accuracy"],
),
"tasks.qa.multiple_choice.with_topic",
Expand All @@ -37,6 +39,7 @@
input_fields={"question": str, "choices": List[str]},
reference_fields={"answer": Union[int, str], "choices": List[str]},
prediction_type=str,
augmentable_inputs=["question"],
metrics=["metrics.accuracy"],
),
"tasks.qa.multiple_choice.open",
Expand All @@ -54,6 +57,7 @@
},
reference_fields={"answer": Union[int, str], "choices": List[str]},
prediction_type=str,
augmentable_inputs=["context", "question"],
metrics=["metrics.accuracy"],
),
"tasks.qa.multiple_choice.with_context.with_topic",
Expand Down
3 changes: 3 additions & 0 deletions src/unitxt/catalog/tasks/qa/multiple_choice/open.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
"choices": "List[str]"
},
"prediction_type": "str",
"augmentable_inputs": [
"question"
],
"metrics": [
"metrics.accuracy"
]
Expand Down
4 changes: 4 additions & 0 deletions src/unitxt/catalog/tasks/qa/multiple_choice/with_context.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
"choices": "List[str]"
},
"prediction_type": "str",
"augmentable_inputs": [
"context",
"question"
],
"metrics": [
"metrics.accuracy"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
"choices": "List[str]"
},
"prediction_type": "str",
"augmentable_inputs": [
"context",
"question"
],
"metrics": [
"metrics.accuracy"
]
Expand Down
4 changes: 4 additions & 0 deletions src/unitxt/catalog/tasks/qa/multiple_choice/with_topic.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
"choices": "List[str]"
},
"prediction_type": "str",
"augmentable_inputs": [
"topic",
"question"
],
"metrics": [
"metrics.accuracy"
]
Expand Down
23 changes: 14 additions & 9 deletions src/unitxt/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from .artifact import fetch_artifact
from .augmentors import (
Augmentor,
NullAugmentor,
TaskInputsAugmentor,
)
from .card import TaskCard
from .collections_operators import GetLength
from .dataclass import Field, InternalField, NonPositionalField, OptionalField
from .error_utils import UnitxtError
from .formats import Format, SystemFormat
from .logging_utils import get_logger
from .operator import SequentialOperator, SourceSequentialOperator, StreamingOperator
Expand Down Expand Up @@ -69,9 +68,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
demos_field: str = "demos"
sampler: Sampler = None

augmentor: Union[Augmentor, List[Augmentor]] = OptionalField(
default_factory=NullAugmentor
)
augmentor: Union[Augmentor, List[Augmentor]] = OptionalField(default=None)

steps: List[StreamingOperator] = InternalField(default_factory=list)

Expand Down Expand Up @@ -308,11 +305,19 @@ def reset_pipeline(self):

self.processing.steps.append(self.task)

if not isinstance(self.augmentor, list):
self.augmentor = [self.augmentor]
if self.augmentor is not None:
if (
self.card.task.augmentable_inputs is None
or len(self.task.augmentable_inputs) == 0
):
raise UnitxtError(
f"You specified augmentor in the recipe but the got task without augmentable_inputs: {self.task}"
)

if not isinstance(self.augmentor, list):
self.augmentor = [self.augmentor]

for augmentor in self.augmentor:
if isinstance(augmentor, TaskInputsAugmentor):
for augmentor in self.augmentor:
augmentor.set_fields(self.card.task.augmentable_inputs)
self.processing.steps.append(augmentor)

Expand Down

0 comments on commit bd8e176

Please sign in to comment.