diff --git a/examples/evaluate_benchmark.py b/examples/evaluate_benchmark.py index dca439f16..e92b7e309 100644 --- a/examples/evaluate_benchmark.py +++ b/examples/evaluate_benchmark.py @@ -48,9 +48,8 @@ # Infere using flan t5 base using HF API -model_name = "google/flan-t5-base" inference_model = HFPipelineBasedInferenceEngine( - model_name=model_name, max_new_tokens=32 + model_name="google/flan-t5-base", max_new_tokens=32 ) predictions = inference_model.infer(test_dataset) diff --git a/examples/evaluate_different_demo_selections.py b/examples/evaluate_different_demo_selections.py index 61d7d6837..d494b89f4 100644 --- a/examples/evaluate_different_demo_selections.py +++ b/examples/evaluate_different_demo_selections.py @@ -32,14 +32,13 @@ num_demos=num_demos, demos_pool_size=50, loader_limit=200, - max_test_instances=100, + max_test_instances=10, sampler=demo_sampler, + split="test", ) - test_dataset = dataset["test"] - - predictions = inference_model.infer(test_dataset) - evaluated_dataset = evaluate(predictions=predictions, data=test_dataset) + predictions = inference_model.infer(dataset) + evaluated_dataset = evaluate(predictions=predictions, data=dataset) logger.info( f"Sample input and output for sampler {demo_sampler} and num_demos '{num_demos}':" diff --git a/examples/evaluate_image_text_to_text.py b/examples/evaluate_image_text_to_text.py index a4f0dc6c1..1edbe02e6 100644 --- a/examples/evaluate_image_text_to_text.py +++ b/examples/evaluate_image_text_to_text.py @@ -6,19 +6,19 @@ with settings.context( disable_hf_datasets_cache=False, ): - inference_model = HFLlavaInferenceEngine( - model_name="llava-hf/llava-interleave-qwen-0.5b-hf", max_new_tokens=32 - ) - dataset = load_dataset( card="cards.doc_vqa.lmms_eval", template="templates.qa.with_context.title", format="formats.chat_api", - loader_limit=300, + loader_limit=10, augmentor="augmentors.image.grey_scale", split="test", ) + inference_model = HFLlavaInferenceEngine( + model_name="llava-hf/llava-interleave-qwen-0.5b-hf", max_new_tokens=32 + ) + predictions = inference_model.infer(dataset) evaluated_dataset = evaluate(predictions=predictions, data=dataset) diff --git a/prepare/tasks/qa/multiple_choice/tasks.py b/prepare/tasks/qa/multiple_choice/tasks.py index cf8952e29..8fe83cad0 100644 --- a/prepare/tasks/qa/multiple_choice/tasks.py +++ b/prepare/tasks/qa/multiple_choice/tasks.py @@ -14,6 +14,7 @@ }, reference_fields={"answer": Union[int, str], "choices": List[str]}, prediction_type=str, + augmentable_inputs=["context", "question"], metrics=["metrics.accuracy"], ), "tasks.qa.multiple_choice.with_context", @@ -26,6 +27,7 @@ input_fields={"topic": str, "question": str, "choices": List[str]}, reference_fields={"answer": Union[int, str], "choices": List[str]}, prediction_type=str, + augmentable_inputs=["topic", "question"], metrics=["metrics.accuracy"], ), "tasks.qa.multiple_choice.with_topic", @@ -37,6 +39,7 @@ input_fields={"question": str, "choices": List[str]}, reference_fields={"answer": Union[int, str], "choices": List[str]}, prediction_type=str, + augmentable_inputs=["question"], metrics=["metrics.accuracy"], ), "tasks.qa.multiple_choice.open", @@ -54,6 +57,7 @@ }, reference_fields={"answer": Union[int, str], "choices": List[str]}, prediction_type=str, + augmentable_inputs=["context", "question"], metrics=["metrics.accuracy"], ), "tasks.qa.multiple_choice.with_context.with_topic", diff --git a/src/unitxt/catalog/tasks/qa/multiple_choice/open.json b/src/unitxt/catalog/tasks/qa/multiple_choice/open.json index 1cd21924d..a6422737f 100644 --- a/src/unitxt/catalog/tasks/qa/multiple_choice/open.json +++ b/src/unitxt/catalog/tasks/qa/multiple_choice/open.json @@ -9,6 +9,9 @@ "choices": "List[str]" }, "prediction_type": "str", + "augmentable_inputs": [ + "question" + ], "metrics": [ "metrics.accuracy" ] diff --git a/src/unitxt/catalog/tasks/qa/multiple_choice/with_context.json b/src/unitxt/catalog/tasks/qa/multiple_choice/with_context.json index a22346778..be5de61a5 100644 --- a/src/unitxt/catalog/tasks/qa/multiple_choice/with_context.json +++ b/src/unitxt/catalog/tasks/qa/multiple_choice/with_context.json @@ -11,6 +11,10 @@ "choices": "List[str]" }, "prediction_type": "str", + "augmentable_inputs": [ + "context", + "question" + ], "metrics": [ "metrics.accuracy" ] diff --git a/src/unitxt/catalog/tasks/qa/multiple_choice/with_context/with_topic.json b/src/unitxt/catalog/tasks/qa/multiple_choice/with_context/with_topic.json index 97895cc35..5bb4cbb29 100644 --- a/src/unitxt/catalog/tasks/qa/multiple_choice/with_context/with_topic.json +++ b/src/unitxt/catalog/tasks/qa/multiple_choice/with_context/with_topic.json @@ -12,6 +12,10 @@ "choices": "List[str]" }, "prediction_type": "str", + "augmentable_inputs": [ + "context", + "question" + ], "metrics": [ "metrics.accuracy" ] diff --git a/src/unitxt/catalog/tasks/qa/multiple_choice/with_topic.json b/src/unitxt/catalog/tasks/qa/multiple_choice/with_topic.json index 24e86e13a..da7184663 100644 --- a/src/unitxt/catalog/tasks/qa/multiple_choice/with_topic.json +++ b/src/unitxt/catalog/tasks/qa/multiple_choice/with_topic.json @@ -10,6 +10,10 @@ "choices": "List[str]" }, "prediction_type": "str", + "augmentable_inputs": [ + "topic", + "question" + ], "metrics": [ "metrics.accuracy" ] diff --git a/src/unitxt/standard.py b/src/unitxt/standard.py index a8dbe46ba..6982b2c7c 100644 --- a/src/unitxt/standard.py +++ b/src/unitxt/standard.py @@ -3,12 +3,11 @@ from .artifact import fetch_artifact from .augmentors import ( Augmentor, - NullAugmentor, - TaskInputsAugmentor, ) from .card import TaskCard from .collections_operators import GetLength from .dataclass import Field, InternalField, NonPositionalField, OptionalField +from .error_utils import UnitxtError from .formats import Format, SystemFormat from .logging_utils import get_logger from .operator import SequentialOperator, SourceSequentialOperator, StreamingOperator @@ -69,9 +68,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator): demos_field: str = "demos" sampler: Sampler = None - augmentor: Union[Augmentor, List[Augmentor]] = OptionalField( - default_factory=NullAugmentor - ) + augmentor: Union[Augmentor, List[Augmentor]] = OptionalField(default=None) steps: List[StreamingOperator] = InternalField(default_factory=list) @@ -308,11 +305,19 @@ def reset_pipeline(self): self.processing.steps.append(self.task) - if not isinstance(self.augmentor, list): - self.augmentor = [self.augmentor] + if self.augmentor is not None: + if ( + self.card.task.augmentable_inputs is None + or len(self.task.augmentable_inputs) == 0 + ): + raise UnitxtError( + f"You specified augmentor in the recipe but the got task without augmentable_inputs: {self.task}" + ) + + if not isinstance(self.augmentor, list): + self.augmentor = [self.augmentor] - for augmentor in self.augmentor: - if isinstance(augmentor, TaskInputsAugmentor): + for augmentor in self.augmentor: augmentor.set_fields(self.card.task.augmentable_inputs) self.processing.steps.append(augmentor)