From 5257600f1d3f03bdcb9045e0bbcb56e3aaec1b2d Mon Sep 17 00:00:00 2001 From: Agus Date: Wed, 15 Jan 2025 12:28:19 +0100 Subject: [PATCH] Image Language Models and `ImageGeneration` task (#1060) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Gabriel Martín Blázquez Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../image_generation_gallery.md | 10 + docs/api/models/image_generation/index.md | 7 + docs/api/pipeline/typing.md | 3 - docs/api/step/typing.md | 3 - docs/api/task/image_task.md | 7 + docs/api/task/task_gallery.md | 1 + docs/api/task/typing.md | 3 - docs/api/typing.md | 8 + .../how_to_guides/advanced/distiset.md | 27 ++ .../advanced/pipeline_requirements.md | 2 +- .../advanced/structured_generation.md | 4 +- .../basic/step/generator_step.md | 8 +- .../how_to_guides/basic/step/global_step.md | 6 +- .../how_to_guides/basic/step/index.md | 8 +- .../basic/task/generator_task.md | 5 +- .../how_to_guides/basic/task/image_task.md | 104 ++++++++ .../how_to_guides/basic/task/index.md | 5 +- .../examples/image_generation.md | 108 ++++++++ docs/sections/pipeline_samples/index.md | 8 + .../tutorials/clean_existing_dataset.ipynb | 2 +- examples/image_generation.py | 42 +++ mkdocs.yml | 9 +- pyproject.toml | 1 + src/distilabel/distiset.py | 54 +++- src/distilabel/llms.py | 2 +- src/distilabel/mixins/runtime_parameters.py | 87 ++++++ src/distilabel/models/__init__.py | 14 +- .../models/base_clients/__init__.py | 20 ++ .../base_clients/inference_endpoints.py | 154 +++++++++++ src/distilabel/models/base_clients/openai.py | 122 +++++++++ src/distilabel/models/embeddings/base.py | 4 +- .../models/image_generation/__init__.py | 29 ++ .../models/image_generation/base.py | 247 ++++++++++++++++++ .../image_generation/huggingface/__init__.py | 14 + .../huggingface/inference_endpoints.py | 106 ++++++++ .../models/image_generation/openai.py | 129 +++++++++ .../models/image_generation/utils.py | 31 +++ src/distilabel/models/llms/__init__.py | 2 +- src/distilabel/models/llms/anthropic.py | 9 +- src/distilabel/models/llms/base.py | 90 +------ src/distilabel/models/llms/cohere.py | 6 +- src/distilabel/models/llms/groq.py | 6 +- .../llms/huggingface/inference_endpoints.py | 152 ++--------- .../models/llms/huggingface/transformers.py | 9 +- src/distilabel/models/llms/litellm.py | 7 +- src/distilabel/models/llms/llamacpp.py | 9 +- src/distilabel/models/llms/mistral.py | 6 +- src/distilabel/models/llms/mlx.py | 5 +- src/distilabel/models/llms/moa.py | 5 +- src/distilabel/models/llms/ollama.py | 10 +- src/distilabel/models/llms/openai.py | 108 +------- src/distilabel/models/llms/typing.py | 62 ----- src/distilabel/models/llms/utils.py | 4 +- src/distilabel/models/llms/vertexai.py | 5 +- src/distilabel/models/llms/vllm.py | 19 +- src/distilabel/models/mixins/magpie.py | 2 +- src/distilabel/pipeline/base.py | 4 +- src/distilabel/pipeline/local.py | 2 +- src/distilabel/pipeline/ray.py | 2 +- .../pipeline/routing_batch_function.py | 2 +- src/distilabel/pipeline/step_wrapper.py | 2 +- src/distilabel/steps/__init__.py | 2 +- src/distilabel/steps/argilla/base.py | 2 +- src/distilabel/steps/argilla/preference.py | 2 +- .../steps/argilla/text_generation.py | 2 +- src/distilabel/steps/base.py | 6 +- src/distilabel/steps/clustering/dbscan.py | 2 +- .../steps/clustering/text_clustering.py | 2 +- src/distilabel/steps/clustering/umap.py | 2 +- src/distilabel/steps/columns/combine.py | 2 +- src/distilabel/steps/columns/expand.py | 2 +- src/distilabel/steps/columns/group.py | 2 +- src/distilabel/steps/columns/keep.py | 2 +- src/distilabel/steps/columns/merge.py | 2 +- src/distilabel/steps/decorator.py | 2 +- .../steps/embeddings/embedding_generation.py | 2 +- .../steps/embeddings/nearest_neighbour.py | 2 +- src/distilabel/steps/filtering/embedding.py | 2 +- src/distilabel/steps/filtering/minhash.py | 2 +- .../steps/formatting/conversation.py | 2 +- src/distilabel/steps/formatting/dpo.py | 2 +- src/distilabel/steps/formatting/sft.py | 2 +- src/distilabel/steps/generators/data.py | 2 +- .../steps/generators/huggingface.py | 2 +- src/distilabel/steps/globals/huggingface.py | 2 +- src/distilabel/steps/reward_model.py | 3 +- src/distilabel/steps/tasks/__init__.py | 10 +- .../steps/tasks/apigen/execution_checker.py | 2 +- .../steps/tasks/apigen/generator.py | 3 +- .../steps/tasks/apigen/semantic_checker.py | 3 +- src/distilabel/steps/tasks/apigen/utils.py | 2 +- .../steps/tasks/argilla_labeller.py | 3 +- src/distilabel/steps/tasks/base.py | 107 +++++++- src/distilabel/steps/tasks/clair.py | 3 +- .../steps/tasks/complexity_scorer.py | 2 +- src/distilabel/steps/tasks/decorator.py | 4 +- .../steps/tasks/evol_instruct/base.py | 5 +- .../steps/tasks/evol_instruct/generator.py | 4 +- .../steps/tasks/evol_quality/base.py | 4 +- .../steps/tasks/generate_embeddings.py | 3 +- src/distilabel/steps/tasks/genstruct.py | 2 +- .../steps/tasks/image_generation.py | 188 +++++++++++++ .../steps/tasks/improving_text_embeddings.py | 3 +- .../tasks/instruction_backtranslation.py | 2 +- src/distilabel/steps/tasks/magpie/base.py | 4 +- .../steps/tasks/magpie/generator.py | 3 +- .../steps/tasks/math_shepherd/completer.py | 4 +- .../steps/tasks/math_shepherd/generator.py | 3 +- .../steps/tasks/math_shepherd/utils.py | 2 +- src/distilabel/steps/tasks/pair_rm.py | 2 +- src/distilabel/steps/tasks/prometheus_eval.py | 2 +- src/distilabel/steps/tasks/quality_scorer.py | 2 +- src/distilabel/steps/tasks/self_instruct.py | 2 +- .../steps/tasks/sentence_transformers.py | 2 +- .../steps/tasks/structured_generation.py | 2 +- .../tasks/structured_outputs/outlines.py | 2 +- .../steps/tasks/text_classification.py | 2 +- src/distilabel/steps/tasks/text_generation.py | 3 +- .../steps/tasks/text_generation_with_image.py | 3 +- src/distilabel/steps/tasks/ultrafeedback.py | 2 +- src/distilabel/steps/tasks/urial.py | 3 +- .../{typing.py => typing/__init__.py} | 44 +++- src/distilabel/typing/base.py | 46 ++++ .../tasks/typing.py => typing/models.py} | 72 +++-- .../typing.py => typing/pipeline.py} | 3 +- .../{steps/typing.py => typing/steps.py} | 0 .../utils/export_components_info.py | 22 ++ .../utils/mkdocs/components_gallery.py | 68 ++++- .../templates/components-gallery/index.md | 8 + .../integration/test_dataset_without_step.py | 2 +- tests/integration/test_embedding_dedup.py | 2 +- tests/integration/test_load_stages.py | 2 +- tests/integration/test_multiple_replicas.py | 2 +- .../test_offline_batch_generation.py | 3 +- tests/integration/test_pipe_llms.py | 2 +- tests/integration/test_pipe_simple.py | 2 +- tests/integration/test_ray_pipeline.py | 2 +- .../test_routing_batch_function.py | 2 +- tests/unit/conftest.py | 27 +- .../unit/models/image_generation/__init__.py | 14 + .../image_generation/huggingface/__init__.py | 14 + .../huggingface/test_inference_endpoints.py | 59 +++++ .../models/image_generation/test_openai.py | 105 ++++++++ .../mixins/test_cuda_device_placement.py | 2 +- tests/unit/pipeline/test_base.py | 2 +- tests/unit/pipeline/test_dag.py | 2 +- tests/unit/pipeline/utils.py | 2 +- tests/unit/steps/argilla/test_base.py | 2 +- .../steps/clustering/test_text_clustering.py | 3 +- .../unit/steps/tasks/apigen/test_generator.py | 6 +- .../tasks/math_shepherd/test_completer.py | 2 +- .../tasks/math_shepherd/test_generator.py | 2 +- .../tasks/structured_outputs/test_outlines.py | 2 +- .../unit/steps/tasks/test_argilla_labeller.py | 2 +- .../unit/steps/tasks/test_image_generation.py | 55 ++++ .../tasks/test_improving_text_embeddings.py | 3 +- .../tasks/test_instruction_backtranslation.py | 3 +- .../steps/tasks/test_structured_generation.py | 3 +- .../steps/tasks/test_text_classification.py | 3 +- tests/unit/steps/tasks/test_ultrafeedback.py | 3 +- tests/unit/steps/test_base.py | 2 +- tests/unit/steps/test_decorator.py | 2 +- tests/unit/test_distiset.py | 56 ++++ tests/unit/utils/test_requirements.py | 2 +- 164 files changed, 2398 insertions(+), 632 deletions(-) create mode 100644 docs/api/models/image_generation/image_generation_gallery.md create mode 100644 docs/api/models/image_generation/index.md delete mode 100644 docs/api/pipeline/typing.md delete mode 100644 docs/api/step/typing.md create mode 100644 docs/api/task/image_task.md delete mode 100644 docs/api/task/typing.md create mode 100644 docs/api/typing.md create mode 100644 docs/sections/how_to_guides/basic/task/image_task.md create mode 100644 docs/sections/pipeline_samples/examples/image_generation.md create mode 100644 examples/image_generation.py create mode 100644 src/distilabel/models/base_clients/__init__.py create mode 100644 src/distilabel/models/base_clients/inference_endpoints.py create mode 100644 src/distilabel/models/base_clients/openai.py create mode 100644 src/distilabel/models/image_generation/__init__.py create mode 100644 src/distilabel/models/image_generation/base.py create mode 100644 src/distilabel/models/image_generation/huggingface/__init__.py create mode 100644 src/distilabel/models/image_generation/huggingface/inference_endpoints.py create mode 100644 src/distilabel/models/image_generation/openai.py create mode 100644 src/distilabel/models/image_generation/utils.py delete mode 100644 src/distilabel/models/llms/typing.py create mode 100644 src/distilabel/steps/tasks/image_generation.py rename src/distilabel/{typing.py => typing/__init__.py} (72%) create mode 100644 src/distilabel/typing/base.py rename src/distilabel/{steps/tasks/typing.py => typing/models.py} (66%) rename src/distilabel/{pipeline/typing.py => typing/pipeline.py} (98%) rename src/distilabel/{steps/typing.py => typing/steps.py} (100%) create mode 100644 tests/unit/models/image_generation/__init__.py create mode 100644 tests/unit/models/image_generation/huggingface/__init__.py create mode 100644 tests/unit/models/image_generation/huggingface/test_inference_endpoints.py create mode 100644 tests/unit/models/image_generation/test_openai.py create mode 100644 tests/unit/steps/tasks/test_image_generation.py diff --git a/docs/api/models/image_generation/image_generation_gallery.md b/docs/api/models/image_generation/image_generation_gallery.md new file mode 100644 index 0000000000..2baab4baee --- /dev/null +++ b/docs/api/models/image_generation/image_generation_gallery.md @@ -0,0 +1,10 @@ +# ImageGenerationModel Gallery + +This section contains the existing [`ImageGenerationModel`][distilabel.models.image_generation] subclasses implemented in `distilabel`. + +::: distilabel.models.image_generation + options: + filters: + - "!^ImageGenerationModel$" + - "!^AsyngImageGenerationModel$" + - "!typing" \ No newline at end of file diff --git a/docs/api/models/image_generation/index.md b/docs/api/models/image_generation/index.md new file mode 100644 index 0000000000..f8d326236b --- /dev/null +++ b/docs/api/models/image_generation/index.md @@ -0,0 +1,7 @@ +# ImageGenerationModel + +This section contains the API reference for the `distilabel` image generation models, both for the [`ImageGenerationModel`][distilabel.models.image_generation.ImageGenerationModel] synchronous implementation, and for the [`AsyncImageGenerationModel`][distilabel.models.image_generation.AsyncImageGenerationModel] asynchronous one. + +For more information and examples on how to use existing LLMs or create custom ones, please refer to [Tutorial - ImageGenerationModel](../../../sections/how_to_guides/basic/task/image_task.md). + +::: distilabel.models.image_generation.base diff --git a/docs/api/pipeline/typing.md b/docs/api/pipeline/typing.md deleted file mode 100644 index e4455ece00..0000000000 --- a/docs/api/pipeline/typing.md +++ /dev/null @@ -1,3 +0,0 @@ -# Pipeline Typing - -::: distilabel.pipeline.typing diff --git a/docs/api/step/typing.md b/docs/api/step/typing.md deleted file mode 100644 index 1a86e7dac1..0000000000 --- a/docs/api/step/typing.md +++ /dev/null @@ -1,3 +0,0 @@ -# Step Typing - -::: distilabel.steps.typing \ No newline at end of file diff --git a/docs/api/task/image_task.md b/docs/api/task/image_task.md new file mode 100644 index 0000000000..5cb698d548 --- /dev/null +++ b/docs/api/task/image_task.md @@ -0,0 +1,7 @@ +# ImageTask + +This section contains the API reference for the `distilabel` image generation tasks. + +For more information on how the [`ImageTask`][distilabel.steps.tasks.ImageTask] works and see some examples, check the [Tutorial - Task - ImageTask](../../sections/how_to_guides/basic/task/generator_task.md) page. + +::: distilabel.steps.tasks.base.ImageTask diff --git a/docs/api/task/task_gallery.md b/docs/api/task/task_gallery.md index 4cf90c479d..aa2f3ecf2d 100644 --- a/docs/api/task/task_gallery.md +++ b/docs/api/task/task_gallery.md @@ -8,5 +8,6 @@ This section contains the existing [`Task`][distilabel.steps.tasks.Task] subclas - "!Task" - "!_Task" - "!GeneratorTask" + - "!ImageTask" - "!ChatType" - "!typing" \ No newline at end of file diff --git a/docs/api/task/typing.md b/docs/api/task/typing.md deleted file mode 100644 index 818ad070b6..0000000000 --- a/docs/api/task/typing.md +++ /dev/null @@ -1,3 +0,0 @@ -# Task Typing - -::: distilabel.steps.tasks.typing \ No newline at end of file diff --git a/docs/api/typing.md b/docs/api/typing.md new file mode 100644 index 0000000000..53d33868d8 --- /dev/null +++ b/docs/api/typing.md @@ -0,0 +1,8 @@ +# Types + +This section contains the different types used accross the distilabel codebase. + +::: distilabel.typing.base +::: distilabel.typing.steps +::: distilabel.typing.models +::: distilabel.typing.pipeline diff --git a/docs/sections/how_to_guides/advanced/distiset.md b/docs/sections/how_to_guides/advanced/distiset.md index 1c00554e28..001ec827ed 100644 --- a/docs/sections/how_to_guides/advanced/distiset.md +++ b/docs/sections/how_to_guides/advanced/distiset.md @@ -119,6 +119,33 @@ class MagpieGenerator(GeneratorTask, MagpieBase): The `Citations` section can include any number of bibtex references. To define them, you can add as much elements as needed just like in the example: each citation will be a block of the form: ` ```@misc{...}``` `. This information will be automatically used in the README of your `Distiset` if you decide to call `distiset.push_to_hub`. Alternatively, if the `Citations` is not found, but in the `References` there are found any urls pointing to `https://arxiv.org/`, we will try to obtain the `Bibtex` equivalent automatically. This way, Hugging Face can automatically track the paper for you and it's easier to find other datasets citing the same paper, or directly visiting the paper page. +#### Image Datasets + +!!! info "Keep reading if you are interested in Image datasets" + + The `Distiset` object has a new method `transform_columns_to_image` specifically to transform the images to `PIL.Image.Image` before pushing the dataset to the hugging face hub. + +Since version `1.5.0` we have the [`ImageGeneration`](https://distilabel.argilla.io/dev/components-gallery/task/imagegeneration/) task that is able to generate images from text. By default, all the process will work internally with a string representation for the images. This is done for simplicity while processing. But to take advantage of the Hugging Face Hub functionalities if the dataset generated is going to be stored there, a proper Image object may be preferable, so we can see the images in the dataset viewer for example. Let's take a look at the following pipeline extracted from "examples/image_generation.py" at the root of the repository to see how we can do it: + +```diff +# Assume all the imports are already done, we are only interested +with Pipeline(name="image_generation_pipeline") as pipeline: + img_generation = ImageGeneration( + name="flux_schnell", + llm=igm, + InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell") + ) + ... + +if __name__ == "__main__": + distiset = pipeline.run(use_cache=False, dataset=ds) + # Save the images as `PIL.Image.Image` ++ distiset = distiset.transform_columns_to_image("image") + distiset.push_to_hub(...) +``` + +After calling [`transform_columns_to_image`][distilabel.distiset.Distiset.transform_columns_to_image] on the image columns we may have generated (in this case we only want to transform the `image` column, but a list can be passed). This will apply to any leaf nodes we have in the pipeline, meaning if we have different subsets, the "image" column will be found in all of them, or we can pass a list of columns. + ### Save and load from disk Take into account that these methods work as `datasets.load_from_disk` and `datasets.Dataset.save_to_disk` so the arguments are directly passed to those methods. This means you can also make use of `storage_options` argument to save your [`Distiset`][distilabel.distiset.Distiset] in your cloud provider, including the distilabel artifacts (`pipeline.yaml`, `pipeline.log` and the `README.md` with the dataset card). You can read more in `datasets` documentation [here](https://huggingface.co/docs/datasets/filesystems#saving-serialized-datasets). diff --git a/docs/sections/how_to_guides/advanced/pipeline_requirements.md b/docs/sections/how_to_guides/advanced/pipeline_requirements.md index 66a2594bd2..3f739cbf2a 100644 --- a/docs/sections/how_to_guides/advanced/pipeline_requirements.md +++ b/docs/sections/how_to_guides/advanced/pipeline_requirements.md @@ -9,7 +9,7 @@ from typing import List from distilabel.steps import Step from distilabel.steps.base import StepInput -from distilabel.steps.typing import StepOutput +from distilabel.typing import StepOutput from distilabel.steps import LoadDataFromDicts from distilabel.utils.requirements import requirements from distilabel.pipeline import Pipeline diff --git a/docs/sections/how_to_guides/advanced/structured_generation.md b/docs/sections/how_to_guides/advanced/structured_generation.md index 3eb1da99af..1675d369a6 100644 --- a/docs/sections/how_to_guides/advanced/structured_generation.md +++ b/docs/sections/how_to_guides/advanced/structured_generation.md @@ -21,7 +21,7 @@ The [`LLM`][distilabel.models.llms.LLM] has an argument named `structured_output We will start with a JSON example, where we initially define a `pydantic.BaseModel` schema to guide the generation of the structured output. !!! NOTE - Take a look at [`StructuredOutputType`][distilabel.steps.tasks.typing.StructuredOutputType] to see the expected format + Take a look at [`StructuredOutputType`][distilabel.typing.models.StructuredOutputType] to see the expected format of the `structured_output` dict variable. ```python @@ -139,7 +139,7 @@ For other LLM providers behind APIs, there's no direct way of accessing the inte ``` !!! Note - Take a look at [`InstructorStructuredOutputType`][distilabel.steps.tasks.typing.InstructorStructuredOutputType] to see the expected format + Take a look at [`InstructorStructuredOutputType`][distilabel.typing.models.InstructorStructuredOutputType] to see the expected format of the `structured_output` dict variable. The following is the same example you can see with `outlines`'s `JSON` section for comparison purposes. diff --git a/docs/sections/how_to_guides/basic/step/generator_step.md b/docs/sections/how_to_guides/basic/step/generator_step.md index 0422644c36..50ca5e52d7 100644 --- a/docs/sections/how_to_guides/basic/step/generator_step.md +++ b/docs/sections/how_to_guides/basic/step/generator_step.md @@ -9,7 +9,7 @@ from typing_extensions import override from distilabel.steps import GeneratorStep if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, GeneratorStepOutput + from distilabel.typing import StepColumns, GeneratorStepOutput class MyGeneratorStep(GeneratorStep): instructions: List[str] @@ -67,7 +67,7 @@ We can define a custom generator step by creating a new subclass of the [`Genera The default signature for the `process` method is `process(self, offset: int = 0) -> GeneratorStepOutput`. The argument `offset` should be respected, no more arguments can be provided, and the type-hints and return type-hints should be respected too because it should be able to receive any number of inputs by default i.e. more than one [`Step`][distilabel.steps.Step] at a time could be connected to the current one. !!! WARNING - For the custom [`Step`][distilabel.steps.Step] subclasses to work properly with `distilabel` and with the validation and serialization performed by default over each [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline], the type-hint for both [`StepInput`][distilabel.steps.StepInput] and [`StepOutput`][distilabel.steps.typing.StepOutput] should be used and not surrounded with double-quotes or imported under `typing.TYPE_CHECKING`, otherwise, the validation and/or serialization will fail. + For the custom [`Step`][distilabel.steps.Step] subclasses to work properly with `distilabel` and with the validation and serialization performed by default over each [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline], the type-hint for both [`StepInput`][distilabel.steps.StepInput] and [`StepOutput`][distilabel.typing.StepOutput] should be used and not surrounded with double-quotes or imported under `typing.TYPE_CHECKING`, otherwise, the validation and/or serialization will fail. === "Inherit from `GeneratorStep`" @@ -81,7 +81,7 @@ We can define a custom generator step by creating a new subclass of the [`Genera from distilabel.steps import GeneratorStep if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, GeneratorStepOutput + from distilabel.typing import StepColumns, GeneratorStepOutput class MyGeneratorStep(GeneratorStep): instructions: List[str] @@ -104,7 +104,7 @@ We can define a custom generator step by creating a new subclass of the [`Genera from distilabel.steps import step if TYPE_CHECKING: - from distilabel.steps.typing import GeneratorStepOutput + from distilabel.typing import GeneratorStepOutput @step(outputs=[...], step_type="generator") def CustomGeneratorStep(offset: int = 0) -> "GeneratorStepOutput": diff --git a/docs/sections/how_to_guides/basic/step/global_step.md b/docs/sections/how_to_guides/basic/step/global_step.md index 814f01a0fb..db050a6dc7 100644 --- a/docs/sections/how_to_guides/basic/step/global_step.md +++ b/docs/sections/how_to_guides/basic/step/global_step.md @@ -16,7 +16,7 @@ We can define a custom step by creating a new subclass of the [`GlobalStep`][dis The default signature for the `process` method is `process(self, *inputs: StepInput) -> StepOutput`. The argument `inputs` should be respected, no more arguments can be provided, and the type-hints and return type-hints should be respected too because it should be able to receive any number of inputs by default i.e. more than one [`Step`][distilabel.steps.Step] at a time could be connected to the current one. !!! WARNING - For the custom [`GlobalStep`][distilabel.steps.GlobalStep] subclasses to work properly with `distilabel` and with the validation and serialization performed by default over each [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline], the type-hint for both [`StepInput`][distilabel.steps.StepInput] and [`StepOutput`][distilabel.steps.typing.StepOutput] should be used and not surrounded with double-quotes or imported under `typing.TYPE_CHECKING`, otherwise, the validation and/or serialization will fail. + For the custom [`GlobalStep`][distilabel.steps.GlobalStep] subclasses to work properly with `distilabel` and with the validation and serialization performed by default over each [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline], the type-hint for both [`StepInput`][distilabel.steps.StepInput] and [`StepOutput`][distilabel.typing.StepOutput] should be used and not surrounded with double-quotes or imported under `typing.TYPE_CHECKING`, otherwise, the validation and/or serialization will fail. === "Inherit from `GlobalStep`" @@ -27,7 +27,7 @@ We can define a custom step by creating a new subclass of the [`GlobalStep`][dis from distilabel.steps import GlobalStep, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class CustomStep(Step): @property @@ -61,7 +61,7 @@ We can define a custom step by creating a new subclass of the [`GlobalStep`][dis from distilabel.steps import StepInput, step if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput @step(inputs=[...], outputs=[...], step_type="global") def CustomStep(inputs: StepInput) -> "StepOutput": diff --git a/docs/sections/how_to_guides/basic/step/index.md b/docs/sections/how_to_guides/basic/step/index.md index d03a6b2149..76cae37075 100644 --- a/docs/sections/how_to_guides/basic/step/index.md +++ b/docs/sections/how_to_guides/basic/step/index.md @@ -11,7 +11,7 @@ from typing import TYPE_CHECKING from distilabel.steps import Step, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class MyStep(Step): @property @@ -87,7 +87,7 @@ We can define a custom step by creating a new subclass of the [`Step`][distilabe The default signature for the `process` method is `process(self, *inputs: StepInput) -> StepOutput`. The argument `inputs` should be respected, no more arguments can be provided, and the type-hints and return type-hints should be respected too because it should be able to receive any number of inputs by default i.e. more than one [`Step`][distilabel.steps.Step] at a time could be connected to the current one. !!! WARNING - For the custom [`Step`][distilabel.steps.Step] subclasses to work properly with `distilabel` and with the validation and serialization performed by default over each [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline], the type-hint for both [`StepInput`][distilabel.steps.StepInput] and [`StepOutput`][distilabel.steps.typing.StepOutput] should be used and not surrounded with double-quotes or imported under `typing.TYPE_CHECKING`, otherwise, the validation and/or serialization will fail. + For the custom [`Step`][distilabel.steps.Step] subclasses to work properly with `distilabel` and with the validation and serialization performed by default over each [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline], the type-hint for both [`StepInput`][distilabel.steps.StepInput] and [`StepOutput`][distilabel.typing.StepOutput] should be used and not surrounded with double-quotes or imported under `typing.TYPE_CHECKING`, otherwise, the validation and/or serialization will fail. === "Inherit from `Step`" @@ -98,7 +98,7 @@ We can define a custom step by creating a new subclass of the [`Step`][distilabe from distilabel.steps import Step, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class CustomStep(Step): @property @@ -132,7 +132,7 @@ We can define a custom step by creating a new subclass of the [`Step`][distilabe from distilabel.steps import StepInput, step if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput @step(inputs=[...], outputs=[...]) def CustomStep(inputs: StepInput) -> "StepOutput": diff --git a/docs/sections/how_to_guides/basic/task/generator_task.md b/docs/sections/how_to_guides/basic/task/generator_task.md index 6fbb3d742e..bb86c28e31 100644 --- a/docs/sections/how_to_guides/basic/task/generator_task.md +++ b/docs/sections/how_to_guides/basic/task/generator_task.md @@ -12,8 +12,7 @@ from typing import Any, Dict, List, Union from typing_extensions import override from distilabel.steps.tasks.base import GeneratorTask -from distilabel.steps.tasks.typing import ChatType -from distilabel.steps.typing import GeneratorOutput +from distilabel.typing import ChatType, GeneratorOutput class MyCustomTask(GeneratorTask): @@ -78,7 +77,7 @@ We can define a custom generator task by creating a new subclass of the [`Genera from typing import Any, Dict, List, Union from distilabel.steps.tasks.base import GeneratorTask -from distilabel.steps.tasks.typing import ChatType +from distilabel.typing import ChatType class MyCustomTask(GeneratorTask): diff --git a/docs/sections/how_to_guides/basic/task/image_task.md b/docs/sections/how_to_guides/basic/task/image_task.md new file mode 100644 index 0000000000..ecdee6c66f --- /dev/null +++ b/docs/sections/how_to_guides/basic/task/image_task.md @@ -0,0 +1,104 @@ +# ImageTask to work with Image Generation Models + +## Working with ImageTasks + +The [`ImageTask`][distilabel.steps.tasks.ImageTask] is a custom implementation of a [`Task`][distilabel.steps.tasks.Task] special to deal images. These tasks behave exactly as any other [`Task`][distilabel.steps.tasks.Task], but instead of relying on an [`LLM`][distilabel.models.llms.LLM], they work with a [`ImageGenerationModel`][distilabel.models.image_generation.ImageGenerationModel]. + +!!! info "New in version 1.5.0" + This task is new and is expected to work with Image Generation Models. + +These tasks take as attribute an `image_generation_model` instead of `llm` as we would have with the standard `Task`, but everything else remains the same. Let's see an example with [`ImageGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/imagegeneration/): + +```python +from distilabel.steps.tasks import ImageGeneration +from distilabel.models.image_generation import InferenceEndpointsImageGeneration + +task = ImageGeneration( + name="image-generation", + image_generation_model=InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell"), +) +task.load() + +next(task.process([{"prompt": "a white siamese cat"}])) +# [{'image": "iVBORw0KGgoAAAANSUhEUgA...", "model_name": "black-forest-labs/FLUX.1-schnell"}] +``` + +!!! info "Visualize the image in a notebook" + If you are testing the `ImageGeneration` task in a notebook, you can do the following + to see the rendered image: + + ```python + from distilabel.models.image_generation.utils import image_from_str + + result = next(task.process([{"prompt": "a white siamese cat"}])) + image_from_str(result[0]["image"]) # Returns a `PIL.Image.Image` that renders directly + ``` + +!!! tip "Running ImageGeneration in a Pipeline" + This transformation between image as string and as PIL object can be done for the whole dataset if running a pipeline, by calling the method `transform_columns_to_image` on the final distiset and passing the name (or list of names) of the column image. + +## Defining custom ImageTasks + +We can define a custom generator task by creating a new subclass of the [`ImageTask`][distilabel.steps.tasks.ImageTask] and defining the following: + +- `process`: is a method that generates the data based on the [`ImageGenerationModel`][distilabel.models.image_generation.ImageGenerationModel] and the `prompt` provided within the class instance, and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. + +- `inputs`: is a property that returns a list of strings with the names of the required input fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not. + +- `outputs`: is a property that returns a list of strings with the names of the output fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not. This property should always include `model_name` as one of the outputs since that's automatically injected from the LLM. + +- `format_input`: is a method that receives a dictionary with the input data and returns a *prompt* to be passed to the model. + +- `format_output`: is a method that receives the output from the [`ImageGenerationModel`][distilabel.models.image_generation.ImageGenerationModel] and optionally also the input data (which may be useful to build the output in some scenarios), and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. + +```python +from typing import TYPE_CHECKING + +from distilabel.models.image_generation.utils import image_from_str, image_to_str +from distilabel.steps.base import StepInput +from distilabel.steps.tasks.base import ImageTask + +if TYPE_CHECKING: + from distilabel.typing import StepColumns, StepOutput + + +class MyCustomImageTask(ImageTask): + @override + def process(self, offset: int = 0) -> GeneratorOutput: + formatted_inputs = self._format_inputs(inputs) + + outputs = self.llm.generate_outputs( + inputs=formatted_inputs, + num_generations=self.num_generations, + **self.llm.get_generation_kwargs(), + ) + + task_outputs = [] + for input, input_outputs in zip(inputs, outputs): + formatted_outputs = self._format_outputs(input_outputs, input) + for formatted_output in formatted_outputs: + task_outputs.append( + {**input, **formatted_output, "model_name": self.llm.model_name} + ) + yield task_outputs + + @property + def inputs(self) -> "StepColumns": + return ["prompt"] + + @property + def outputs(self) -> "StepColumns": + return ["image", "model_name"] + + def format_input(self, input: dict[str, any]) -> str: + return input["prompt"] + + def format_output( + self, output: Union[str, None], input: dict[str, any] + ) -> Dict[str, Any]: + # Extract/generate/modify the image from the output + return {"image": ..., "model_name": self.llm.model_name} +``` + +!!! Warning + Note the fact that in the `process` method we are not dealing with the `image_generation` attribute but with the `llm`. This is not a bug, but intended, as internally we rename the `image_generation` to `llm` to reuse the code. diff --git a/docs/sections/how_to_guides/basic/task/index.md b/docs/sections/how_to_guides/basic/task/index.md index dd5de6f837..c2291c8769 100644 --- a/docs/sections/how_to_guides/basic/task/index.md +++ b/docs/sections/how_to_guides/basic/task/index.md @@ -217,7 +217,7 @@ We can define a custom step by creating a new subclass of the [`Task`][distilabe - `inputs`: is a property that returns a list of strings with the names of the required input fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not. -- `format_input`: is a method that receives a dictionary with the input data and returns a [`ChatType`][distilabel.steps.tasks.ChatType] following [the chat-completion OpenAI message formatting](https://platform.openai.com/docs/guides/text-generation). +- `format_input`: is a method that receives a dictionary with the input data and returns a [`ChatType`][distilabel.typing.models.ChatType] following [the chat-completion OpenAI message formatting](https://platform.openai.com/docs/guides/text-generation). - `outputs`: is a property that returns a list of strings with the names of the output fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not. This property should always include `model_name` as one of the outputs since that's automatically injected from the LLM. @@ -233,8 +233,7 @@ We can define a custom step by creating a new subclass of the [`Task`][distilabe from distilabel.steps.tasks import Task if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns - from distilabel.steps.tasks.typing import ChatType + from distilabel.typing import StepColumns, ChatType class MyCustomTask(Task): diff --git a/docs/sections/pipeline_samples/examples/image_generation.md b/docs/sections/pipeline_samples/examples/image_generation.md new file mode 100644 index 0000000000..39f8daba3b --- /dev/null +++ b/docs/sections/pipeline_samples/examples/image_generation.md @@ -0,0 +1,108 @@ +--- +hide: toc +--- + +# Image generation with `distilabel` + +Create synthetic images using `distilabel`. + +This example shows how distilabel can be used to generate image data, either using [`InferenceEndpointsImageGeneration`](https://distilabel.argilla.io/dev/components-gallery/image_generation/inferenceendpointsimagegeneration/) or [`OpenAIImageGeneration`](https://distilabel.argilla.io/dev/components-gallery/image_generation/openaiimagegeneration/), thanks to the [`ImageGeneration`](https://distilabel.argilla.io/dev/components-gallery/task/imagegeneration/) task. + + +=== "Inference Endpoints - black-forest-labs/FLUX.1-schnell" + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps import KeepColumns + from distilabel.models.image_generation import InferenceEndpointsImageGeneration + from distilabel.steps.tasks import ImageGeneration + + from datasets import load_dataset + + ds = load_dataset("dvilasuero/finepersonas-v0.1-tiny", split="train").select(range(3)) + + with Pipeline(name="image_generation_pipeline") as pipeline: + ilm = InferenceEndpointsImageGeneration( + model_id="black-forest-labs/FLUX.1-schnell" + ) + + img_generation = ImageGeneration( + name="flux_schnell", + llm=ilm, + input_mappings={"prompt": "persona"} + ) + + keep_columns = KeepColumns(columns=["persona", "model_name", "image"]) + + img_generation >> keep_columns + ``` + + Sample image for the prompt: + + > A local art historian and museum professional interested in 19th-century American art and the local cultural heritage of Cincinnati. + + ![image_ie](https://huggingface.co/datasets/plaguss/test-finepersonas-v0.1-tiny-flux-schnell/resolve/main/artifacts/flux_schnell/images/3333f9870feda32a449994017eb72675.jpeg) + +=== "OpenAI - dall-e-3" + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps import KeepColumns + from distilabel.models.image_generation import OpenAIImageGeneration + from distilabel.steps.tasks import ImageGeneration + + from datasets import load_dataset + + ds = load_dataset("dvilasuero/finepersonas-v0.1-tiny", split="train").select(range(3)) + + with Pipeline(name="image_generation_pipeline") as pipeline: + ilm = OpenAIImageGeneration( + model="dall-e-3", + generation_kwargs={ + "size": "1024x1024", + "quality": "standard", + "style": "natural" + } + ) + + img_generation = ImageGeneration( + name="dalle-3" + llm=ilm, + input_mappings={"prompt": "persona"} + ) + + keep_columns = KeepColumns(columns=["persona", "model_name", "image"]) + + img_generation >> keep_columns + ``` + + Sample image for the prompt: + + > A local art historian and museum professional interested in 19th-century American art and the local cultural heritage of Cincinnati. + + ![image_oai](https://huggingface.co/datasets/plaguss/test-finepersonas-v0.1-tiny-dall-e-3/resolve/main/artifacts/dalle-3/images/3333f9870feda32a449994017eb72675.jpeg) + +!!! success "Save the Distiset as an Image Dataset" + + Note the call to `Distiset.transform_columns_to_image`, to have the images uploaded directly as an [`Image dataset`](https://huggingface.co/docs/hub/en/datasets-image): + + ```python + if __name__ == "__main__": + distiset = pipeline.run(use_cache=False, dataset=ds) + # Save the images as `PIL.Image.Image` + distiset = distiset.transform_columns_to_image("image") + distiset.push_to_hub("plaguss/test-finepersonas-v0.1-tiny-flux-schnell") + + ``` + +The full pipeline can be run at the following example. Keep in mind, you need to install `pillow` first: `pip install distilabel[vision]`. + +??? Run + + ```python + python examples/image_generation.py + ``` + +```python title="image_generation.py" +--8<-- "examples/image_generation.py" +``` diff --git a/docs/sections/pipeline_samples/index.md b/docs/sections/pipeline_samples/index.md index 1c95b60b18..956af3b518 100644 --- a/docs/sections/pipeline_samples/index.md +++ b/docs/sections/pipeline_samples/index.md @@ -161,6 +161,14 @@ hide: toc [:octicons-arrow-right-24: Example](examples/exam_questions.md) +- __Image generation with distilabel__ + + --- + + Generate synthetic images using distilabel. + + [:octicons-arrow-right-24: Example](examples/image_generation.md) + - __Text generation with images in distilabel__ --- diff --git a/docs/sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb b/docs/sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb index 7b75f7fcaa..6730a80892 100644 --- a/docs/sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb +++ b/docs/sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb @@ -196,7 +196,7 @@ " from distilabel.steps import GlobalStep, StepInput\n", "\n", " if TYPE_CHECKING:\n", - " from distilabel.steps.typing import StepOutput\n", + " from distilabel.typing import StepOutput\n", " \n", " import random\n", "\n", diff --git a/examples/image_generation.py b/examples/image_generation.py new file mode 100644 index 0000000000..dbee42ebe4 --- /dev/null +++ b/examples/image_generation.py @@ -0,0 +1,42 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datasets import load_dataset + +from distilabel.models.image_generation import InferenceEndpointsImageGeneration +from distilabel.pipeline import Pipeline +from distilabel.steps import KeepColumns +from distilabel.steps.tasks import ImageGeneration + +ds = load_dataset("dvilasuero/finepersonas-v0.1-tiny", split="train").select(range(3)) + +with Pipeline(name="image_generation_pipeline") as pipeline: + igm = InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell") + + img_generation = ImageGeneration( + name="flux_schnell", + image_generation_model=igm, + input_mappings={"prompt": "persona"}, + ) + + keep_columns = KeepColumns(columns=["persona", "model_name", "image"]) + + img_generation >> keep_columns + + +if __name__ == "__main__": + distiset = pipeline.run(use_cache=False, dataset=ds) + # Save the images as `PIL.Image.Image` + distiset = distiset.transform_columns_to_image("image") + distiset.push_to_hub("plaguss/test-finepersonas-v0.1-tiny-flux-schnell") diff --git a/mkdocs.yml b/mkdocs.yml index f5a98be65d..24e5ca9b74 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -184,6 +184,7 @@ nav: - Tasks for generating and judging with LLMs: - "sections/how_to_guides/basic/task/index.md" - GeneratorTask: "sections/how_to_guides/basic/task/generator_task.md" + - ImageTask: "sections/how_to_guides/basic/task/image_task.md" - Executing Tasks with LLMs: "sections/how_to_guides/basic/llm/index.md" - Execute Steps and Tasks in a Pipeline: "sections/how_to_guides/basic/pipeline/index.md" - Advanced: @@ -222,6 +223,7 @@ nav: - Structured generation with instructor: "sections/pipeline_samples/examples/mistralai_with_instructor.md" - Create a social network with FinePersonas: "sections/pipeline_samples/examples/fine_personas_social_network.md" - Create questions and answers for a exam: "sections/pipeline_samples/examples/exam_questions.md" + - Image generation with distilabel: "sections/pipeline_samples/examples/image_generation.md" - Text generation with images in distilabel: "sections/pipeline_samples/examples/text_generation_with_image.md" - API Reference: - Step: @@ -235,22 +237,22 @@ nav: - Hugging Face: "api/step_gallery/hugging_face.md" - Columns: "api/step_gallery/columns.md" - Extra: "api/step_gallery/extra.md" - - Typing: "api/step/typing.md" - Task: - "api/task/index.md" - GeneratorTask: "api/task/generator_task.md" - Task Gallery: "api/task/task_gallery.md" - - Typing: "api/task/typing.md" - LLM: - "api/models/llm/index.md" - LLM Gallery: "api/models/llm/llm_gallery.md" - Embedding: - "api/models/embedding/index.md" - Embedding Gallery: "api/models/embedding/embedding_gallery.md" + - ImageGenerationModels: + - "api/models/image_generation/index.md" + - Image Generation Gallery: "api/models/image_generation/image_generation_gallery.md" - Pipeline: - "api/pipeline/index.md" - Routing Batch Function: "api/pipeline/routing_batch_function.md" - - Typing: "api/pipeline/typing.md" - Step Wrapper: "api/pipeline/step_wrapper.md" - Mixins: - RuntimeParametersMixin: "api/mixins/runtime_parameters.md" @@ -259,6 +261,7 @@ nav: - Errors: "api/errors.md" - Distiset: "api/distiset.md" - CLI: "api/cli.md" + - Types: "api/typing.md" - Community: - sections/community/index.md - How to contribute?: sections/community/contributor.md diff --git a/pyproject.toml b/pyproject.toml index 3123d56b55..1c55ebb1c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,7 @@ text-clustering = [ "matplotlib >= 3.8.3", # For the figure (even though it's optional) ] mlx = ["mlx >= 0.21.0", "mlx-lm"] +vision = ["Pillow >= 10.3.0"] # To work with images. # minhash minhash = ["datasketch >= 1.6.5", "nltk>3.8.1"] diff --git a/src/distilabel/distiset.py b/src/distilabel/distiset.py index e934f9d340..ce4e855858 100644 --- a/src/distilabel/distiset.py +++ b/src/distilabel/distiset.py @@ -24,7 +24,7 @@ import fsspec import yaml -from datasets import Dataset, load_dataset, load_from_disk +from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from datasets.filesystems import is_remote_filesystem from huggingface_hub import DatasetCardData, HfApi, upload_file, upload_folder from huggingface_hub.file_download import hf_hub_download @@ -187,9 +187,14 @@ def _get_card( record = ( dataset[0] if not isinstance(dataset, dict) else dataset["train"][0] ) + from PIL import ImageFile + for key, value in record.items(): + # If the value is an image, we set it to an empty string to avoid the `README.md` to huge + if isinstance(value, ImageFile.ImageFile): + value = "" # If list is too big, the `README.md` generated will be huge so we truncate it - if isinstance(value, list): + elif isinstance(value, list): length = len(value) if length < 10: continue @@ -585,6 +590,51 @@ def __repr__(self): repr = re.sub(r"^", " " * 4, repr, count=0, flags=re.M) return f"Distiset({{\n{repr}\n}})" + def transform_columns_to_image(self, columns: Union[str, list[str]]) -> Self: + """Transforms the columns of the dataset to `PIL.Image` objects. + + Args: + columns: Column or list of columns to transform. + + Returns: + Transforms the columns of the dataset to `PIL.Image` objects before pushing, + so the Hub treats them as Image objects and can be rendered in the dataset + viewer, and cast them to be automatically transformed when downloading + the dataset back. + """ + from datasets import Image + + from distilabel.models.image_generation.utils import image_from_str + + columns = [columns] if isinstance(columns, str) else columns + + def cast_to_image(row: dict) -> dict: + for column in columns: + row[column] = image_from_str(row[column]) + return row + + for name, dataset in self.items(): + # In case train_test_split was called + if isinstance(dataset, DatasetDict): + for split, dataset_split in dataset.items(): + dataset_split = dataset_split.map(cast_to_image) + for column in columns: + if column in dataset_split.column_names: + dataset_split = dataset_split.cast_column( + column, Image(decode=True) + ) + self[name][split] = dataset_split + else: + dataset = dataset.map(cast_to_image) + + for column in columns: + if column in dataset.column_names: + dataset = dataset.cast_column(column, Image(decode=True)) + + self[name] = dataset + + return self + def create_distiset( # noqa: C901 data_dir: Path, diff --git a/src/distilabel/llms.py b/src/distilabel/llms.py index b00d891407..8d579048df 100644 --- a/src/distilabel/llms.py +++ b/src/distilabel/llms.py @@ -37,10 +37,10 @@ from distilabel.models.llms.ollama import OllamaLLM from distilabel.models.llms.openai import OpenAILLM from distilabel.models.llms.together import TogetherLLM -from distilabel.models.llms.typing import GenerateOutput, HiddenState from distilabel.models.llms.vertexai import VertexAILLM from distilabel.models.llms.vllm import ClientvLLM, vLLM from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin +from distilabel.typing import GenerateOutput, HiddenState __all__ = [ "LLM", diff --git a/src/distilabel/mixins/runtime_parameters.py b/src/distilabel/mixins/runtime_parameters.py index f8371e30ab..73b5845e8f 100644 --- a/src/distilabel/mixins/runtime_parameters.py +++ b/src/distilabel/mixins/runtime_parameters.py @@ -13,11 +13,14 @@ # limitations under the License. import difflib +import inspect +from functools import cached_property from typing import TYPE_CHECKING, Any, Dict, List, Tuple, TypeVar, Union from pydantic import BaseModel, Field, PrivateAttr from typing_extensions import Annotated, get_args, get_origin +from distilabel.utils.docstring import parse_google_docstring from distilabel.utils.typing_ import ( extract_annotation_inner_type, is_type_pydantic_secret_field, @@ -26,6 +29,9 @@ if TYPE_CHECKING: from pydantic.fields import FieldInfo + from distilabel.utils.docstring import Docstring + + _T = TypeVar("_T") _RUNTIME_PARAMETER_ANNOTATION = "distilabel_step_runtime_parameter" RuntimeParameter = Annotated[ @@ -218,3 +224,84 @@ def _is_runtime_parameter(field: "FieldInfo") -> Tuple[bool, bool]: return True, is_optional return False, False + + +class RuntimeParametersModelMixin(RuntimeParametersMixin): + """Specific mixin for RuntimeParameters that affect the model classes, LLM, + ImageGenerationModel, etc. + """ + + @property + def generate_parameters(self) -> list["inspect.Parameter"]: + """Returns the parameters of the `generate` method. + + Returns: + A list containing the parameters of the `generate` method. + """ + return list(inspect.signature(self.generate).parameters.values()) + + @property + def runtime_parameters_names(self) -> "RuntimeParametersNames": + """Returns the runtime parameters of the `ImageGenerationModel`, which are combination of the + attributes of the `ImageGenerationModel` type hinted with `RuntimeParameter` and the parameters + of the `generate` method that are not `input` and `num_generations`. + + Returns: + A dictionary with the name of the runtime parameters as keys and a boolean + indicating if the parameter is optional or not. + """ + runtime_parameters = super().runtime_parameters_names + runtime_parameters["generation_kwargs"] = {} + + # runtime parameters from the `generate` method + for param in self.generate_parameters: + if param.name in ["input", "inputs", "num_generations"]: + continue + is_optional = param.default != inspect.Parameter.empty + runtime_parameters["generation_kwargs"][param.name] = is_optional + + return runtime_parameters + + def get_runtime_parameters_info(self) -> List["RuntimeParameterInfo"]: + """Gets the information of the runtime parameters of the `LLM` such as the name + and the description. This function is meant to include the information of the runtime + parameters in the serialized data of the `LLM`. + + Returns: + A list containing the information for each runtime parameter of the `LLM`. + """ + runtime_parameters_info = super().get_runtime_parameters_info() + + generation_kwargs_info = next( + ( + runtime_parameter_info + for runtime_parameter_info in runtime_parameters_info + if runtime_parameter_info["name"] == "generation_kwargs" + ), + None, + ) + + # If `generation_kwargs` attribute is present, we need to include the `generate` + # method arguments as the information for this attribute. + if generation_kwargs_info: + generate_docstring_args = self.generate_parsed_docstring["args"] + generation_kwargs_info["keys"] = [] + + for key, value in generation_kwargs_info["optional"].items(): + info = {"name": key, "optional": value} + if description := generate_docstring_args.get(key): + info["description"] = description + generation_kwargs_info["keys"].append(info) + + generation_kwargs_info.pop("optional") + + return runtime_parameters_info + + @cached_property + def generate_parsed_docstring(self) -> "Docstring": + """Returns the parsed docstring of the `generate` method. + + Returns: + The parsed docstring of the `generate` method. + """ + return parse_google_docstring(self.generate) diff --git a/src/distilabel/models/__init__.py b/src/distilabel/models/__init__.py index 86ea2023e4..1c96f5ab0b 100644 --- a/src/distilabel/models/__init__.py +++ b/src/distilabel/models/__init__.py @@ -19,6 +19,14 @@ SentenceTransformerEmbeddings, ) from distilabel.models.embeddings.vllm import vLLMEmbeddings +from distilabel.models.image_generation.base import ( + AsyncImageGenerationModel, + ImageGenerationModel, +) +from distilabel.models.image_generation.huggingface.inference_endpoints import ( + InferenceEndpointsImageGeneration, +) +from distilabel.models.image_generation.openai import OpenAIImageGeneration from distilabel.models.llms.anthropic import AnthropicLLM from distilabel.models.llms.anyscale import AnyscaleLLM from distilabel.models.llms.azure import AzureOpenAILLM @@ -34,15 +42,16 @@ from distilabel.models.llms.ollama import OllamaLLM from distilabel.models.llms.openai import OpenAILLM from distilabel.models.llms.together import TogetherLLM -from distilabel.models.llms.typing import GenerateOutput, HiddenState from distilabel.models.llms.vertexai import VertexAILLM from distilabel.models.llms.vllm import ClientvLLM, vLLM from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin +from distilabel.typing import GenerateOutput, HiddenState __all__ = [ "LLM", "AnthropicLLM", "AnyscaleLLM", + "AsyncImageGenerationModel", "AsyncLLM", "AzureOpenAILLM", "ClientvLLM", @@ -52,6 +61,8 @@ "GenerateOutput", "GroqLLM", "HiddenState", + "ImageGenerationModel", + "InferenceEndpointsImageGeneration", "InferenceEndpointsLLM", "LiteLLM", "LlamaCppEmbeddings", @@ -60,6 +71,7 @@ "MixtureOfAgentsLLM", "MlxLLM", "OllamaLLM", + "OpenAIImageGeneration", "OpenAILLM", "SentenceTransformerEmbeddings", "TogetherLLM", diff --git a/src/distilabel/models/base_clients/__init__.py b/src/distilabel/models/base_clients/__init__.py new file mode 100644 index 0000000000..07e329b4d9 --- /dev/null +++ b/src/distilabel/models/base_clients/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.models.base_clients.inference_endpoints import ( + InferenceEndpointsBaseClient, +) +from distilabel.models.base_clients.openai import OpenAIBaseClient + +__all__ = ["InferenceEndpointsBaseClient", "OpenAIBaseClient"] diff --git a/src/distilabel/models/base_clients/inference_endpoints.py b/src/distilabel/models/base_clients/inference_endpoints.py new file mode 100644 index 0000000000..ebcc84e344 --- /dev/null +++ b/src/distilabel/models/base_clients/inference_endpoints.py @@ -0,0 +1,154 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import ( + TYPE_CHECKING, + Optional, + Union, +) + +from pydantic import ( + BaseModel, + Field, + PrivateAttr, + SecretStr, +) + +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.typing import StructuredOutputType +from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR, get_hf_token + +if TYPE_CHECKING: + from huggingface_hub import AsyncInferenceClient + from transformers import PreTrainedTokenizer + + +class InferenceEndpointsBaseClient(BaseModel): + model_id: Optional[str] = None + + endpoint_name: Optional[RuntimeParameter[str]] = Field( + default=None, + description="The name of the Inference Endpoint to use for the LLM.", + ) + endpoint_namespace: Optional[RuntimeParameter[str]] = Field( + default=None, + description="The namespace of the Inference Endpoint to use for the LLM.", + ) + base_url: Optional[RuntimeParameter[str]] = Field( + default=None, + description="The base URL to use for the Inference Endpoints API requests.", + ) + api_key: Optional[RuntimeParameter[SecretStr]] = Field( + default_factory=lambda: os.getenv(HF_TOKEN_ENV_VAR), + description="The API key to authenticate the requests to the Inference Endpoints API.", + ) + + tokenizer_id: Optional[str] = None + model_display_name: Optional[str] = None + + structured_output: Optional[RuntimeParameter[StructuredOutputType]] = Field( + default=None, + description="The structured output format to use across all the generations.", + ) + + _num_generations_param_supported = False + + _model_name: Optional[str] = PrivateAttr(default=None) + _tokenizer: Optional["PreTrainedTokenizer"] = PrivateAttr(default=None) + _api_key_env_var: str = PrivateAttr(HF_TOKEN_ENV_VAR) + _aclient: Optional["AsyncInferenceClient"] = PrivateAttr(...) + + def load(self) -> None: # noqa: C901 + """Loads the `AsyncInferenceClient` client to connect to the Hugging Face Inference + Endpoint. + + Raises: + ImportError: if the `huggingface-hub` Python client is not installed. + ValueError: if the model is not currently deployed or is not running the TGI framework. + ImportError: if the `transformers` Python client is not installed. + """ + + try: + from huggingface_hub import ( + AsyncInferenceClient, + InferenceClient, + get_inference_endpoint, + ) + except ImportError as ie: + raise ImportError( + "Hugging Face Hub Python client is not installed. Please install it using" + " `pip install 'distilabel[hf-inference-endpoints]'`." + ) from ie + + if self.api_key is None: + self.api_key = SecretStr(get_hf_token(self.__class__.__name__, "api_key")) + + if self.model_id is not None: + client = InferenceClient( + model=self.model_id, token=self.api_key.get_secret_value() + ) + status = client.get_model_status() + + if ( + status.state not in {"Loadable", "Loaded"} + and status.framework != "text-generation-inference" + ): + raise ValueError( + f"Model {self.model_id} is not currently deployed or is not running the TGI framework" + ) + + self.base_url = client._resolve_url( + model=self.model_id, task="text-generation" + ) + + if self.endpoint_name is not None: + client = get_inference_endpoint( + name=self.endpoint_name, + namespace=self.endpoint_namespace, + token=self.api_key.get_secret_value(), + ) + if client.status in ["paused", "scaledToZero"]: + client.resume().wait(timeout=300) + elif client.status == "initializing": + client.wait(timeout=300) + + self.base_url = client.url + self._model_name = client.repository + + self._aclient = AsyncInferenceClient( + base_url=self.base_url, + token=self.api_key.get_secret_value(), + ) + + if self.tokenizer_id: + try: + from transformers import AutoTokenizer + except ImportError as ie: + raise ImportError( + "Transformers Python client is not installed. Please install it using" + " `pip install 'distilabel[hf-inference-endpoints]'`." + ) from ie + self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id) + + @property + def model_name(self) -> Union[str, None]: # type: ignore + """Returns the model name used for the model.""" + return ( + self.model_display_name + or self._model_name + or self.model_id + or self.endpoint_name + or self.base_url + ) diff --git a/src/distilabel/models/base_clients/openai.py b/src/distilabel/models/base_clients/openai.py new file mode 100644 index 0000000000..ada4d0b4d7 --- /dev/null +++ b/src/distilabel/models/base_clients/openai.py @@ -0,0 +1,122 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import TYPE_CHECKING, Dict, Optional + +from pydantic import BaseModel, Field, PrivateAttr, SecretStr + +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.typing import InstructorStructuredOutputType + +if TYPE_CHECKING: + from openai import AsyncOpenAI, OpenAI + + +_OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY" + + +class OpenAIBaseClient(BaseModel): + model: str + base_url: Optional[RuntimeParameter[str]] = Field( + default_factory=lambda: os.getenv( + "OPENAI_BASE_URL", "https://api.openai.com/v1" + ), + description="The base URL to use for the OpenAI API requests.", + ) + api_key: Optional[RuntimeParameter[SecretStr]] = Field( + default_factory=lambda: os.getenv(_OPENAI_API_KEY_ENV_VAR_NAME), + description="The API key to authenticate the requests to the OpenAI API.", + ) # type: ignore + default_headers: Optional[RuntimeParameter[Dict[str, str]]] = Field( + default=None, + description="The default headers to use for the OpenAI API requests.", + ) + max_retries: RuntimeParameter[int] = Field( + default=6, + description="The maximum number of times to retry the request to the API before" + " failing.", + ) + timeout: RuntimeParameter[int] = Field( + default=120, + description="The maximum time in seconds to wait for a response from the API.", + ) + structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = ( + Field( + default=None, + description="The structured output format to use across all the generations.", + ) + ) + + _api_key_env_var: str = PrivateAttr(_OPENAI_API_KEY_ENV_VAR_NAME) + _client: "OpenAI" = PrivateAttr(None) # type: ignore + _aclient: "AsyncOpenAI" = PrivateAttr(None) # type: ignore + + def load(self) -> None: + """Loads the `AsyncOpenAI` client to benefit from async requests.""" + + try: + from openai import AsyncOpenAI, OpenAI + except ImportError as ie: + raise ImportError( + "OpenAI Python client is not installed. Please install it using" + " `pip install 'distilabel[openai]'`." + ) from ie + + if self.api_key is None: + raise ValueError( + f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`" + f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`." + ) + + self._client = OpenAI( + base_url=self.base_url, + api_key=self.api_key.get_secret_value(), + max_retries=self.max_retries, # type: ignore + timeout=self.timeout, + default_headers=self.default_headers, + ) + + self._aclient = AsyncOpenAI( + base_url=self.base_url, + api_key=self.api_key.get_secret_value(), + max_retries=self.max_retries, # type: ignore + timeout=self.timeout, + default_headers=self.default_headers, + ) + + if self.structured_output: + # This applies only to the LLMs. + result = self._prepare_structured_output( + structured_output=self.structured_output, + client=self._aclient, + framework="openai", + ) + self._aclient = result.get("client") # type: ignore + if structured_output := result.get("structured_output"): + self.structured_output = structured_output + + def unload(self) -> None: + """Set clients to `None` as they both contain `thread._RLock` which cannot be pickled + in case an exception is raised and has to be handled in the main process""" + + self._client = None # type: ignore + self._aclient = None # type: ignore + self.default_headers = None + self.structured_output = None + + @property + def model_name(self) -> str: + """Returns the model name used for the LLM.""" + return self.model diff --git a/src/distilabel/models/embeddings/base.py b/src/distilabel/models/embeddings/base.py index e2ee4af3f1..ad46345d54 100644 --- a/src/distilabel/models/embeddings/base.py +++ b/src/distilabel/models/embeddings/base.py @@ -50,7 +50,9 @@ class Embeddings(RuntimeParametersMixin, BaseModel, _Serializable, ABC): def load(self) -> None: """Method to be called to initialize the `Embeddings`""" - self._logger = logging.getLogger(f"distilabel.llm.{self.model_name}") + self._logger = logging.getLogger( + f"distilabel.models.embeddings.{self.model_name}" + ) def unload(self) -> None: """Method to be called to unload the `Embeddings` and release any resources.""" diff --git a/src/distilabel/models/image_generation/__init__.py b/src/distilabel/models/image_generation/__init__.py new file mode 100644 index 0000000000..42a4f5a3db --- /dev/null +++ b/src/distilabel/models/image_generation/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.models.image_generation.base import ( + AsyncImageGenerationModel, + ImageGenerationModel, +) +from distilabel.models.image_generation.huggingface.inference_endpoints import ( + InferenceEndpointsImageGeneration, +) +from distilabel.models.image_generation.openai import OpenAIImageGeneration + +__all__ = [ + "AsyncImageGenerationModel", + "ImageGenerationModel", + "InferenceEndpointsImageGeneration", + "OpenAIImageGeneration", +] diff --git a/src/distilabel/models/image_generation/base.py b/src/distilabel/models/image_generation/base.py new file mode 100644 index 0000000000..bdce07d504 --- /dev/null +++ b/src/distilabel/models/image_generation/base.py @@ -0,0 +1,247 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import inspect +import logging +import sys +from abc import ABC, abstractmethod +from functools import cached_property +from typing import TYPE_CHECKING, Any, Optional + +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr + +from distilabel.mixins.runtime_parameters import ( + RuntimeParameter, + RuntimeParametersModelMixin, +) +from distilabel.utils.docstring import parse_google_docstring +from distilabel.utils.itertools import grouper +from distilabel.utils.serialization import _Serializable + +if TYPE_CHECKING: + from logging import Logger + + from distilabel.utils.docstring import Docstring + + +class ImageGenerationModel(RuntimeParametersModelMixin, BaseModel, _Serializable, ABC): + """Base class for `ImageGeneration` models. + + To implement an `ImageGeneration` subclass, you need to subclass this class and implement: + - `load` method to load the `ImageGeneration` model if needed. Don't forget to call `super().load()`, + so the `_logger` attribute is initialized. + - `model_name` property to return the model name used for the LLM. + - `generate` method to generate `num_generations` per input in `inputs`. + + Attributes: + generation_kwargs: the kwargs to be propagated to either `generate` or `agenerate` + methods within each `ImageGenerationModel`. + _logger: the logger to be used for the `ImageGenerationModel`. It will be initialized + when the `load` method is called. + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + protected_namespaces=(), + validate_default=True, + validate_assignment=True, + extra="forbid", + ) + + generation_kwargs: Optional[RuntimeParameter[dict[str, Any]]] = Field( + default_factory=dict, + description="The kwargs to be propagated to either `generate` or `agenerate`" + " methods within each `ImageGenerationModel`.", + ) + _logger: "Logger" = PrivateAttr(None) + + def load(self) -> None: + """Method to be called to initialize the `ImageGenerationModel`, and its logger.""" + self._logger = logging.getLogger( + f"distilabel.models.image_generation.{self.model_name}" + ) + + def unload(self) -> None: + """Method to be called to unload the `ImageGenerationModel` and release any resources.""" + pass + + @property + @abstractmethod + def model_name(self) -> str: + """Returns the model name used for the `ImageGenerationModel`.""" + pass + + def get_generation_kwargs(self) -> dict[str, Any]: + """Returns the generation kwargs to be used for the generation. This method can + be overridden to provide a more complex logic for the generation kwargs. + + Returns: + The kwargs to be used for the generation. + """ + return self.generation_kwargs # type: ignore + + @abstractmethod + def generate( + self, inputs: list[str], num_generations: int = 1, **kwargs: Any + ) -> list[list[dict[str, Any]]]: + """Generates images from the provided input. + + Args: + inputs: the prompt text to generate the image from. + num_generations: the number of images to generate. Defaults to `1`. + + Returns: + A list with a dictionary with the list of images generated. + """ + pass + + def generate_outputs( + self, + inputs: list[str], + num_generations: int = 1, + **kwargs: Any, + ) -> list[list[dict[str, Any]]]: + """This method is defined for compatibility with the `LLMs`. It calls the `generate` + method. + """ + return self.generate(inputs=inputs, num_generations=num_generations, **kwargs) + + +class AsyncImageGenerationModel(ImageGenerationModel): + """Abstract class for asynchronous `ImageGenerationModels`, to benefit from the async capabilities + of each LLM implementation. This class is meant to be subclassed by each `ImageGenerationModel`, and the + method `agenerate` needs to be implemented to provide the asynchronous generation of + responses. + + Attributes: + _event_loop: the event loop to be used for the asynchronous generation of responses. + """ + + _num_generations_param_supported = True + _event_loop: "asyncio.AbstractEventLoop" = PrivateAttr(default=None) + _new_event_loop: bool = PrivateAttr(default=False) + + @property + def generate_parameters(self) -> list[inspect.Parameter]: + """Returns the parameters of the `agenerate` method. + + Returns: + A list containing the parameters of the `agenerate` method. + """ + return list(inspect.signature(self.agenerate).parameters.values()) + + @cached_property + def generate_parsed_docstring(self) -> "Docstring": + """Returns the parsed docstring of the `agenerate` method. + + Returns: + The parsed docstring of the `agenerate` method. + """ + return parse_google_docstring(self.agenerate) + + @property + def event_loop(self) -> "asyncio.AbstractEventLoop": + if self._event_loop is None: + try: + self._event_loop = asyncio.get_running_loop() + if self._event_loop.is_closed(): + self._event_loop = asyncio.new_event_loop() # type: ignore + self._new_event_loop = True + except RuntimeError: + self._event_loop = asyncio.new_event_loop() + self._new_event_loop = True + asyncio.set_event_loop(self._event_loop) + return self._event_loop + + @abstractmethod + async def agenerate( + self, input: str, num_generations: int = 1, **kwargs: Any + ) -> list[dict[str, Any]]: + """Generates images from the provided input. + + Args: + input: the input text to generate the image from. + num_generations: the number of images to generate. Defaults to `1`. + + Returns: + A list with a dictionary with the list of images generated. + """ + pass + + async def _agenerate( + self, inputs: list[str], num_generations: int = 1, **kwargs: Any + ) -> list[list[dict[str, Any]]]: + """Internal function to concurrently generate images for a list of inputs. + + Args: + inputs: the list of inputs to generate images for. + num_generations: the number of generations to generate per input. + **kwargs: the additional kwargs to be used for the generation. + + Returns: + A list containing the generations for each input. + """ + if self._num_generations_param_supported: + tasks = [ + asyncio.create_task( + self.agenerate( + input=input, num_generations=num_generations, **kwargs + ) + ) + for input in inputs + ] + return await asyncio.gather(*tasks) + + tasks = [ + asyncio.create_task(self.agenerate(input=input, **kwargs)) + for input in inputs + for _ in range(num_generations) + ] + outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)] + return [ + list(group) + for group in grouper(outputs, n=num_generations, incomplete="ignore") + ] + + def generate( + self, + inputs: list[str], + num_generations: int = 1, + **kwargs: Any, + ) -> list[list[dict[str, Any]]]: + """Method to generate a list of images asynchronously, returning the output + synchronously awaiting for the image of each input sent to `agenerate`. + + Args: + inputs: the list of inputs to generate images for. + num_generations: the number of generations to generate per input. + **kwargs: the additional kwargs to be used for the generation. + + Returns: + A list containing the images for each input. + """ + return self.event_loop.run_until_complete( + self._agenerate(inputs=inputs, num_generations=num_generations, **kwargs) + ) + + def __del__(self) -> None: + """Closes the event loop when the object is deleted.""" + if sys.meta_path is None: + return + + if self._new_event_loop: + if self._event_loop.is_running(): + self._event_loop.stop() + self._event_loop.close() diff --git a/src/distilabel/models/image_generation/huggingface/__init__.py b/src/distilabel/models/image_generation/huggingface/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/src/distilabel/models/image_generation/huggingface/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/src/distilabel/models/image_generation/huggingface/inference_endpoints.py b/src/distilabel/models/image_generation/huggingface/inference_endpoints.py new file mode 100644 index 0000000000..2403fbf018 --- /dev/null +++ b/src/distilabel/models/image_generation/huggingface/inference_endpoints.py @@ -0,0 +1,106 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Optional + +from pydantic import validate_call + +from distilabel.models.base_clients.inference_endpoints import ( + InferenceEndpointsBaseClient, +) +from distilabel.models.image_generation.base import AsyncImageGenerationModel +from distilabel.models.image_generation.utils import image_to_str + +if TYPE_CHECKING: + from PIL.Image import Image + + +class InferenceEndpointsImageGeneration( # type: ignore + InferenceEndpointsBaseClient, AsyncImageGenerationModel +): + """Inference Endpoint image generation implementation running the async API client. + + Attributes: + model_id: the model ID to use for the ImageGenerationModel as available in the Hugging Face Hub, which + will be used to resolve the base URL for the serverless Inference Endpoints API requests. + Defaults to `None`. + endpoint_name: the name of the Inference Endpoint to use for the LLM. Defaults to `None`. + endpoint_namespace: the namespace of the Inference Endpoint to use for the LLM. Defaults to `None`. + base_url: the base URL to use for the Inference Endpoints API requests. + api_key: the API key to authenticate the requests to the Inference Endpoints API. + + Icon: + `:hugging:` + + Examples: + Generate images from text prompts: + + ```python + from distilabel.models.image_generation import InferenceEndpointsImageGeneration + + igm = InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell", api_key="api.key") + igm.load() + + output = igm.generate_outputs( + inputs=["a white siamese cat"], + ) + # [{"images": ["iVBORw0KGgoAAAANSUhEUgA..."]}] + ``` + """ + + def load(self) -> None: + # Sets the logger and calls the load method of the BaseClient + AsyncImageGenerationModel.load(self) + InferenceEndpointsBaseClient.load(self) + + @validate_call + async def agenerate( # type: ignore + self, + input: str, + negative_prompt: Optional[str] = None, + height: Optional[float] = None, + width: Optional[float] = None, + num_inference_steps: Optional[float] = None, + guidance_scale: Optional[float] = None, + num_generations: int = 1, + ) -> list[dict[str, Any]]: + """Generates images from text prompts using `huggingface_hub.AsyncInferenceClient.text_to_image`. + + Args: + input: Prompt to generate an image from. + negative_prompt: An optional negative prompt for the image generation. Defaults to None. + height: The height in pixels of the image to generate. + width: The width in pixels of the image to generate. + num_inference_steps: The number of denoising steps. More denoising steps usually lead + to a higher quality image at the expense of slower inference. + guidance_scale: Higher guidance scale encourages to generate images that are closely + linked to the text `prompt`, usually at the expense of lower image quality. + num_generations: The number of images to generate. Defaults to `1`. + It's here to ensure the validation succeeds, but it won't have effect. + + Returns: + A list with a dictionary containing a list with the image as a base64 string. + """ + + image: "Image" = await self._aclient.text_to_image( # type: ignore + input, + negative_prompt=negative_prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + ) + img_str = image_to_str(image, image_format="JPEG") + + return [{"images": [img_str]}] diff --git a/src/distilabel/models/image_generation/openai.py b/src/distilabel/models/image_generation/openai.py new file mode 100644 index 0000000000..6315eb8046 --- /dev/null +++ b/src/distilabel/models/image_generation/openai.py @@ -0,0 +1,129 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +from typing import TYPE_CHECKING, Any, Literal, Optional + +import requests +from pydantic import validate_call + +from distilabel.models.base_clients.openai import OpenAIBaseClient +from distilabel.models.image_generation.base import AsyncImageGenerationModel + +if TYPE_CHECKING: + from openai.types import ImagesResponse + + +class OpenAIImageGeneration(OpenAIBaseClient, AsyncImageGenerationModel): + """OpenAI image generation implementation running the async API client. + + Attributes: + model: the model name to use for the ImageGenerationModel e.g. "dall-e-3", etc. + Supported models can be found [here](https://platform.openai.com/docs/guides/images). + base_url: the base URL to use for the OpenAI API requests. Defaults to `None`, which + means that the value set for the environment variable `OPENAI_BASE_URL` will + be used, or "https://api.openai.com/v1" if not set. + api_key: the API key to authenticate the requests to the OpenAI API. Defaults to + `None` which means that the value set for the environment variable `OPENAI_API_KEY` + will be used, or `None` if not set. + max_retries: the maximum number of times to retry the request to the API before + failing. Defaults to `6`. + timeout: the maximum time in seconds to wait for a response from the API. Defaults + to `120`. + + Icon: + `:simple-openai:` + + Examples: + Generate images from text prompts: + + ```python + from distilabel.models.image_generation import OpenAIImageGeneration + + igm = OpenAIImageGeneration(model="dall-e-3", api_key="api.key") + + igm.load() + + output = igm.generate_outputs( + inputs=["a white siamese cat"], + size="1024x1024", + quality="standard", + style="natural", + ) + # [{"images": ["iVBORw0KGgoAAAANSUhEUgA..."]}] + ``` + """ + + def load(self) -> None: + # Sets the logger and calls the load method of the BaseClient + AsyncImageGenerationModel.load(self) + OpenAIBaseClient.load(self) + + @validate_call + async def agenerate( # type: ignore + self, + input: str, + num_generations: int = 1, + quality: Optional[Literal["standard", "hd"]] = "standard", + response_format: Optional[Literal["url", "b64_json"]] = "url", + size: Optional[ + Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] + ] = None, + style: Optional[Literal["vivid", "natural"]] = None, + ) -> list[dict[str, Any]]: + """Generates `num_generations` images for the given input using the OpenAI async + client. The images are base64 string representations. + + Args: + input: A text description of the desired image(s). The maximum length is 1000 + characters for `dall-e-2` and 4000 characters for `dall-e-3`. + num_generations: The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only + `n=1` is supported. + quality: The quality of the image that will be generated. `hd` creates images with finer + details and greater consistency across the image. This param is only supported + for `dall-e-3`. + response_format: The format in which the generated images are returned. Must be one of `url` or + `b64_json`. URLs are only valid for 60 minutes after the image has been + generated. + size: The size of the generated images. Must be one of `256x256`, `512x512`, or + `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or + `1024x1792` for `dall-e-3` models. + style: The style of the generated images. Must be one of `vivid` or `natural`. Vivid + causes the model to lean towards generating hyper-real and dramatic images. + Natural causes the model to produce more natural, less hyper-real looking + images. This param is only supported for `dall-e-3`. + + Returns: + A list with a dictionary with the list of images generated. + """ + images_response: "ImagesResponse" = await self._aclient.images.generate( + model=self.model_name, + prompt=input, + n=num_generations, + quality=quality, + response_format=response_format, + size=size, + style=style, + ) + images = [] + for image in images_response.data: + if response_format == "url": + image_data = requests.get( + image.url + ).content # TODO: Keep a requests/httpx session instead + image_str = base64.b64encode(image_data).decode() + images.append(image_str) + elif response_format == "b64_json": + images.append(image.b64_json) + return [{"images": images}] diff --git a/src/distilabel/models/image_generation/utils.py b/src/distilabel/models/image_generation/utils.py new file mode 100644 index 0000000000..e5f08ca343 --- /dev/null +++ b/src/distilabel/models/image_generation/utils.py @@ -0,0 +1,31 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import io + +from PIL import Image + + +def image_to_str(image: Image.Image, image_format: str = "JPEG") -> str: + """Converts a PIL Image to a base64 encoded string.""" + buffered = io.BytesIO() + image.save(buffered, format=image_format) + return base64.b64encode(buffered.getvalue()).decode("utf-8") + + +def image_from_str(image_str: str) -> Image.Image: + """Converts a base64 encoded string to a PIL Image.""" + image_bytes = base64.b64decode(image_str) + return Image.open(io.BytesIO(image_bytes)) diff --git a/src/distilabel/models/llms/__init__.py b/src/distilabel/models/llms/__init__.py index 0b0f3a7a9c..3469c1e2bc 100644 --- a/src/distilabel/models/llms/__init__.py +++ b/src/distilabel/models/llms/__init__.py @@ -27,10 +27,10 @@ from distilabel.models.llms.ollama import OllamaLLM from distilabel.models.llms.openai import OpenAILLM from distilabel.models.llms.together import TogetherLLM -from distilabel.models.llms.typing import GenerateOutput, HiddenState from distilabel.models.llms.vertexai import VertexAILLM from distilabel.models.llms.vllm import ClientvLLM, vLLM from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin +from distilabel.typing import GenerateOutput, HiddenState __all__ = [ "LLM", diff --git a/src/distilabel/models/llms/anthropic.py b/src/distilabel/models/llms/anthropic.py index ab364bad58..3650671118 100644 --- a/src/distilabel/models/llms/anthropic.py +++ b/src/distilabel/models/llms/anthropic.py @@ -29,20 +29,19 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output -from distilabel.steps.tasks.typing import ( +from distilabel.typing import ( FormattedInput, + GenerateOutput, InstructorStructuredOutputType, ) if TYPE_CHECKING: - from typing import BaseModel - from anthropic import AsyncAnthropic from anthropic.types import Message + from pydantic import BaseModel - from distilabel.models.llms.typing import LLMStatistics + from distilabel.typing import LLMStatistics _ANTHROPIC_API_KEY_ENV_VAR_NAME = "ANTHROPIC_API_KEY" diff --git a/src/distilabel/models/llms/base.py b/src/distilabel/models/llms/base.py index df274df402..912839b27b 100644 --- a/src/distilabel/models/llms/base.py +++ b/src/distilabel/models/llms/base.py @@ -31,7 +31,7 @@ from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException from distilabel.mixins.runtime_parameters import ( RuntimeParameter, - RuntimeParametersMixin, + RuntimeParametersModelMixin, ) from distilabel.utils.docstring import parse_google_docstring from distilabel.utils.notebook import in_notebook @@ -40,16 +40,13 @@ if TYPE_CHECKING: from logging import Logger - from distilabel.mixins.runtime_parameters import ( - RuntimeParameterInfo, - RuntimeParametersNames, - ) - from distilabel.models.llms.typing import GenerateOutput, HiddenState - from distilabel.steps.tasks.structured_outputs.outlines import StructuredOutputType - from distilabel.steps.tasks.typing import ( + from distilabel.typing import ( FormattedInput, + GenerateOutput, + HiddenState, InstructorStructuredOutputType, StandardInput, + StructuredOutputType, ) from distilabel.utils.docstring import Docstring @@ -59,7 +56,7 @@ nest_asyncio.apply() -class LLM(RuntimeParametersMixin, BaseModel, _Serializable, ABC): +class LLM(RuntimeParametersModelMixin, BaseModel, _Serializable, ABC): """Base class for `LLM`s to be used in `distilabel` framework. To implement an `LLM` subclass, you need to subclass this class and implement: @@ -241,81 +238,6 @@ def _offline_batch_generate_polling( jobs_ids=self.jobs_ids # type: ignore ) from e - @property - def generate_parameters(self) -> List["inspect.Parameter"]: - """Returns the parameters of the `generate` method. - - Returns: - A list containing the parameters of the `generate` method. - """ - return list(inspect.signature(self.generate).parameters.values()) - - @property - def runtime_parameters_names(self) -> "RuntimeParametersNames": - """Returns the runtime parameters of the `LLM`, which are combination of the - attributes of the `LLM` type hinted with `RuntimeParameter` and the parameters - of the `generate` method that are not `input` and `num_generations`. - - Returns: - A dictionary with the name of the runtime parameters as keys and a boolean - indicating if the parameter is optional or not. - """ - runtime_parameters = super().runtime_parameters_names - runtime_parameters["generation_kwargs"] = {} - - # runtime parameters from the `generate` method - for param in self.generate_parameters: - if param.name in ["input", "inputs", "num_generations"]: - continue - is_optional = param.default != inspect.Parameter.empty - runtime_parameters["generation_kwargs"][param.name] = is_optional - - return runtime_parameters - - def get_runtime_parameters_info(self) -> List["RuntimeParameterInfo"]: - """Gets the information of the runtime parameters of the `LLM` such as the name - and the description. This function is meant to include the information of the runtime - parameters in the serialized data of the `LLM`. - - Returns: - A list containing the information for each runtime parameter of the `LLM`. - """ - runtime_parameters_info = super().get_runtime_parameters_info() - - generation_kwargs_info = next( - ( - runtime_parameter_info - for runtime_parameter_info in runtime_parameters_info - if runtime_parameter_info["name"] == "generation_kwargs" - ), - None, - ) - - # If `generation_kwargs` attribute is present, we need to include the `generate` - # method arguments as the information for this attribute. - if generation_kwargs_info: - generate_docstring_args = self.generate_parsed_docstring["args"] - - generation_kwargs_info["keys"] = [] - for key, value in generation_kwargs_info["optional"].items(): - info = {"name": key, "optional": value} - if description := generate_docstring_args.get(key): - info["description"] = description - generation_kwargs_info["keys"].append(info) - - generation_kwargs_info.pop("optional") - - return runtime_parameters_info - - @cached_property - def generate_parsed_docstring(self) -> "Docstring": - """Returns the parsed docstring of the `generate` method. - - Returns: - The parsed docstring of the `generate` method. - """ - return parse_google_docstring(self.generate) - def get_last_hidden_states( self, inputs: List["StandardInput"] ) -> List["HiddenState"]: diff --git a/src/distilabel/models/llms/cohere.py b/src/distilabel/models/llms/cohere.py index 8b081a762e..0c9a342aea 100644 --- a/src/distilabel/models/llms/cohere.py +++ b/src/distilabel/models/llms/cohere.py @@ -28,10 +28,10 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import compute_tokens, prepare_output -from distilabel.steps.tasks.typing import ( +from distilabel.typing import ( FormattedInput, + GenerateOutput, InstructorStructuredOutputType, ) @@ -40,7 +40,7 @@ from pydantic import BaseModel from tokenizers import Tokenizer - from distilabel.models.llms.typing import LLMStatistics + from distilabel.typing import LLMStatistics _COHERE_API_KEY_ENV_VAR_NAME = "COHERE_API_KEY" diff --git a/src/distilabel/models/llms/groq.py b/src/distilabel/models/llms/groq.py index fec511bbee..4334d72bdd 100644 --- a/src/distilabel/models/llms/groq.py +++ b/src/distilabel/models/llms/groq.py @@ -18,11 +18,11 @@ from pydantic import Field, PrivateAttr, SecretStr, validate_call from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output from distilabel.steps.base import RuntimeParameter -from distilabel.steps.tasks.typing import ( +from distilabel.typing import ( FormattedInput, + GenerateOutput, InstructorStructuredOutputType, ) @@ -30,7 +30,7 @@ from groq import AsyncGroq from groq.types.chat.chat_completion import ChatCompletion - from distilabel.models.llms.typing import LLMStatistics + from distilabel.typing import LLMStatistics _GROQ_API_BASE_URL_ENV_VAR_NAME = "GROQ_BASE_URL" diff --git a/src/distilabel/models/llms/huggingface/inference_endpoints.py b/src/distilabel/models/llms/huggingface/inference_endpoints.py index 6f97c5814a..8956529999 100644 --- a/src/distilabel/models/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/models/llms/huggingface/inference_endpoints.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import random import sys import warnings @@ -30,29 +29,22 @@ from pydantic import ( Field, PositiveInt, - PrivateAttr, - SecretStr, ValidationError, model_validator, validate_call, ) from pydantic._internal._model_construction import ModelMetaclass -from typing_extensions import Annotated, override +from typing_extensions import Annotated -from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.models.base_clients.inference_endpoints import ( + InferenceEndpointsBaseClient, +) from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput, Logprob from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.models.mixins.magpie import MagpieChatTemplateMixin -from distilabel.steps.tasks.typing import ( - FormattedInput, - StandardInput, - StructuredOutputType, -) -from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR, get_hf_token +from distilabel.typing import FormattedInput, GenerateOutput, Logprob, StandardInput if TYPE_CHECKING: - from huggingface_hub import AsyncInferenceClient from huggingface_hub.inference._generated.types.chat_completion import ( ChatCompletionOutput, ChatCompletionOutputComplete, @@ -60,12 +52,13 @@ from huggingface_hub.inference._generated.types.text_generation import ( TextGenerationOutput, ) - from transformers import PreTrainedTokenizer - from distilabel.models.llms.typing import Logprob + from distilabel.typing import Logprob -class InferenceEndpointsLLM(AsyncLLM, MagpieChatTemplateMixin): +class InferenceEndpointsLLM( + InferenceEndpointsBaseClient, AsyncLLM, MagpieChatTemplateMixin +): """InferenceEndpoints LLM implementation running the async API client. This LLM will internally use `huggingface_hub.AsyncInferenceClient`. @@ -164,39 +157,11 @@ class User(BaseModel): ``` """ - model_id: Optional[str] = None - - endpoint_name: Optional[RuntimeParameter[str]] = Field( - default=None, - description="The name of the Inference Endpoint to use for the LLM.", - ) - endpoint_namespace: Optional[RuntimeParameter[str]] = Field( - default=None, - description="The namespace of the Inference Endpoint to use for the LLM.", - ) - base_url: Optional[RuntimeParameter[str]] = Field( - default=None, - description="The base URL to use for the Inference Endpoints API requests.", - ) - api_key: Optional[RuntimeParameter[SecretStr]] = Field( - default_factory=lambda: os.getenv(HF_TOKEN_ENV_VAR), - description="The API key to authenticate the requests to the Inference Endpoints API.", - ) - - tokenizer_id: Optional[str] = None - model_display_name: Optional[str] = None - - structured_output: Optional[RuntimeParameter[StructuredOutputType]] = Field( - default=None, - description="The structured output format to use across all the generations.", - ) - - _num_generations_param_supported = False - - _model_name: Optional[str] = PrivateAttr(default=None) - _tokenizer: Optional["PreTrainedTokenizer"] = PrivateAttr(default=None) - _api_key_env_var: str = PrivateAttr(HF_TOKEN_ENV_VAR) - _aclient: Optional["AsyncInferenceClient"] = PrivateAttr(...) + def load(self) -> None: + # Sets the logger and calls the load method of the BaseClient + self._num_generations_param_supported = False + AsyncLLM.load(self) + InferenceEndpointsBaseClient.load(self) @model_validator(mode="after") # type: ignore def only_one_of_model_id_endpoint_name_or_base_url_provided( @@ -242,92 +207,6 @@ def only_one_of_model_id_endpoint_name_or_base_url_provided( f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}." ) - def load(self) -> None: # noqa: C901 - """Loads the `AsyncInferenceClient` client to connect to the Hugging Face Inference - Endpoint. - - Raises: - ImportError: if the `huggingface-hub` Python client is not installed. - ValueError: if the model is not currently deployed or is not running the TGI framework. - ImportError: if the `transformers` Python client is not installed. - """ - super().load() - - try: - from huggingface_hub import ( - AsyncInferenceClient, - InferenceClient, - get_inference_endpoint, - ) - except ImportError as ie: - raise ImportError( - "Hugging Face Hub Python client is not installed. Please install it using" - " `pip install 'distilabel[hf-inference-endpoints]'`." - ) from ie - - if self.api_key is None: - self.api_key = SecretStr(get_hf_token(self.__class__.__name__, "api_key")) - - if self.model_id is not None: - client = InferenceClient( - model=self.model_id, token=self.api_key.get_secret_value() - ) - status = client.get_model_status() - - if ( - status.state not in {"Loadable", "Loaded"} - and status.framework != "text-generation-inference" - ): - raise ValueError( - f"Model {self.model_id} is not currently deployed or is not running the TGI framework" - ) - - self.base_url = client._resolve_url( - model=self.model_id, task="text-generation" - ) - - if self.endpoint_name is not None: - client = get_inference_endpoint( - name=self.endpoint_name, - namespace=self.endpoint_namespace, - token=self.api_key.get_secret_value(), - ) - if client.status in ["paused", "scaledToZero"]: - client.resume().wait(timeout=300) - elif client.status == "initializing": - client.wait(timeout=300) - - self.base_url = client.url - self._model_name = client.repository - - self._aclient = AsyncInferenceClient( - base_url=self.base_url, - token=self.api_key.get_secret_value(), - ) - - if self.tokenizer_id: - try: - from transformers import AutoTokenizer - except ImportError as ie: - raise ImportError( - "Transformers Python client is not installed. Please install it using" - " `pip install 'distilabel[hf-inference-endpoints]'`." - ) from ie - - self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id) - - @property - @override - def model_name(self) -> Union[str, None]: # type: ignore - """Returns the model name used for the LLM.""" - return ( - self.model_display_name - or self._model_name - or self.model_id - or self.endpoint_name - or self.base_url - ) - def prepare_input(self, input: "StandardInput") -> str: """Prepares the input (applying the chat template and tokenization) for the provided input. @@ -588,6 +467,7 @@ async def agenerate( # type: ignore top_k: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + num_generations: int = 1, ) -> GenerateOutput: """Generates completions for the given input using the async client. This method uses two methods of the `huggingface_hub.AsyncClient`: `chat_completion` and `text_generation`. @@ -656,6 +536,8 @@ async def agenerate( # type: ignore watermark: whether to add the watermark to the generated text. This argument is exclusive of the `text_generation` method and will be only used if `tokenizer_id` is not `None`. Defaults to `None`. + num_generations: the number of generations to generate. Defaults to `1`. It's here to ensure + the validation succeds. Returns: A list of lists of strings containing the generated responses for each input. diff --git a/src/distilabel/models/llms/huggingface/transformers.py b/src/distilabel/models/llms/huggingface/transformers.py index 19dc32dd2d..aef8c40e16 100644 --- a/src/distilabel/models/llms/huggingface/transformers.py +++ b/src/distilabel/models/llms/huggingface/transformers.py @@ -19,14 +19,17 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import LLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.models.mixins.magpie import MagpieChatTemplateMixin from distilabel.steps.tasks.structured_outputs.outlines import ( _is_outlines_version_below_0_1_0, ) -from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput +from distilabel.typing import ( + GenerateOutput, + OutlinesStructuredOutputType, + StandardInput, +) from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR if TYPE_CHECKING: @@ -34,7 +37,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils import PreTrainedTokenizer - from distilabel.models.llms.typing import HiddenState + from distilabel.typing import HiddenState class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): diff --git a/src/distilabel/models/llms/litellm.py b/src/distilabel/models/llms/litellm.py index 9b52ad8c71..29c910622b 100644 --- a/src/distilabel/models/llms/litellm.py +++ b/src/distilabel/models/llms/litellm.py @@ -20,9 +20,12 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output -from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType +from distilabel.typing import ( + FormattedInput, + GenerateOutput, + InstructorStructuredOutputType, +) if TYPE_CHECKING: from litellm import Choices diff --git a/src/distilabel/models/llms/llamacpp.py b/src/distilabel/models/llms/llamacpp.py index a754f6b84f..87f5eb358f 100644 --- a/src/distilabel/models/llms/llamacpp.py +++ b/src/distilabel/models/llms/llamacpp.py @@ -18,10 +18,13 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import LLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output from distilabel.models.mixins.magpie import MagpieChatTemplateMixin -from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType +from distilabel.typing import ( + FormattedInput, + GenerateOutput, + OutlinesStructuredOutputType, +) if TYPE_CHECKING: from llama_cpp import ( @@ -31,7 +34,7 @@ LogitsProcessorList, ) - from distilabel.steps.tasks.typing import FormattedInput, StandardInput + from distilabel.typing import FormattedInput, StandardInput class LlamaCppLLM(LLM, MagpieChatTemplateMixin): diff --git a/src/distilabel/models/llms/mistral.py b/src/distilabel/models/llms/mistral.py index 4147edaf03..e6047d4be5 100644 --- a/src/distilabel/models/llms/mistral.py +++ b/src/distilabel/models/llms/mistral.py @@ -19,10 +19,10 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output -from distilabel.steps.tasks.typing import ( +from distilabel.typing import ( FormattedInput, + GenerateOutput, InstructorStructuredOutputType, ) @@ -30,7 +30,7 @@ from mistralai import Mistral from mistralai.models.chatcompletionresponse import ChatCompletionResponse - from distilabel.models.llms.typing import LLMStatistics + from distilabel.typing import LLMStatistics _MISTRALAI_API_KEY_ENV_VAR_NAME = "MISTRAL_API_KEY" diff --git a/src/distilabel/models/llms/mlx.py b/src/distilabel/models/llms/mlx.py index 1f8c9b8c65..ffdcf37526 100644 --- a/src/distilabel/models/llms/mlx.py +++ b/src/distilabel/models/llms/mlx.py @@ -27,12 +27,9 @@ ) from distilabel.models.llms.base import LLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.models.mixins.magpie import MagpieChatTemplateMixin -from distilabel.steps.tasks.typing import ( - StandardInput, -) +from distilabel.typing import GenerateOutput, StandardInput if TYPE_CHECKING: import mlx.nn as nn diff --git a/src/distilabel/models/llms/moa.py b/src/distilabel/models/llms/moa.py index 11af619ad4..ea859e95da 100644 --- a/src/distilabel/models/llms/moa.py +++ b/src/distilabel/models/llms/moa.py @@ -19,12 +19,11 @@ from pydantic import Field from distilabel.models.llms.base import LLM, AsyncLLM -from distilabel.steps.tasks.typing import StandardInput +from distilabel.typing import StandardInput if TYPE_CHECKING: from distilabel.mixins.runtime_parameters import RuntimeParametersNames - from distilabel.models.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import FormattedInput + from distilabel.typing import FormattedInput, GenerateOutput # Mixture-of-Agents system prompt from the paper with the addition instructing the LLM # to not mention that it used responses from previous models to avoid having texts like diff --git a/src/distilabel/models/llms/ollama.py b/src/distilabel/models/llms/ollama.py index a930399114..4cb5aa0428 100644 --- a/src/distilabel/models/llms/ollama.py +++ b/src/distilabel/models/llms/ollama.py @@ -19,17 +19,19 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output from distilabel.models.mixins.magpie import MagpieChatTemplateMixin -from distilabel.steps.tasks.typing import InstructorStructuredOutputType, StandardInput +from distilabel.typing import ( + GenerateOutput, + InstructorStructuredOutputType, + StandardInput, +) if TYPE_CHECKING: from ollama import AsyncClient from ollama._types import ChatResponse, GenerateResponse - from distilabel.models.llms.typing import LLMStatistics - from distilabel.steps.tasks.typing import StandardInput + from distilabel.typing import LLMStatistics, StandardInput # Copied from `ollama._types.Options` diff --git a/src/distilabel/models/llms/openai.py b/src/distilabel/models/llms/openai.py index 91f24a3336..66dbfcff17 100644 --- a/src/distilabel/models/llms/openai.py +++ b/src/distilabel/models/llms/openai.py @@ -13,36 +13,32 @@ # limitations under the License. import io -import os from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union import orjson -from pydantic import Field, PositiveInt, PrivateAttr, SecretStr, validate_call +from pydantic import PositiveInt, validate_call from distilabel import envs from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException -from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.models.base_clients.openai import OpenAIBaseClient from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output -from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType +from distilabel.typing import FormattedInput, GenerateOutput if TYPE_CHECKING: - from openai import AsyncOpenAI, OpenAI from openai.types import Batch as OpenAIBatch from openai.types import FileObject as OpenAIFileObject from openai.types.chat import ChatCompletion as OpenAIChatCompletion from openai.types.chat.chat_completion import Choice as OpenAIChoice from openai.types.completion import Completion as OpenAICompletion - from distilabel.models.llms.typing import LLMStatistics, Logprob + from distilabel.typing import LLMStatistics, Logprob -_OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY" _OPENAI_BATCH_API_MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB -class OpenAILLM(AsyncLLM): +class OpenAILLM(OpenAIBaseClient, AsyncLLM): """OpenAI LLM implementation running the async API client. Attributes: @@ -143,99 +139,9 @@ class User(BaseModel): ``` """ - model: str - base_url: Optional[RuntimeParameter[str]] = Field( - default_factory=lambda: os.getenv( - "OPENAI_BASE_URL", "https://api.openai.com/v1" - ), - description="The base URL to use for the OpenAI API requests.", - ) - api_key: Optional[RuntimeParameter[SecretStr]] = Field( - default_factory=lambda: os.getenv(_OPENAI_API_KEY_ENV_VAR_NAME), - description="The API key to authenticate the requests to the OpenAI API.", - ) - default_headers: Optional[RuntimeParameter[Dict[str, str]]] = Field( - default=None, - description="The default headers to use for the OpenAI API requests.", - ) - max_retries: RuntimeParameter[int] = Field( - default=6, - description="The maximum number of times to retry the request to the API before" - " failing.", - ) - timeout: RuntimeParameter[int] = Field( - default=120, - description="The maximum time in seconds to wait for a response from the API.", - ) - structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = ( - Field( - default=None, - description="The structured output format to use across all the generations.", - ) - ) - - _api_key_env_var: str = PrivateAttr(_OPENAI_API_KEY_ENV_VAR_NAME) - _client: "OpenAI" = PrivateAttr(None) - _aclient: "AsyncOpenAI" = PrivateAttr(None) - def load(self) -> None: - """Loads the `AsyncOpenAI` client to benefit from async requests.""" - super().load() - - try: - from openai import AsyncOpenAI, OpenAI - except ImportError as ie: - raise ImportError( - "OpenAI Python client is not installed. Please install it using" - " `pip install 'distilabel[openai]'`." - ) from ie - - if self.api_key is None: - raise ValueError( - f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`" - f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`." - ) - - self._client = OpenAI( - base_url=self.base_url, - api_key=self.api_key.get_secret_value(), - max_retries=self.max_retries, # type: ignore - timeout=self.timeout, - default_headers=self.default_headers, - ) - - self._aclient = AsyncOpenAI( - base_url=self.base_url, - api_key=self.api_key.get_secret_value(), - max_retries=self.max_retries, # type: ignore - timeout=self.timeout, - default_headers=self.default_headers, - ) - - if self.structured_output: - result = self._prepare_structured_output( - structured_output=self.structured_output, - client=self._aclient, - framework="openai", - ) - self._aclient = result.get("client") # type: ignore - if structured_output := result.get("structured_output"): - self.structured_output = structured_output - - def unload(self) -> None: - """Set clients to `None` as they both contain `thread._RLock` which cannot be pickled - in case an exception is raised and has to be handled in the main process""" - - self._client = None # type: ignore - self._aclient = None # type: ignore - self.default_headers = None - self.structured_output = None - super().unload() - - @property - def model_name(self) -> str: - """Returns the model name used for the LLM.""" - return self.model + AsyncLLM.load(self) + OpenAIBaseClient.load(self) @validate_call async def agenerate( # type: ignore diff --git a/src/distilabel/models/llms/typing.py b/src/distilabel/models/llms/typing.py deleted file mode 100644 index cfa4ec382f..0000000000 --- a/src/distilabel/models/llms/typing.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2023-present, Argilla, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING, Any, Dict, List, TypedDict, TypeVar, Union - -from typing_extensions import NotRequired - -LLMOutput = List[Union[str, None]] - - -class Logprob(TypedDict): - token: str - logprob: float - - -LLMLogprobs = List[List[List[Logprob]]] -"""A type alias representing the probability distributions output by an `LLM`. - -Structure: - - Outermost list: contains multiple generation choices when sampling (`n` sequences) - - Middle list: represents each position in the generated sequence - - Innermost list: contains the log probabilities for each token in the vocabulary at that position -""" - - -class TokenCount(TypedDict): - input_tokens: List[int] - output_tokens: List[int] - - -LLMStatistics = Union[TokenCount, Dict[str, Any]] -"""Initially the LLMStatistics will contain the token count, but can have more variables. -They can be added once we have them defined for every LLM. -""" - - -class GenerateOutput(TypedDict): - generations: LLMOutput - statistics: LLMStatistics - logprobs: NotRequired[LLMLogprobs] - - -if TYPE_CHECKING: - from numpy import floating - from numpy.typing import NDArray - - GenericFloat = TypeVar("GenericFloat", bound=floating[Any]) - - HiddenState = NDArray[GenericFloat] -else: - HiddenState = Any diff --git a/src/distilabel/models/llms/utils.py b/src/distilabel/models/llms/utils.py index ef97e53e1f..45f9088ca5 100644 --- a/src/distilabel/models/llms/utils.py +++ b/src/distilabel/models/llms/utils.py @@ -14,10 +14,10 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union -from distilabel.steps.tasks.typing import ChatType +from distilabel.typing import ChatType if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput, LLMLogprobs, LLMOutput + from distilabel.typing import GenerateOutput, LLMLogprobs, LLMOutput def compute_tokens( diff --git a/src/distilabel/models/llms/vertexai.py b/src/distilabel/models/llms/vertexai.py index 7c1b3e6bb4..b241e4d8d8 100644 --- a/src/distilabel/models/llms/vertexai.py +++ b/src/distilabel/models/llms/vertexai.py @@ -18,14 +18,13 @@ from typing_extensions import TypedDict from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output -from distilabel.steps.tasks.typing import StandardInput +from distilabel.typing import GenerateOutput, StandardInput if TYPE_CHECKING: from vertexai.generative_models import Content, GenerationResponse, GenerativeModel - from distilabel.models.llms.typing import LLMStatistics + from distilabel.typing import LLMStatistics class VertexChatItem(TypedDict): diff --git a/src/distilabel/models/llms/vllm.py b/src/distilabel/models/llms/vllm.py index ceab8e3e30..6075c4f54e 100644 --- a/src/distilabel/models/llms/vllm.py +++ b/src/distilabel/models/llms/vllm.py @@ -33,11 +33,15 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import LLM from distilabel.models.llms.openai import OpenAILLM -from distilabel.models.llms.typing import GenerateOutput, Logprob from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.models.mixins.magpie import MagpieChatTemplateMixin -from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType +from distilabel.typing import ( + FormattedInput, + GenerateOutput, + Logprob, + OutlinesStructuredOutputType, +) if TYPE_CHECKING: from openai import OpenAI # noqa @@ -45,10 +49,13 @@ from vllm import LLM as _vLLM from vllm.outputs import RequestOutput, CompletionOutput - from distilabel.steps.tasks.typing import StandardInput - from distilabel.models.llms.typing import LLMStatistics - from distilabel.steps.tasks.typing import StructuredInput - from distilabel.models.llms.typing import LLMLogprobs, LLMOutput + from distilabel.typing import ( + StandardInput, + StructuredInput, + LLMStatistics, + LLMLogprobs, + LLMOutput, + ) LogitsProcessorFn = Union[ diff --git a/src/distilabel/models/mixins/magpie.py b/src/distilabel/models/mixins/magpie.py index 8efa3add58..8edc1d92e8 100644 --- a/src/distilabel/models/mixins/magpie.py +++ b/src/distilabel/models/mixins/magpie.py @@ -18,7 +18,7 @@ from typing_extensions import Self if TYPE_CHECKING: - from distilabel.steps.tasks.typing import StandardInput + from distilabel.typing import StandardInput MagpieAvailablePreQueryTemplates = Literal["llama3", "qwen2"] """The available predefined pre-query templates.""" diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index 168599f782..2a0c89abd5 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -68,13 +68,13 @@ from distilabel.distiset import Distiset from distilabel.pipeline.routing_batch_function import RoutingBatchFunction - from distilabel.pipeline.typing import ( + from distilabel.steps.base import Step + from distilabel.typing import ( InputDataset, LoadGroups, PipelineRuntimeParametersInfo, StepLoadStatus, ) - from distilabel.steps.base import Step class _CacheLocation(TypedDict): """Dictionary to store the filenames and directories of a cached pipeline. diff --git a/src/distilabel/pipeline/local.py b/src/distilabel/pipeline/local.py index e8716f1ade..29ab8131cb 100644 --- a/src/distilabel/pipeline/local.py +++ b/src/distilabel/pipeline/local.py @@ -44,8 +44,8 @@ from queue import Queue from distilabel.distiset import Distiset - from distilabel.pipeline.typing import InputDataset, LoadGroups from distilabel.steps.base import _Step + from distilabel.typing import InputDataset, LoadGroups _SUBPROCESS_EXCEPTION: Union[Exception, None] = None diff --git a/src/distilabel/pipeline/ray.py b/src/distilabel/pipeline/ray.py index c2e85afd86..2d1158aedd 100644 --- a/src/distilabel/pipeline/ray.py +++ b/src/distilabel/pipeline/ray.py @@ -32,8 +32,8 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from distilabel.distiset import Distiset - from distilabel.pipeline.typing import InputDataset, LoadGroups from distilabel.steps.base import _Step + from distilabel.typing import InputDataset, LoadGroups class RayPipeline(BasePipeline): diff --git a/src/distilabel/pipeline/routing_batch_function.py b/src/distilabel/pipeline/routing_batch_function.py index 3f0aaf9ff4..31889acc90 100644 --- a/src/distilabel/pipeline/routing_batch_function.py +++ b/src/distilabel/pipeline/routing_batch_function.py @@ -28,8 +28,8 @@ if TYPE_CHECKING: from distilabel.pipeline.batch import _Batch - from distilabel.pipeline.typing import DownstreamConnectableSteps from distilabel.steps.base import _Step + from distilabel.typing import DownstreamConnectableSteps RoutingBatchFunc = Callable[[List[str]], List[str]] """Type alias for a routing batch function. It takes a list of all the downstream steps and diff --git a/src/distilabel/pipeline/step_wrapper.py b/src/distilabel/pipeline/step_wrapper.py index 1caa3a3e38..52937107f3 100644 --- a/src/distilabel/pipeline/step_wrapper.py +++ b/src/distilabel/pipeline/step_wrapper.py @@ -21,8 +21,8 @@ from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.pipeline.batch import _Batch -from distilabel.pipeline.typing import StepLoadStatus from distilabel.steps.base import GeneratorStep, Step, _Step +from distilabel.typing import StepLoadStatus class _StepWrapper: diff --git a/src/distilabel/steps/__init__.py b/src/distilabel/steps/__init__.py index 58875bbec3..19d90f9a33 100644 --- a/src/distilabel/steps/__init__.py +++ b/src/distilabel/steps/__init__.py @@ -55,7 +55,7 @@ from distilabel.steps.globals.huggingface import PushToHub from distilabel.steps.reward_model import RewardModelScore from distilabel.steps.truncate import TruncateTextColumn -from distilabel.steps.typing import GeneratorStepOutput, StepOutput +from distilabel.typing import GeneratorStepOutput, StepOutput __all__ = [ "DBSCAN", diff --git a/src/distilabel/steps/argilla/base.py b/src/distilabel/steps/argilla/base.py index 06db05e05b..1742ac675d 100644 --- a/src/distilabel/steps/argilla/base.py +++ b/src/distilabel/steps/argilla/base.py @@ -31,7 +31,7 @@ if TYPE_CHECKING: from argilla import Argilla, Dataset - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput _ARGILLA_API_URL_ENV_VAR_NAME = "ARGILLA_API_URL" diff --git a/src/distilabel/steps/argilla/preference.py b/src/distilabel/steps/argilla/preference.py index 210cca208f..22cb6d02da 100644 --- a/src/distilabel/steps/argilla/preference.py +++ b/src/distilabel/steps/argilla/preference.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: from argilla import RatingQuestion, Suggestion, TextField, TextQuestion - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class PreferenceToArgilla(ArgillaBase): diff --git a/src/distilabel/steps/argilla/text_generation.py b/src/distilabel/steps/argilla/text_generation.py index ad5323b0bc..ed590dec57 100644 --- a/src/distilabel/steps/argilla/text_generation.py +++ b/src/distilabel/steps/argilla/text_generation.py @@ -28,7 +28,7 @@ from distilabel.steps.base import StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class TextGenerationToArgilla(ArgillaBase): diff --git a/src/distilabel/steps/base.py b/src/distilabel/steps/base.py index 128ccefc75..88bed374bc 100644 --- a/src/distilabel/steps/base.py +++ b/src/distilabel/steps/base.py @@ -48,12 +48,14 @@ from distilabel.pipeline.base import BasePipeline from distilabel.pipeline.routing_batch_function import RoutingBatchFunction - from distilabel.pipeline.typing import ( + from distilabel.typing import ( DownstreamConnectable, DownstreamConnectableSteps, + GeneratorStepOutput, + StepColumns, + StepOutput, UpstreamConnectableSteps, ) - from distilabel.steps.typing import GeneratorStepOutput, StepColumns, StepOutput DEFAULT_INPUT_BATCH_SIZE = 50 diff --git a/src/distilabel/steps/clustering/dbscan.py b/src/distilabel/steps/clustering/dbscan.py index 2124d787c1..238d9338ed 100644 --- a/src/distilabel/steps/clustering/dbscan.py +++ b/src/distilabel/steps/clustering/dbscan.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from sklearn.cluster import DBSCAN as _DBSCAN - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class DBSCAN(GlobalStep): diff --git a/src/distilabel/steps/clustering/text_clustering.py b/src/distilabel/steps/clustering/text_clustering.py index 925ffab229..06358a6189 100644 --- a/src/distilabel/steps/clustering/text_clustering.py +++ b/src/distilabel/steps/clustering/text_clustering.py @@ -28,7 +28,7 @@ from distilabel.utils.itertools import batched if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class TextClustering(TextClassification, GlobalTask): diff --git a/src/distilabel/steps/clustering/umap.py b/src/distilabel/steps/clustering/umap.py index 9bf71c68e3..2688088c6f 100644 --- a/src/distilabel/steps/clustering/umap.py +++ b/src/distilabel/steps/clustering/umap.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from umap import UMAP as _UMAP - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class UMAP(GlobalStep): diff --git a/src/distilabel/steps/columns/combine.py b/src/distilabel/steps/columns/combine.py index 784beffe47..cd08303ac8 100644 --- a/src/distilabel/steps/columns/combine.py +++ b/src/distilabel/steps/columns/combine.py @@ -19,7 +19,7 @@ from distilabel.steps.columns.utils import merge_distilabel_metadata if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class CombineOutputs(Step): diff --git a/src/distilabel/steps/columns/expand.py b/src/distilabel/steps/columns/expand.py index 989924cf8a..aae1c336e7 100644 --- a/src/distilabel/steps/columns/expand.py +++ b/src/distilabel/steps/columns/expand.py @@ -22,7 +22,7 @@ from distilabel.steps.base import Step, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class ExpandColumns(Step): diff --git a/src/distilabel/steps/columns/group.py b/src/distilabel/steps/columns/group.py index 4cc77b50f0..ed9ee7a2df 100644 --- a/src/distilabel/steps/columns/group.py +++ b/src/distilabel/steps/columns/group.py @@ -21,7 +21,7 @@ from distilabel.steps.columns.utils import group_columns if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class GroupColumns(Step): diff --git a/src/distilabel/steps/columns/keep.py b/src/distilabel/steps/columns/keep.py index c12dfdd61d..0835cd834c 100644 --- a/src/distilabel/steps/columns/keep.py +++ b/src/distilabel/steps/columns/keep.py @@ -19,7 +19,7 @@ from distilabel.steps.base import Step, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class KeepColumns(Step): diff --git a/src/distilabel/steps/columns/merge.py b/src/distilabel/steps/columns/merge.py index 54ab3e3c75..0f2bb66f2f 100644 --- a/src/distilabel/steps/columns/merge.py +++ b/src/distilabel/steps/columns/merge.py @@ -20,7 +20,7 @@ from distilabel.steps.columns.utils import merge_columns if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class MergeColumns(Step): diff --git a/src/distilabel/steps/decorator.py b/src/distilabel/steps/decorator.py index 3e84df66f2..9bcc6f2dcb 100644 --- a/src/distilabel/steps/decorator.py +++ b/src/distilabel/steps/decorator.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: from distilabel.steps.base import _Step - from distilabel.steps.typing import GeneratorStepOutput, StepColumns, StepOutput + from distilabel.typing import GeneratorStepOutput, StepColumns, StepOutput _STEP_MAPPING = { "normal": Step, diff --git a/src/distilabel/steps/embeddings/embedding_generation.py b/src/distilabel/steps/embeddings/embedding_generation.py index 0aeed03102..5e2a839f69 100644 --- a/src/distilabel/steps/embeddings/embedding_generation.py +++ b/src/distilabel/steps/embeddings/embedding_generation.py @@ -18,7 +18,7 @@ from distilabel.steps.base import Step, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class EmbeddingGeneration(Step): diff --git a/src/distilabel/steps/embeddings/nearest_neighbour.py b/src/distilabel/steps/embeddings/nearest_neighbour.py index a962ca3b14..ab33be2a4d 100644 --- a/src/distilabel/steps/embeddings/nearest_neighbour.py +++ b/src/distilabel/steps/embeddings/nearest_neighbour.py @@ -23,7 +23,7 @@ from distilabel.steps import GlobalStep, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class FaissNearestNeighbour(GlobalStep): diff --git a/src/distilabel/steps/filtering/embedding.py b/src/distilabel/steps/filtering/embedding.py index cb1e710374..4572bca5cf 100644 --- a/src/distilabel/steps/filtering/embedding.py +++ b/src/distilabel/steps/filtering/embedding.py @@ -23,7 +23,7 @@ from distilabel.steps.base import GlobalStep, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class EmbeddingDedup(GlobalStep): diff --git a/src/distilabel/steps/filtering/minhash.py b/src/distilabel/steps/filtering/minhash.py index 7e86d30543..3d89b11e50 100644 --- a/src/distilabel/steps/filtering/minhash.py +++ b/src/distilabel/steps/filtering/minhash.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: from datasketch import MinHash, MinHashLSH - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput # Copied from: https://github.com/huggingface/datatrove/blob/main/src/datatrove/utils/text.py#L89C1-L95C65 diff --git a/src/distilabel/steps/formatting/conversation.py b/src/distilabel/steps/formatting/conversation.py index 29381521bd..0101aec196 100644 --- a/src/distilabel/steps/formatting/conversation.py +++ b/src/distilabel/steps/formatting/conversation.py @@ -17,7 +17,7 @@ from distilabel.steps.base import Step, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class ConversationTemplate(Step): diff --git a/src/distilabel/steps/formatting/dpo.py b/src/distilabel/steps/formatting/dpo.py index 72253eb194..528abbbb87 100644 --- a/src/distilabel/steps/formatting/dpo.py +++ b/src/distilabel/steps/formatting/dpo.py @@ -18,7 +18,7 @@ from distilabel.steps.base import Step, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class FormatTextGenerationDPO(Step): diff --git a/src/distilabel/steps/formatting/sft.py b/src/distilabel/steps/formatting/sft.py index 2793b212e6..6122ead0d1 100644 --- a/src/distilabel/steps/formatting/sft.py +++ b/src/distilabel/steps/formatting/sft.py @@ -18,7 +18,7 @@ from distilabel.steps.base import Step, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class FormatTextGenerationSFT(Step): diff --git a/src/distilabel/steps/generators/data.py b/src/distilabel/steps/generators/data.py index 803ee35eac..3b43e97c8f 100644 --- a/src/distilabel/steps/generators/data.py +++ b/src/distilabel/steps/generators/data.py @@ -20,7 +20,7 @@ from distilabel.steps.base import GeneratorStep if TYPE_CHECKING: - from distilabel.steps.typing import GeneratorStepOutput + from distilabel.typing import GeneratorStepOutput class LoadDataFromDicts(GeneratorStep): diff --git a/src/distilabel/steps/generators/huggingface.py b/src/distilabel/steps/generators/huggingface.py index 721b3d4081..6c3b821a33 100644 --- a/src/distilabel/steps/generators/huggingface.py +++ b/src/distilabel/steps/generators/huggingface.py @@ -47,7 +47,7 @@ from distilabel.steps.base import GeneratorStep if TYPE_CHECKING: - from distilabel.steps.typing import GeneratorStepOutput + from distilabel.typing import GeneratorStepOutput T = TypeVar("T") diff --git a/src/distilabel/steps/globals/huggingface.py b/src/distilabel/steps/globals/huggingface.py index 82e7f35ab6..e9723f520d 100644 --- a/src/distilabel/steps/globals/huggingface.py +++ b/src/distilabel/steps/globals/huggingface.py @@ -23,7 +23,7 @@ from distilabel.steps.base import GlobalStep, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class PushToHub(GlobalStep): diff --git a/src/distilabel/steps/reward_model.py b/src/distilabel/steps/reward_model.py index 0af5d5cfdd..87fef02264 100644 --- a/src/distilabel/steps/reward_model.py +++ b/src/distilabel/steps/reward_model.py @@ -25,8 +25,7 @@ import torch from transformers import PreTrainedModel, PreTrainedTokenizer - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import ChatType, StepColumns, StepOutput class RewardModelScore(Step, CudaDevicePlacementMixin): diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py index f542aea232..977e663992 100644 --- a/src/distilabel/steps/tasks/__init__.py +++ b/src/distilabel/steps/tasks/__init__.py @@ -16,7 +16,7 @@ from distilabel.steps.tasks.apigen.generator import APIGenGenerator from distilabel.steps.tasks.apigen.semantic_checker import APIGenSemanticChecker from distilabel.steps.tasks.argilla_labeller import ArgillaLabeller -from distilabel.steps.tasks.base import GeneratorTask, Task +from distilabel.steps.tasks.base import GeneratorTask, ImageTask, Task from distilabel.steps.tasks.clair import CLAIR from distilabel.steps.tasks.complexity_scorer import ComplexityScorer from distilabel.steps.tasks.decorator import task @@ -29,6 +29,7 @@ from distilabel.steps.tasks.evol_quality.base import EvolQuality from distilabel.steps.tasks.generate_embeddings import GenerateEmbeddings from distilabel.steps.tasks.genstruct import Genstruct +from distilabel.steps.tasks.image_generation import ImageGeneration from distilabel.steps.tasks.improving_text_embeddings import ( BitextRetrievalGenerator, EmbeddingTaskGenerator, @@ -55,9 +56,9 @@ from distilabel.steps.tasks.text_classification import TextClassification from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration from distilabel.steps.tasks.text_generation_with_image import TextGenerationWithImage -from distilabel.steps.tasks.typing import ChatItem, ChatType from distilabel.steps.tasks.ultrafeedback import UltraFeedback from distilabel.steps.tasks.urial import URIAL +from distilabel.typing import ChatItem, ChatType __all__ = [ "CLAIR", @@ -66,6 +67,7 @@ "APIGenGenerator", "APIGenSemanticChecker", "ArgillaLabeller", + "ArgillaLabeller", "BitextRetrievalGenerator", "ChatGeneration", "ChatItem", @@ -86,18 +88,22 @@ "GenerateTextRetrievalData", "GeneratorTask", "Genstruct", + "ImageGeneration", + "ImageTask", "InstructionBacktranslation", "Magpie", "MagpieGenerator", "MathShepherdCompleter", "MathShepherdGenerator", "MonolingualTripletGenerator", + "MonolingualTripletGenerator", "PairRM", "PrometheusEval", "QualityScorer", "SelfInstruct", "StructuredGeneration", "Task", + "Task", "TextClassification", "TextGeneration", "TextGenerationWithImage", diff --git a/src/distilabel/steps/tasks/apigen/execution_checker.py b/src/distilabel/steps/tasks/apigen/execution_checker.py index 7d30dd1f75..7cd597e88e 100644 --- a/src/distilabel/steps/tasks/apigen/execution_checker.py +++ b/src/distilabel/steps/tasks/apigen/execution_checker.py @@ -33,7 +33,7 @@ if TYPE_CHECKING: from types import ModuleType - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class APIGenExecutionChecker(Step): diff --git a/src/distilabel/steps/tasks/apigen/generator.py b/src/distilabel/steps/tasks/apigen/generator.py index 39f202d065..941c7b3ea4 100644 --- a/src/distilabel/steps/tasks/apigen/generator.py +++ b/src/distilabel/steps/tasks/apigen/generator.py @@ -26,8 +26,7 @@ from distilabel.steps.tasks.base import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns + from distilabel.typing import ChatType, StepColumns SYSTEM_PROMPT_API_GEN: Final[str] = """\ diff --git a/src/distilabel/steps/tasks/apigen/semantic_checker.py b/src/distilabel/steps/tasks/apigen/semantic_checker.py index c5cf0b183b..c5e7582313 100644 --- a/src/distilabel/steps/tasks/apigen/semantic_checker.py +++ b/src/distilabel/steps/tasks/apigen/semantic_checker.py @@ -24,8 +24,7 @@ from distilabel.steps.tasks.base import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns + from distilabel.typing import ChatType, StepColumns SYSTEM_PROMPT_SEMANTIC_CHECKER: Final[str] = """\ diff --git a/src/distilabel/steps/tasks/apigen/utils.py b/src/distilabel/steps/tasks/apigen/utils.py index 85ff0b764c..7e07997b06 100644 --- a/src/distilabel/steps/tasks/apigen/utils.py +++ b/src/distilabel/steps/tasks/apigen/utils.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from types import ModuleType - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class PrepareExamples(Step): diff --git a/src/distilabel/steps/tasks/argilla_labeller.py b/src/distilabel/steps/tasks/argilla_labeller.py index 1888087e8d..c3fae412c0 100644 --- a/src/distilabel/steps/tasks/argilla_labeller.py +++ b/src/distilabel/steps/tasks/argilla_labeller.py @@ -40,8 +40,7 @@ TextQuestion, ) - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepOutput + from distilabel.typing import ChatType, StepOutput class ArgillaLabeller(Task): diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index ae19a1038f..3a575545d1 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -22,6 +22,7 @@ from distilabel.constants import DISTILABEL_METADATA_KEY from distilabel.errors import DistilabelUserError from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.models.image_generation.base import ImageGenerationModel from distilabel.models.llms.base import LLM from distilabel.steps.base import ( GeneratorStep, @@ -33,9 +34,13 @@ from distilabel.utils.dicts import group_dicts if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput, LLMStatistics - from distilabel.steps.tasks.typing import ChatType, FormattedInput - from distilabel.steps.typing import StepOutput + from distilabel.typing import ( + ChatType, + FormattedInput, + GenerateOutput, + LLMStatistics, + StepOutput, + ) class _Task(_Step, ABC): @@ -491,6 +496,102 @@ class GlobalTask(_Task, GlobalStep): pass +class ImageTask(_Task, Step): + """`ImageTask` is a class that implements the `_Task` abstract class and adds the `Step` + interface to be used as a step in the pipeline. It differs from the `Task` in that it's + expected to work with `ImageGenerationModel`s instead of `LLM`s. + + Attributes: + image_generation_model: the `ImageGenerationModel` to be used to generate the outputs. + llm: This attribute is here to respect the `_Task` interface, but it's used internally only. + group_generations: whether to group the `num_generations` generated per input in + a list or create a row per generation. Defaults to `False`. + num_generations: The number of generations to be produced per input. + """ + + llm: Union[LLM, ImageGenerationModel, None] = None + image_generation_model: ImageGenerationModel + + def model_post_init(self, __context: Any) -> None: + assert self.llm is None, ( + "`ImageTask` cannot use an `LLM` attribute given by the user, pass " + "the `image_generation_model` attribute instead." + ) + self.llm = self.image_generation_model + # Call the post init from the Step, as we don't want to call specific behaviour + # from the task, that may need to deal with specific attributes from the LLM + # not in the ImageGenerationModel + super(Step, self).model_post_init(__context) + + @abstractmethod + def format_input(self, input: dict[str, any]) -> str: + """Abstract method to format the inputs of the task. It needs to receive an input + as a Python dictionary, and generates a string to be used as the prompt for the model.""" + pass + + def _format_inputs(self, inputs: list[dict[str, any]]) -> List["FormattedInput"]: + """Formats the inputs of the task using the `format_input` method. + + Args: + inputs: A list of Python dictionaries with the inputs of the task. + + Returns: + A list containing the formatted inputs, which are `ChatType`-like following + the OpenAI formatting. + """ + return [self.format_input(input) for input in inputs] + + def _format_outputs( + self, + outputs: list[Union[str, None]], + input: Union[Dict[str, Any], None] = None, + ) -> List[Dict[str, Any]]: + """Formats the outputs of the task using the `format_output` method. If the output + is `None` (i.e. the LLM failed to generate a response), then the outputs will be + set to `None` as well. + + Args: + outputs: The outputs (`n` generations) for the provided `input`. + input: The input used to generate the output. + + Returns: + A list containing a dictionary with the outputs of the task for each input. + """ + inputs = [None] if input is None else [input] + formatted_outputs = [] + + for output, input in zip(outputs, inputs): # type: ignore + try: + formatted_output = self.format_output(output, input) + formatted_output = self._create_metadata( + formatted_output, + output, + input, + add_raw_output=self.add_raw_output, # type: ignore + add_raw_input=self.add_raw_input, # type: ignore + statistics=None, + ) + formatted_outputs.append(formatted_output) + except Exception as e: + self._logger.warning( # type: ignore + f"Task '{self.name}' failed to format output: {e}. Saving raw response." # type: ignore + ) + formatted_outputs.append(self._output_on_failure(output, input)) + return formatted_outputs + + @abstractmethod + def process(self, inputs: StepInput) -> "StepOutput": # type: ignore + """Processes the inputs of the task and generates the outputs using the `ImageGenerationModel`. + + Args: + inputs: A list of Python dictionaries with the inputs of the task. + + Yields: + A list of Python dictionaries with the outputs of the task. + """ + pass + + def normalize_statistics(output: "GenerateOutput") -> "GenerateOutput": """Transforms the GenerateOutput statistics to have the same length as the generations. diff --git a/src/distilabel/steps/tasks/clair.py b/src/distilabel/steps/tasks/clair.py index 524a1d76c9..b619ef9dbb 100644 --- a/src/distilabel/steps/tasks/clair.py +++ b/src/distilabel/steps/tasks/clair.py @@ -21,8 +21,7 @@ from distilabel.steps.tasks.base import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns + from distilabel.typing import ChatType, StepColumns SYSTEM_PROMPT: Final[str] = ( diff --git a/src/distilabel/steps/tasks/complexity_scorer.py b/src/distilabel/steps/tasks/complexity_scorer.py index bd8a99c6b0..d36c7f1d07 100644 --- a/src/distilabel/steps/tasks/complexity_scorer.py +++ b/src/distilabel/steps/tasks/complexity_scorer.py @@ -30,7 +30,7 @@ from distilabel.steps.tasks.base import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType + from distilabel.typing import ChatType _PARSE_SCORE_LINE_REGEX = re.compile(r"\[\d+\] score: (\d+)", re.IGNORECASE) diff --git a/src/distilabel/steps/tasks/decorator.py b/src/distilabel/steps/tasks/decorator.py index 8862734f8c..c9752f247c 100644 --- a/src/distilabel/steps/tasks/decorator.py +++ b/src/distilabel/steps/tasks/decorator.py @@ -20,10 +20,10 @@ from distilabel.errors import DistilabelUserError from distilabel.steps.tasks.base import Task -from distilabel.steps.tasks.typing import FormattedInput +from distilabel.typing import FormattedInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns + from distilabel.typing import StepColumns TaskFormattingOutputFunc = Callable[..., Dict[str, Any]] diff --git a/src/distilabel/steps/tasks/evol_instruct/base.py b/src/distilabel/steps/tasks/evol_instruct/base.py index f1a44d6a84..eae066c690 100644 --- a/src/distilabel/steps/tasks/evol_instruct/base.py +++ b/src/distilabel/steps/tasks/evol_instruct/base.py @@ -23,12 +23,11 @@ from distilabel.steps.base import StepInput from distilabel.steps.tasks.base import Task from distilabel.steps.tasks.evol_instruct.utils import MUTATION_TEMPLATES -from distilabel.steps.tasks.typing import ChatType +from distilabel.typing import ChatType from distilabel.utils.lists import flatten_responses if TYPE_CHECKING: - from distilabel.models.llms.typing import LLMStatistics - from distilabel.steps.typing import StepOutput + from distilabel.typing import LLMStatistics, StepOutput class EvolInstruct(Task): diff --git a/src/distilabel/steps/tasks/evol_instruct/generator.py b/src/distilabel/steps/tasks/evol_instruct/generator.py index 6f985464eb..415654ba12 100644 --- a/src/distilabel/steps/tasks/evol_instruct/generator.py +++ b/src/distilabel/steps/tasks/evol_instruct/generator.py @@ -33,9 +33,7 @@ from distilabel.utils.lists import flatten_responses if TYPE_CHECKING: - from distilabel.models.llms.typing import LLMStatistics - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import GeneratorStepOutput + from distilabel.typing import ChatType, GeneratorStepOutput, LLMStatistics class EvolInstructGenerator(GeneratorTask): diff --git a/src/distilabel/steps/tasks/evol_quality/base.py b/src/distilabel/steps/tasks/evol_quality/base.py index 8ea7061105..41441381df 100644 --- a/src/distilabel/steps/tasks/evol_quality/base.py +++ b/src/distilabel/steps/tasks/evol_quality/base.py @@ -23,10 +23,10 @@ from distilabel.steps.base import StepInput from distilabel.steps.tasks.base import Task from distilabel.steps.tasks.evol_quality.utils import MUTATION_TEMPLATES -from distilabel.steps.tasks.typing import ChatType +from distilabel.typing import ChatType if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class EvolQuality(Task): diff --git a/src/distilabel/steps/tasks/generate_embeddings.py b/src/distilabel/steps/tasks/generate_embeddings.py index f73ee1b2b3..bedc8c5419 100644 --- a/src/distilabel/steps/tasks/generate_embeddings.py +++ b/src/distilabel/steps/tasks/generate_embeddings.py @@ -20,8 +20,7 @@ from distilabel.utils.chat import is_openai_format if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import ChatType, StepColumns, StepOutput class GenerateEmbeddings(Step): diff --git a/src/distilabel/steps/tasks/genstruct.py b/src/distilabel/steps/tasks/genstruct.py index 2b9c307d5b..e63a75f704 100644 --- a/src/distilabel/steps/tasks/genstruct.py +++ b/src/distilabel/steps/tasks/genstruct.py @@ -28,7 +28,7 @@ from distilabel.steps.tasks.base import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType + from distilabel.typing import ChatType _PARSE_GENSTRUCT_OUTPUT_REGEX = r"(.+?)\[\[\[Assistant\]\]\](.+)$" diff --git a/src/distilabel/steps/tasks/image_generation.py b/src/distilabel/steps/tasks/image_generation.py new file mode 100644 index 0000000000..3484b90058 --- /dev/null +++ b/src/distilabel/steps/tasks/image_generation.py @@ -0,0 +1,188 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +from typing import TYPE_CHECKING + +from distilabel.models.image_generation.utils import image_from_str +from distilabel.steps.base import StepInput +from distilabel.steps.tasks.base import ImageTask + +if TYPE_CHECKING: + from distilabel.typing import StepColumns, StepOutput + + +class ImageGeneration(ImageTask): + """Image generation with an image to text model given a prompt. + + `ImageGeneration` is a pre-defined task that allows generating images from a prompt. + It works with any of the `image_generation` defined under `distilabel.models.image_generation`, + the models implemented models that allow image generation. + By default, the images are generated as a base64 string format, and after the dataset + has been generated, the images can be automatically transformed to `PIL.Image.Image` using + `Distiset.transform_columns_to_image`. Take a look at the `Image Generation with distilabel` + example in the documentation for more information. + Using the `save_artifacts` attribute, the images can be saved on the artifacts folder in the + hugging face hub repository. + + Attributes: + save_artifacts: Bool value to save the image artifacts on its folder. + Otherwise, the base64 representation of the image will be saved as + a string. Defaults to False. + image_format: Any of the formats supported by PIL. Defaults to `JPEG`. + + Input columns: + - prompt (str): A column named prompt with the prompts to generate the images. + + Output columns: + - image (`str`): The generated image. Initially is a base64 string, for simplicity + during the pipeline run, but this can be transformed to an Image object after + distiset is returned at the end of a pipeline by calling + `distiset.transform_columns_to_image()`. + - image_path (`str`): The path where the image is saved. Only available if `save_artifacts` + is True. + - model_name (`str`): The name of the model used to generate the image. + + Categories: + - image-generation + + Examples: + Generate an image from a prompt: + + ```python + from distilabel.steps.tasks import ImageGeneration + from distilabel.models.image_generation import InferenceEndpointsImageGeneration + + igm = InferenceEndpointsImageGeneration( + model_id="black-forest-labs/FLUX.1-schnell" + ) + + # save_artifacts=True by default in JPEG format, if set to False, the image will be saved as a string. + image_gen = ImageGeneration(image_generation_model=igm) + + image_gen.load() + + result = next( + image_gen.process( + [{"prompt": "a white siamese cat"}] + ) + ) + ``` + + Generate an image and save them as artifacts in a Hugging Face Hub repository: + + ```python + from distilabel.steps.tasks import ImageGeneration + # Select the Image Generation model to use + from distilabel.models.image_generation import OpenAIImageGeneration + + igm = OpenAIImageGeneration( + model="dall-e-3", + api_key="api.key", + generation_kwargs={ + "size": "1024x1024", + "quality": "standard", + "style": "natural" + } + ) + + # save_artifacts=True by default in JPEG format, if set to False, the image will be saved as a string. + image_gen = ImageGeneration( + image_generation_model=igm, + save_artifacts=True, + image_format="JPEG" # By default will use JPEG, the options available can be seen in PIL documentation. + ) + + image_gen.load() + + result = next( + image_gen.process( + [{"prompt": "a white siamese cat"}] + ) + ) + ``` + """ + + save_artifacts: bool = False + image_format: str = "JPEG" + + @property + def inputs(self) -> "StepColumns": + return ["prompt"] + + @property + def outputs(self) -> "StepColumns": + return { + "image": True, + "image_path": False, + "model_name": True, + } + + def format_input(self, input: dict[str, any]) -> str: + return input["prompt"] + + def format_output( + self, output: dict[str, any], input: dict[str, any] + ) -> dict[str, any]: + image = None + if img_str := output.get("images"): + image = img_str[0] # Grab only the first image + + return {"image": image, "model_name": self.llm.model_name} + + def save(self, **kwargs): + if not self.save_artifacts: + from distilabel.utils.serialization import _Serializable + + super(_Serializable).save(**kwargs) + + def process(self, inputs: StepInput) -> "StepOutput": + formatted_inputs = self._format_inputs(inputs) + + outputs = self.llm.generate_outputs( + inputs=formatted_inputs, + num_generations=self.num_generations, + **self.llm.get_generation_kwargs(), + ) + + task_outputs = [] + for input, input_outputs in zip(inputs, outputs): + formatted_outputs = self._format_outputs(input_outputs, input) + for formatted_output in formatted_outputs: + if self.save_artifacts and ( + image := formatted_output.get("image", None) + ): + # use prompt as filename + prompt_hash = hashlib.md5(input["prompt"].encode()).hexdigest() + # Build PIL image to save it + image = image_from_str(image) + + self.save_artifact( + name="images", + write_function=lambda path, + prompt_hash=prompt_hash, + img=image: img.save( + path / f"{prompt_hash}.{self.image_format.lower()}", + format=self.image_format, + ), + metadata={"type": "image"}, + ) + formatted_output["image_path"] = ( + f"artifacts/{self.name}/images/{prompt_hash}.{self.image_format.lower()}" + ) + + task_outputs.append( + {**input, **formatted_output, "model_name": self.llm.model_name} + ) + yield task_outputs diff --git a/src/distilabel/steps/tasks/improving_text_embeddings.py b/src/distilabel/steps/tasks/improving_text_embeddings.py index d806e3aded..8569c12810 100644 --- a/src/distilabel/steps/tasks/improving_text_embeddings.py +++ b/src/distilabel/steps/tasks/improving_text_embeddings.py @@ -23,8 +23,7 @@ from typing_extensions import override from distilabel.steps.tasks.base import GeneratorTask, Task -from distilabel.steps.tasks.typing import ChatType -from distilabel.steps.typing import GeneratorStepOutput +from distilabel.typing import ChatType, GeneratorStepOutput # BASE CLASSES diff --git a/src/distilabel/steps/tasks/instruction_backtranslation.py b/src/distilabel/steps/tasks/instruction_backtranslation.py index a0420ef8f3..52405406ed 100644 --- a/src/distilabel/steps/tasks/instruction_backtranslation.py +++ b/src/distilabel/steps/tasks/instruction_backtranslation.py @@ -26,7 +26,7 @@ from pydantic import PrivateAttr from distilabel.steps.tasks.base import Task -from distilabel.steps.tasks.typing import ChatType +from distilabel.typing import ChatType class InstructionBacktranslation(Task): diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py index 265497409c..13c4f9f0be 100644 --- a/src/distilabel/steps/tasks/magpie/base.py +++ b/src/distilabel/steps/tasks/magpie/base.py @@ -30,9 +30,7 @@ from distilabel.utils.dicts import merge_dicts if TYPE_CHECKING: - from distilabel.models.llms.typing import LLMStatistics - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import ChatType, LLMStatistics, StepColumns, StepOutput MAGPIE_MULTI_TURN_SYSTEM_PROMPT = ( diff --git a/src/distilabel/steps/tasks/magpie/generator.py b/src/distilabel/steps/tasks/magpie/generator.py index c9d18d9fca..c1fcd14828 100644 --- a/src/distilabel/steps/tasks/magpie/generator.py +++ b/src/distilabel/steps/tasks/magpie/generator.py @@ -24,8 +24,7 @@ from distilabel.steps.tasks.magpie.base import MagpieBase if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import GeneratorStepOutput, StepColumns + from distilabel.typing import ChatType, GeneratorStepOutput, StepColumns class MagpieGenerator(GeneratorTask, MagpieBase): diff --git a/src/distilabel/steps/tasks/math_shepherd/completer.py b/src/distilabel/steps/tasks/math_shepherd/completer.py index 5d3fdd7e15..05ff410ac5 100644 --- a/src/distilabel/steps/tasks/math_shepherd/completer.py +++ b/src/distilabel/steps/tasks/math_shepherd/completer.py @@ -28,9 +28,7 @@ from distilabel.utils.itertools import batched if TYPE_CHECKING: - from distilabel.models.llms.typing import LLMStatistics - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import ChatType, LLMStatistics, StepColumns, StepOutput SYSTEM_PROMPT = """\ diff --git a/src/distilabel/steps/tasks/math_shepherd/generator.py b/src/distilabel/steps/tasks/math_shepherd/generator.py index d9ab565e54..efcd986549 100644 --- a/src/distilabel/steps/tasks/math_shepherd/generator.py +++ b/src/distilabel/steps/tasks/math_shepherd/generator.py @@ -25,8 +25,7 @@ ) if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns + from distilabel.typing import ChatType, StepColumns SYSTEM_PROMPT = """\ diff --git a/src/distilabel/steps/tasks/math_shepherd/utils.py b/src/distilabel/steps/tasks/math_shepherd/utils.py index 978496996f..8a04f325b5 100644 --- a/src/distilabel/steps/tasks/math_shepherd/utils.py +++ b/src/distilabel/steps/tasks/math_shepherd/utils.py @@ -20,7 +20,7 @@ from distilabel.steps.base import Step, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput def split_solution_steps(text: str) -> list[str]: diff --git a/src/distilabel/steps/tasks/pair_rm.py b/src/distilabel/steps/tasks/pair_rm.py index 23262a533f..4def62615f 100644 --- a/src/distilabel/steps/tasks/pair_rm.py +++ b/src/distilabel/steps/tasks/pair_rm.py @@ -20,7 +20,7 @@ from distilabel.steps.tasks.base import Step if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class PairRM(Step): diff --git a/src/distilabel/steps/tasks/prometheus_eval.py b/src/distilabel/steps/tasks/prometheus_eval.py index 4c61c416be..36e8e6ac9f 100644 --- a/src/distilabel/steps/tasks/prometheus_eval.py +++ b/src/distilabel/steps/tasks/prometheus_eval.py @@ -30,7 +30,7 @@ from distilabel.steps.tasks.base import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType + from distilabel.typing import ChatType _DEFAULT_RUBRICS = { diff --git a/src/distilabel/steps/tasks/quality_scorer.py b/src/distilabel/steps/tasks/quality_scorer.py index efafda2b7a..81dc0c1632 100644 --- a/src/distilabel/steps/tasks/quality_scorer.py +++ b/src/distilabel/steps/tasks/quality_scorer.py @@ -28,7 +28,7 @@ from typing_extensions import override from distilabel.steps.tasks.base import Task -from distilabel.steps.tasks.typing import ChatType +from distilabel.typing import ChatType _PARSE_SCORE_LINE_REGEX = re.compile(r"\[\d+\] score: (\d+)", re.IGNORECASE) diff --git a/src/distilabel/steps/tasks/self_instruct.py b/src/distilabel/steps/tasks/self_instruct.py index dcca46ee67..2b36740c7a 100644 --- a/src/distilabel/steps/tasks/self_instruct.py +++ b/src/distilabel/steps/tasks/self_instruct.py @@ -27,7 +27,7 @@ from distilabel.steps.tasks.base import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType + from distilabel.typing import ChatType class SelfInstruct(Task): diff --git a/src/distilabel/steps/tasks/sentence_transformers.py b/src/distilabel/steps/tasks/sentence_transformers.py index 350849e3d0..a4c2feb20b 100644 --- a/src/distilabel/steps/tasks/sentence_transformers.py +++ b/src/distilabel/steps/tasks/sentence_transformers.py @@ -28,7 +28,7 @@ import importlib.resources as importlib_resources if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType + from distilabel.typing import ChatType GenerationAction = Literal["paraphrase", "semantically-similar", "query", "answer"] diff --git a/src/distilabel/steps/tasks/structured_generation.py b/src/distilabel/steps/tasks/structured_generation.py index 905a6672d0..92eb3fd9e1 100644 --- a/src/distilabel/steps/tasks/structured_generation.py +++ b/src/distilabel/steps/tasks/structured_generation.py @@ -17,7 +17,7 @@ from distilabel.errors import DistilabelUserError from distilabel.steps.tasks.base import Task -from distilabel.steps.tasks.typing import StructuredInput +from distilabel.typing import StructuredInput class StructuredGeneration(Task): diff --git a/src/distilabel/steps/tasks/structured_outputs/outlines.py b/src/distilabel/steps/tasks/structured_outputs/outlines.py index a5aceacb3b..45b5fe7494 100644 --- a/src/distilabel/steps/tasks/structured_outputs/outlines.py +++ b/src/distilabel/steps/tasks/structured_outputs/outlines.py @@ -39,7 +39,7 @@ from transformers import Pipeline # noqa from vllm import LLM as _vLLM # noqa - from distilabel.steps.tasks.typing import OutlinesStructuredOutputType # noqa + from distilabel.typing import OutlinesStructuredOutputType # noqa Frameworks = Literal["transformers", "llamacpp", "vllm"] diff --git a/src/distilabel/steps/tasks/text_classification.py b/src/distilabel/steps/tasks/text_classification.py index 19df530fb6..ec032241f5 100644 --- a/src/distilabel/steps/tasks/text_classification.py +++ b/src/distilabel/steps/tasks/text_classification.py @@ -23,7 +23,7 @@ from distilabel.steps.tasks import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType + from distilabel.typing import ChatType TEXT_CLASSIFICATION_TEMPLATE: str = """\ diff --git a/src/distilabel/steps/tasks/text_generation.py b/src/distilabel/steps/tasks/text_generation.py index b6620430cc..59cf932423 100644 --- a/src/distilabel/steps/tasks/text_generation.py +++ b/src/distilabel/steps/tasks/text_generation.py @@ -23,8 +23,7 @@ from distilabel.utils.template import check_column_in_template if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns + from distilabel.typing import ChatType, StepColumns class TextGeneration(Task): diff --git a/src/distilabel/steps/tasks/text_generation_with_image.py b/src/distilabel/steps/tasks/text_generation_with_image.py index 8494afc9db..8aee386f80 100644 --- a/src/distilabel/steps/tasks/text_generation_with_image.py +++ b/src/distilabel/steps/tasks/text_generation_with_image.py @@ -28,8 +28,7 @@ if TYPE_CHECKING: from PIL.Image import Image - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns + from distilabel.typing import ChatType, StepColumns class TextGenerationWithImage(TextGeneration): diff --git a/src/distilabel/steps/tasks/ultrafeedback.py b/src/distilabel/steps/tasks/ultrafeedback.py index bac144f54d..1139254abb 100644 --- a/src/distilabel/steps/tasks/ultrafeedback.py +++ b/src/distilabel/steps/tasks/ultrafeedback.py @@ -22,7 +22,7 @@ from typing_extensions import override from distilabel.steps.tasks.base import Task -from distilabel.steps.tasks.typing import ChatType +from distilabel.typing import ChatType from distilabel.utils.dicts import group_dicts diff --git a/src/distilabel/steps/tasks/urial.py b/src/distilabel/steps/tasks/urial.py index 24b643ada6..b49c5d9f36 100644 --- a/src/distilabel/steps/tasks/urial.py +++ b/src/distilabel/steps/tasks/urial.py @@ -20,8 +20,7 @@ from distilabel.steps.tasks import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns + from distilabel.typing import ChatType, StepColumns class URIAL(Task): diff --git a/src/distilabel/typing.py b/src/distilabel/typing/__init__.py similarity index 72% rename from src/distilabel/typing.py rename to src/distilabel/typing/__init__.py index a3d65d5d75..ec65f26878 100644 --- a/src/distilabel/typing.py +++ b/src/distilabel/typing/__init__.py @@ -12,26 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -from distilabel.models.llms.typing import GenerateOutput -from distilabel.pipeline.typing import ( - DownstreamConnectable, - DownstreamConnectableSteps, - InputDataset, - PipelineRuntimeParametersInfo, - StepLoadStatus, - UpstreamConnectableSteps, -) -from distilabel.steps.tasks.typing import ( +from distilabel.typing.base import ( ChatItem, ChatType, + ImageContent, + ImageUrl, + TextContent, +) +from distilabel.typing.models import ( FormattedInput, + GenerateOutput, + HiddenState, InstructorStructuredOutputType, + LLMLogprobs, + LLMOutput, + LLMStatistics, + Logprob, OutlinesStructuredOutputType, StandardInput, StructuredInput, StructuredOutputType, + TokenCount, +) +from distilabel.typing.pipeline import ( + DownstreamConnectable, + DownstreamConnectableSteps, + InputDataset, + LoadGroups, + PipelineRuntimeParametersInfo, + StepLoadStatus, + UpstreamConnectableSteps, ) -from distilabel.steps.typing import GeneratorStepOutput, StepColumns, StepOutput +from distilabel.typing.steps import GeneratorStepOutput, StepColumns, StepOutput __all__ = [ "ChatItem", @@ -41,8 +53,16 @@ "FormattedInput", "GenerateOutput", "GeneratorStepOutput", + "HiddenState", + "ImageContent", + "ImageUrl", "InputDataset", "InstructorStructuredOutputType", + "LLMLogprobs", + "LLMOutput", + "LLMStatistics", + "LoadGroups", + "Logprob", "OutlinesStructuredOutputType", "PipelineRuntimeParametersInfo", "StandardInput", @@ -51,5 +71,7 @@ "StepOutput", "StructuredInput", "StructuredOutputType", + "TextContent", + "TokenCount", "UpstreamConnectableSteps", ] diff --git a/src/distilabel/typing/base.py b/src/distilabel/typing/base.py new file mode 100644 index 0000000000..16645c0957 --- /dev/null +++ b/src/distilabel/typing/base.py @@ -0,0 +1,46 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Literal, Union + +from typing_extensions import Required, TypedDict + + +class TextContent(TypedDict, total=False): + type: Required[Literal["text"]] + text: Required[str] + + +class ImageUrl(TypedDict): + url: Required[str] + """Either a URL of the image or the base64 encoded image data.""" + + +class ImageContent(TypedDict, total=False): + """Type alias for the user's message in a conversation that can include text or an image. + It's the standard type for vision language models: + https://platform.openai.com/docs/guides/vision + """ + + type: Required[Literal["image_url"]] + image_url: Required[ImageUrl] + + +class ChatItem(TypedDict): + role: Literal["system", "user", "assistant"] + content: Union[str, list[Union[TextContent, ImageContent]]] + + +ChatType = List[ChatItem] +"""ChatType is a type alias for a `list` of `dict`s following the OpenAI conversational format.""" diff --git a/src/distilabel/steps/tasks/typing.py b/src/distilabel/typing/models.py similarity index 66% rename from src/distilabel/steps/tasks/typing.py rename to src/distilabel/typing/models.py index d0d22a6811..aa11305421 100644 --- a/src/distilabel/steps/tasks/typing.py +++ b/src/distilabel/typing/models.py @@ -12,39 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Literal, + Optional, + Tuple, + Type, + TypeVar, + Union, +) from pydantic import BaseModel -from typing_extensions import Required, TypedDict +from typing_extensions import NotRequired, TypedDict +from distilabel.typing.base import ChatType -class TextContent(TypedDict, total=False): - type: Required[Literal["text"]] - text: Required[str] +LLMOutput = List[Union[str, None]] -class ImageUrl(TypedDict): - url: Required[str] - """Either a URL of the image or the base64 encoded image data.""" +class Logprob(TypedDict): + token: str + logprob: float -class ImageContent(TypedDict, total=False): - """Type alias for the user's message in a conversation that can include text or an image. - It's the standard type for vision language models: - https://platform.openai.com/docs/guides/vision - """ +LLMLogprobs = List[List[List[Logprob]]] +"""A type alias representing the probability distributions output by an `LLM`. + +Structure: + - Outermost list: contains multiple generation choices when sampling (`n` sequences) + - Middle list: represents each position in the generated sequence + - Innermost list: contains the log probabilities for each token in the vocabulary at that position +""" - type: Required[Literal["image_url"]] - image_url: Required[ImageUrl] +class TokenCount(TypedDict): + input_tokens: List[int] + output_tokens: List[int] -class ChatItem(TypedDict): - role: Literal["system", "user", "assistant"] - content: Union[str, list[Union[TextContent, ImageContent]]] +LLMStatistics = Union[TokenCount, Dict[str, Any]] +"""Initially the LLMStatistics will contain the token count, but can have more variables. +They can be added once we have them defined for every LLM. +""" -ChatType = List[ChatItem] -"""ChatType is a type alias for a `list` of `dict`s following the OpenAI conversational format.""" + +class GenerateOutput(TypedDict): + generations: LLMOutput + statistics: LLMStatistics + logprobs: NotRequired[LLMLogprobs] class OutlinesStructuredOutputType(TypedDict, total=False): @@ -84,11 +102,21 @@ class InstructorStructuredOutputType(TypedDict, total=False): OutlinesStructuredOutputType, InstructorStructuredOutputType ] """StructuredOutputType is an alias for the union of `OutlinesStructuredOutputType` and `InstructorStructuredOutputType`.""" - StandardInput = ChatType """StandardInput is an alias for ChatType that defines the default / standard input produced by `format_input`.""" StructuredInput = Tuple[StandardInput, Union[StructuredOutputType, None]] """StructuredInput defines a type produced by `format_input` when using either `StructuredGeneration` or a subclass of it.""" -FormattedInput = Union[StandardInput, StructuredInput, ChatType] +FormattedInput = Union[StandardInput, StructuredInput] """FormattedInput is an alias for the union of `StandardInput` and `StructuredInput` as generated by `format_input` and expected by the `LLM`s, as well as `ConversationType` for the vision language models.""" + + +if TYPE_CHECKING: + from numpy import floating + from numpy.typing import NDArray + + GenericFloat = TypeVar("GenericFloat", bound=floating[Any]) + + HiddenState = NDArray[GenericFloat] +else: + HiddenState = Any diff --git a/src/distilabel/pipeline/typing.py b/src/distilabel/typing/pipeline.py similarity index 98% rename from src/distilabel/pipeline/typing.py rename to src/distilabel/typing/pipeline.py index 3e796948aa..3824cbf116 100644 --- a/src/distilabel/pipeline/typing.py +++ b/src/distilabel/typing/pipeline.py @@ -18,11 +18,12 @@ Dict, List, Literal, - TypedDict, TypeVar, Union, ) +from typing_extensions import TypedDict + if TYPE_CHECKING: import pandas as pd from datasets import Dataset diff --git a/src/distilabel/steps/typing.py b/src/distilabel/typing/steps.py similarity index 100% rename from src/distilabel/steps/typing.py rename to src/distilabel/typing/steps.py diff --git a/src/distilabel/utils/export_components_info.py b/src/distilabel/utils/export_components_info.py index 00144fd041..dcf9e0ecd7 100644 --- a/src/distilabel/utils/export_components_info.py +++ b/src/distilabel/utils/export_components_info.py @@ -16,6 +16,7 @@ from typing import Generator, List, Type, TypedDict, TypeVar from distilabel.models.embeddings.base import Embeddings +from distilabel.models.image_generation.base import ImageGenerationModel from distilabel.models.llms.base import LLM from distilabel.steps.base import _Step from distilabel.steps.tasks.base import _Task @@ -28,6 +29,7 @@ class ComponentsInfo(TypedDict): """A dictionary containing `distilabel` components information.""" llms: List + image_generation_models: List steps: List tasks: List embeddings: List @@ -55,6 +57,10 @@ def export_components_info() -> ComponentsInfo: {"name": llm_type.__name__, "docstring": parse_google_docstring(llm_type)} for llm_type in _get_llms() ], + "image_generation_models": [ + {"name": igm_type.__name__, "docstring": parse_google_docstring(igm_type)} + for igm_type in _get_image_generation_models() + ], "embeddings": [ { "name": embeddings_type.__name__, @@ -113,6 +119,22 @@ def _get_llms() -> List[Type["LLM"]]: ] +def _get_image_generation_models() -> List[Type["ImageGenerationModel"]]: + """Get all `ImageGenerationModel` subclasses, that are not abstract classes. + + Note: + This is a placeholder as we don't have `ImageGenerationModel` classes yet. + + Returns: + The list of all the classes under `distilabel.models.image_generation` that are not abstract classes. + """ + return [ + igm_type + for igm_type in _recursive_subclasses(ImageGenerationModel) + if not inspect.isabstract(igm_type) + ] + + def _get_embeddings() -> List[Type["Embeddings"]]: """Get all `Embeddings` subclasses, that are not abstract classes. diff --git a/src/distilabel/utils/mkdocs/components_gallery.py b/src/distilabel/utils/mkdocs/components_gallery.py index 005f74748e..7293d90e69 100644 --- a/src/distilabel/utils/mkdocs/components_gallery.py +++ b/src/distilabel/utils/mkdocs/components_gallery.py @@ -75,6 +75,7 @@ ).read() ) + _STEPS_CATEGORY_TO_ICON = { "text-generation": ":material-text-box-edit:", "chat-generation": ":material-chat:", @@ -92,6 +93,7 @@ "load": ":material-file-download:", "execution": ":octicons-code-16:", "save": ":material-content-save:", + "image-generation": ":material-image:", "labelling": ":label:", } @@ -112,6 +114,7 @@ "load": "Load steps are used to load the data.", "execution": "Executes python functions.", "save": "Save steps are used to save the data.", + "image-generation": "Image generation steps are used to generate images based on a given prompt.", "labelling": "Labelling steps are used to label the data.", } @@ -199,6 +202,12 @@ def on_files( self.file_paths["llms"] = self._generate_llms_pages( src_dir=src_dir, llms=components_info["llms"] ) + self.file_paths["image_generation_models"] = ( + self._generate_image_generation_pages( + src_dir=src_dir, + image_generation_models=components_info["image_generation_models"], + ) + ) self.file_paths["embeddings"] = self._generate_embeddings_pages( src_dir=src_dir, embeddings=components_info["embeddings"] ) @@ -209,6 +218,7 @@ def on_files( *self.file_paths["steps"], *self.file_paths["tasks"], *self.file_paths["llms"], + *self.file_paths["image_generation_models"], *self.file_paths["embeddings"], ]: file = File( @@ -429,6 +439,48 @@ def _generate_llms_pages(self, src_dir: Path, llms: list) -> List[str]: return paths + def _generate_image_generation_pages( + self, src_dir: Path, image_generation_models: list + ) -> List[str]: + """Generates the files for the `ILMs` subsection of the components gallery. + + Args: + src_dir: The path to the source directory. + image_generation_models: The list of `ImageGenerationModel` components. + + Returns: + The relative paths to the generated files. + """ + + paths = ["components-gallery/image_generation/index.md"] + steps_gallery_page_path = src_dir / paths[0] + steps_gallery_page_path.parent.mkdir(parents=True, exist_ok=True) + + # Create detail page for each `ImageGenerationModel` + for igm in image_generation_models: + content = _LLM_DETAIL_TEMPLATE.render(llm=igm) + + ilm_path = f"components-gallery/image_generation/{igm['name'].lower()}.md" + path = src_dir / ilm_path + with open(path, "w") as f: + f.write(content) + + paths.append(ilm_path) + + # Create the `components-gallery/ilms/index.md` file + content = _COMPONENTS_LIST_TEMPLATE.render( + title="Image Generation Gallery", + description="", + components=image_generation_models, + component_group="image_generation_models", + default_icon=":material-image:", + ) + + with open(steps_gallery_page_path, "w") as f: + f.write(content) + + return paths + def _generate_embeddings_pages(self, src_dir: Path, embeddings: list) -> List[str]: """Generates the files for the `Embeddings` subsection of the components gallery. @@ -491,6 +543,10 @@ def on_nav( steps_file = files.get_file_from_path(self.file_paths["steps"][0]) tasks_file = files.get_file_from_path(self.file_paths["tasks"][0]) llms_file = files.get_file_from_path(self.file_paths["llms"][0]) + image_generation_file = files.get_file_from_path( + self.file_paths["image_generation_models"][0] + ) + steps_files = [ files.get_file_from_path(path) for path in self.file_paths["steps"][0:] ] @@ -500,6 +556,10 @@ def on_nav( llms_files = [ files.get_file_from_path(path) for path in self.file_paths["llms"][0:] ] + image_generation_files = [ + files.get_file_from_path(path) + for path in self.file_paths["image_generation_models"][0:] + ] # Create subsections steps_page = SectionPage( @@ -511,13 +571,19 @@ def on_nav( llms_page = SectionPage( "LLMs", file=llms_file, config=config, children=llms_files ) # type: ignore + igms_page = SectionPage( + "ImageGenerationModels", + file=image_generation_file, + config=config, + children=image_generation_files, + ) # type: ignore # Create the gallery section page = SectionPage( title=self.config.page_title, file=components_gallery_file, config=config, - children=[steps_page, tasks_page, llms_page], + children=[steps_page, tasks_page, llms_page, igms_page], ) # Add the page diff --git a/src/distilabel/utils/mkdocs/templates/components-gallery/index.md b/src/distilabel/utils/mkdocs/templates/components-gallery/index.md index cc3e44aecf..96f09dcc32 100644 --- a/src/distilabel/utils/mkdocs/templates/components-gallery/index.md +++ b/src/distilabel/utils/mkdocs/templates/components-gallery/index.md @@ -31,6 +31,14 @@ hide: [:octicons-arrow-right-24: LLMs](llms/index.md){ .bottom } +- :material-image:{ .lg .middle } __ImageGenerationModels__ + + --- + + Explore all the available `ImageGenerationModels`s integrated with `distilabel`. + + [:octicons-arrow-right-24: ImageGenerationModels](image_generation/index.md){ .bottom } + - :material-vector-line:{ .lg .middle } __Embeddings__ --- diff --git a/tests/integration/test_dataset_without_step.py b/tests/integration/test_dataset_without_step.py index b71631c27e..820793355f 100644 --- a/tests/integration/test_dataset_without_step.py +++ b/tests/integration/test_dataset_without_step.py @@ -21,7 +21,7 @@ from distilabel.pipeline import Pipeline from distilabel.steps import make_generator_step from distilabel.steps.base import Step, StepInput -from distilabel.steps.typing import StepOutput +from distilabel.typing import StepOutput if TYPE_CHECKING: pass diff --git a/tests/integration/test_embedding_dedup.py b/tests/integration/test_embedding_dedup.py index 7806cf6761..7ff02f3d70 100644 --- a/tests/integration/test_embedding_dedup.py +++ b/tests/integration/test_embedding_dedup.py @@ -22,7 +22,7 @@ from distilabel.steps.filtering.embedding import EmbeddingDedup if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput SAMPLE_DATA = [ diff --git a/tests/integration/test_load_stages.py b/tests/integration/test_load_stages.py index 9faa771d77..fa7806a6eb 100644 --- a/tests/integration/test_load_stages.py +++ b/tests/integration/test_load_stages.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from distilabel.pipeline.batch import _Batch - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput routing_batch_function = sample_n_steps(2) diff --git a/tests/integration/test_multiple_replicas.py b/tests/integration/test_multiple_replicas.py index 26d0f19b57..210a338e35 100644 --- a/tests/integration/test_multiple_replicas.py +++ b/tests/integration/test_multiple_replicas.py @@ -22,7 +22,7 @@ from distilabel.steps import LoadDataFromDicts, StepInput, StepResources, step if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput @step(outputs=["generation"]) diff --git a/tests/integration/test_offline_batch_generation.py b/tests/integration/test_offline_batch_generation.py index ae34d04159..7c81e94663 100644 --- a/tests/integration/test_offline_batch_generation.py +++ b/tests/integration/test_offline_batch_generation.py @@ -22,8 +22,7 @@ from distilabel.steps.tasks import TextGeneration if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import FormattedInput + from distilabel.typing import FormattedInput, GenerateOutput class DummyOfflineBatchGenerateLLM(LLM): diff --git a/tests/integration/test_pipe_llms.py b/tests/integration/test_pipe_llms.py index c95af1ac3f..b148d00b79 100644 --- a/tests/integration/test_pipe_llms.py +++ b/tests/integration/test_pipe_llms.py @@ -24,7 +24,7 @@ from distilabel.steps.tasks.text_generation import TextGeneration if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class RenameColumns(Step): diff --git a/tests/integration/test_pipe_simple.py b/tests/integration/test_pipe_simple.py index eee334677e..fd4548700b 100644 --- a/tests/integration/test_pipe_simple.py +++ b/tests/integration/test_pipe_simple.py @@ -21,7 +21,7 @@ from distilabel.steps.generators.data import LoadDataFromDicts if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput DATA = [ {"prompt": "Tell me a joke"}, diff --git a/tests/integration/test_ray_pipeline.py b/tests/integration/test_ray_pipeline.py index 241232b0cf..b29c7f454b 100644 --- a/tests/integration/test_ray_pipeline.py +++ b/tests/integration/test_ray_pipeline.py @@ -22,7 +22,7 @@ from distilabel.steps.generators.data import LoadDataFromDicts if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput DATA = [ {"prompt": "Tell me a joke"}, diff --git a/tests/integration/test_routing_batch_function.py b/tests/integration/test_routing_batch_function.py index 3a48543a8d..60e951f556 100644 --- a/tests/integration/test_routing_batch_function.py +++ b/tests/integration/test_routing_batch_function.py @@ -22,7 +22,7 @@ from distilabel.steps import LoadDataFromDicts, StepInput, step if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput @routing_batch_function() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 32f70133a2..86c8e6c33c 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -20,13 +20,13 @@ import pytest from pydantic import PrivateAttr +from distilabel.models.image_generation.base import AsyncImageGenerationModel from distilabel.models.llms.base import LLM, AsyncLLM from distilabel.models.mixins.magpie import MagpieChatTemplateMixin from distilabel.steps.tasks.base import Task if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import ChatType, FormattedInput + from distilabel.typing import ChatType, FormattedInput, GenerateOutput # Defined here too, so that the serde still works @@ -101,6 +101,29 @@ def generate( ] +class DummyAsyncImageGenerationModel(AsyncImageGenerationModel): + def load(self) -> None: + pass + + @property + def model_name(self) -> str: + return "test" + + async def agenerate( # type: ignore + self, input: str, num_generations: int = 1 + ) -> list[dict[str, Any]]: + import numpy as np + from PIL import Image + + np.random.seed(42) + arr = np.random.randint(0, 255, (100, 100, 3)) + random_image = Image.fromarray(arr, "RGB") + from distilabel.models.image_generation.utils import image_to_str + + img_str = image_to_str(random_image) + return [{"images": [img_str]} for _ in range(num_generations)] + + class DummyTask(Task): @property def inputs(self) -> List[str]: diff --git a/tests/unit/models/image_generation/__init__.py b/tests/unit/models/image_generation/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/tests/unit/models/image_generation/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/models/image_generation/huggingface/__init__.py b/tests/unit/models/image_generation/huggingface/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/tests/unit/models/image_generation/huggingface/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/models/image_generation/huggingface/test_inference_endpoints.py b/tests/unit/models/image_generation/huggingface/test_inference_endpoints.py new file mode 100644 index 0000000000..2ca5eeab0d --- /dev/null +++ b/tests/unit/models/image_generation/huggingface/test_inference_endpoints.py @@ -0,0 +1,59 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest.mock import AsyncMock, MagicMock, patch + +import nest_asyncio +import numpy as np +import pytest +from PIL import Image + +from distilabel.models.image_generation.huggingface.inference_endpoints import ( + InferenceEndpointsImageGeneration, +) + + +@patch("huggingface_hub.AsyncInferenceClient") +class TestInferenceEndpointsImageGeneration: + @pytest.mark.asyncio + async def test_agenerate(self, mock_inference_client: MagicMock) -> None: + igm = InferenceEndpointsImageGeneration( + model_id="black-forest-labs/FLUX.1-schnell", + api_key="api.key", + ) + igm.load() + + arr = np.random.randint(0, 255, (100, 100, 3)) + random_image = Image.fromarray(arr, "RGB") + igm._aclient.text_to_image = AsyncMock(return_value=random_image) + + assert await igm.agenerate("Aenean hend") + + @pytest.mark.asyncio + async def test_generate(self, mock_inference_client: MagicMock) -> None: + igm = InferenceEndpointsImageGeneration( + model_id="black-forest-labs/FLUX.1-schnell", + api_key="api.key", + ) + igm.load() + + arr = np.random.randint(0, 255, (100, 100, 3)) + random_image = Image.fromarray(arr, "RGB") + igm._aclient.text_to_image = AsyncMock(return_value=random_image) + + nest_asyncio.apply() + + images = igm.generate(inputs=["Aenean hendrerit aliquam velit. ..."]) + assert images[0][0]["images"][0].startswith("/9j/4AAQSkZJRgABAQAAAQABAAD/2w") diff --git a/tests/unit/models/image_generation/test_openai.py b/tests/unit/models/image_generation/test_openai.py new file mode 100644 index 0000000000..d057d6c85f --- /dev/null +++ b/tests/unit/models/image_generation/test_openai.py @@ -0,0 +1,105 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import nest_asyncio +import pytest + +from distilabel.models.image_generation.openai import OpenAIImageGeneration + + +@patch("openai.OpenAI") +@patch("openai.AsyncOpenAI") +class TestOpenAIImageGeneration: + model_id: str = "dall-e-3" + + def test_openai_image_generation( + self, _async_openai_mock: MagicMock, _openai_mock: MagicMock + ): + igm = OpenAIImageGeneration( + model="dall-e-3", + api_key="api.key", + generation_kwargs={ + "size": "1024x1024", + "quality": "standard", + "style": "natural", + }, + ) + + assert isinstance(igm, OpenAIImageGeneration) + assert igm.model_name == self.model_id + + @pytest.mark.parametrize("response_format", ["url", "b64_json"]) + @pytest.mark.asyncio + async def test_agenerate( + self, + async_openai_mock: MagicMock, + _openai_mock: MagicMock, + response_format: str, + ) -> None: + igm = OpenAIImageGeneration(model=self.model_id, api_key="api.key") # type: ignore + igm._aclient = async_openai_mock + + with patch("requests.get") as mock_get: + # Mock the download of the image + mock_get.return_value = Mock(content=b"iVBORw0KGgoAAAANSUhEUgA...") + if response_format == "url": + mocked_response = Mock(b64_json=None, url="https://example.com") + else: + mocked_response = Mock(b64_json="iVBORw0KGgoAAAANSUhEUgA...", url=None) + + mocked_generation = Mock(data=[mocked_response]) + igm._aclient.images.generate = AsyncMock(return_value=mocked_generation) + + await igm.agenerate( + input="a white siamese cat", response_format=response_format + ) + + @pytest.mark.parametrize("response_format", ["url", "b64_json"]) + @pytest.mark.asyncio + async def test_generate( + self, + async_openai_mock: MagicMock, + _openai_mock: MagicMock, + response_format: str, + ) -> None: + igm = OpenAIImageGeneration(model=self.model_id, api_key="api.key") # type: ignore + igm._aclient = async_openai_mock + + with patch("requests.get") as mock_get: + # Mock the download of the image + mock_get.return_value = Mock(content=b"iVBORw0KGgoAAAANSUhEUgA...") + + if response_format == "url": + mocked_response = Mock(b64_json=None, url="https://example.com") + else: + mocked_response = Mock(b64_json="iVBORw0KGgoAAAANSUhEUgA...", url=None) + + mocked_generation = Mock(data=[mocked_response]) + igm._aclient.images.generate = AsyncMock(return_value=mocked_generation) + + nest_asyncio.apply() + + igm.generate( + inputs=["a white siamese cat"], response_format=response_format + ) + + with pytest.raises(ValueError): + igm.generate( + inputs=[ + "a white siamese cat", + ], + response_format="unkown_format", + ) diff --git a/tests/unit/models/mixins/test_cuda_device_placement.py b/tests/unit/models/mixins/test_cuda_device_placement.py index bdddabf83e..a20bc2098c 100644 --- a/tests/unit/models/mixins/test_cuda_device_placement.py +++ b/tests/unit/models/mixins/test_cuda_device_placement.py @@ -23,7 +23,7 @@ from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType + from distilabel.typing import ChatType @pytest.fixture diff --git a/tests/unit/pipeline/test_base.py b/tests/unit/pipeline/test_base.py index 3cb680eb06..aa4da987fa 100644 --- a/tests/unit/pipeline/test_base.py +++ b/tests/unit/pipeline/test_base.py @@ -47,7 +47,7 @@ ) from distilabel.pipeline.write_buffer import _WriteBuffer from distilabel.steps.base import Step, StepInput, StepResources, _Step -from distilabel.steps.typing import StepOutput +from distilabel.typing import StepOutput from distilabel.utils.requirements import requirements from distilabel.utils.serialization import TYPE_INFO_KEY diff --git a/tests/unit/pipeline/test_dag.py b/tests/unit/pipeline/test_dag.py index 6a6163b75e..1874a14986 100644 --- a/tests/unit/pipeline/test_dag.py +++ b/tests/unit/pipeline/test_dag.py @@ -28,7 +28,7 @@ from .utils import DummyGeneratorStep, DummyGlobalStep, DummyStep1, DummyStep2 if TYPE_CHECKING: - from distilabel.steps.typing import ( + from distilabel.typing import ( GeneratorStepOutput, StepOutput, ) diff --git a/tests/unit/pipeline/utils.py b/tests/unit/pipeline/utils.py index cb223755aa..bc3a618adb 100644 --- a/tests/unit/pipeline/utils.py +++ b/tests/unit/pipeline/utils.py @@ -16,7 +16,7 @@ from distilabel.pipeline.batch import _Batch from distilabel.steps.base import GeneratorStep, GlobalStep, Step, StepInput -from distilabel.steps.typing import GeneratorStepOutput, StepOutput +from distilabel.typing import GeneratorStepOutput, StepOutput class DummyGeneratorStep(GeneratorStep): diff --git a/tests/unit/steps/argilla/test_base.py b/tests/unit/steps/argilla/test_base.py index c0a452e72b..da7b971250 100644 --- a/tests/unit/steps/argilla/test_base.py +++ b/tests/unit/steps/argilla/test_base.py @@ -23,7 +23,7 @@ from distilabel.steps.base import StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class CustomArgilla(ArgillaBase): diff --git a/tests/unit/steps/clustering/test_text_clustering.py b/tests/unit/steps/clustering/test_text_clustering.py index ddd473bb76..b5eb7a29dd 100644 --- a/tests/unit/steps/clustering/test_text_clustering.py +++ b/tests/unit/steps/clustering/test_text_clustering.py @@ -21,8 +21,7 @@ from tests.unit.conftest import DummyAsyncLLM if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import FormattedInput + from distilabel.typing import FormattedInput, GenerateOutput class ClusteringLLM(DummyAsyncLLM): diff --git a/tests/unit/steps/tasks/apigen/test_generator.py b/tests/unit/steps/tasks/apigen/test_generator.py index 38580c2c42..a73ed68dbf 100644 --- a/tests/unit/steps/tasks/apigen/test_generator.py +++ b/tests/unit/steps/tasks/apigen/test_generator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import random from typing import TYPE_CHECKING, List, Union @@ -21,10 +22,7 @@ from tests.unit.conftest import DummyLLM if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import FormattedInput - -import json + from distilabel.typing import FormattedInput, GenerateOutput class DummyAPIGenLLM(DummyLLM): diff --git a/tests/unit/steps/tasks/math_shepherd/test_completer.py b/tests/unit/steps/tasks/math_shepherd/test_completer.py index c5e8092cd3..5283bb79c3 100644 --- a/tests/unit/steps/tasks/math_shepherd/test_completer.py +++ b/tests/unit/steps/tasks/math_shepherd/test_completer.py @@ -21,7 +21,7 @@ from tests.unit.conftest import DummyLLM if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput + from distilabel.typing import GenerateOutput class MathShepherdCompleterLLM(DummyLLM): diff --git a/tests/unit/steps/tasks/math_shepherd/test_generator.py b/tests/unit/steps/tasks/math_shepherd/test_generator.py index 14ccc87533..6be30405f3 100644 --- a/tests/unit/steps/tasks/math_shepherd/test_generator.py +++ b/tests/unit/steps/tasks/math_shepherd/test_generator.py @@ -26,7 +26,7 @@ from tests.unit.conftest import DummyLLM if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput + from distilabel.typing import GenerateOutput class MathShepherdGeneratorLLM(DummyLLM): diff --git a/tests/unit/steps/tasks/structured_outputs/test_outlines.py b/tests/unit/steps/tasks/structured_outputs/test_outlines.py index 2812c2e48b..d6a7c11126 100644 --- a/tests/unit/steps/tasks/structured_outputs/test_outlines.py +++ b/tests/unit/steps/tasks/structured_outputs/test_outlines.py @@ -22,7 +22,7 @@ _is_outlines_version_below_0_1_0, model_to_schema, ) -from distilabel.steps.tasks.typing import OutlinesStructuredOutputType +from distilabel.typing import OutlinesStructuredOutputType class DummyUserTest(BaseModel): diff --git a/tests/unit/steps/tasks/test_argilla_labeller.py b/tests/unit/steps/tasks/test_argilla_labeller.py index 9418e899a5..b883b39197 100644 --- a/tests/unit/steps/tasks/test_argilla_labeller.py +++ b/tests/unit/steps/tasks/test_argilla_labeller.py @@ -19,7 +19,7 @@ from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.argilla_labeller import ArgillaLabeller -from distilabel.steps.tasks.typing import ChatItem +from distilabel.typing import ChatItem from tests.unit.conftest import DummyAsyncLLM diff --git a/tests/unit/steps/tasks/test_image_generation.py b/tests/unit/steps/tasks/test_image_generation.py new file mode 100644 index 0000000000..4c588419e0 --- /dev/null +++ b/tests/unit/steps/tasks/test_image_generation.py @@ -0,0 +1,55 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from distilabel.steps.tasks.image_generation import ImageGeneration +from tests.unit.conftest import DummyAsyncImageGenerationModel + + +class TestImageGeneration: + def test_format_input(self) -> None: + igm = DummyAsyncImageGenerationModel() + task = ImageGeneration(image_generation_model=igm) + task.load() + + assert ( + task.format_input({"prompt": "a white siamese cat"}) + == "a white siamese cat" + ) + + @pytest.mark.parametrize("save_artifacts", [False]) + def test_process(self, save_artifacts: bool) -> None: + igm = DummyAsyncImageGenerationModel() + task = ImageGeneration( + image_generation_model=igm, save_artifacts=save_artifacts + ) + task.load() + img_str = "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCABkAGQDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDw9whjujGGK7EOS3fv2HfJxz0/ixuDrgqv2jciofJjUKiZG7A7jAxgE55z1+b74jkfzBcMWZfkVRsQYbHZsdM4JzzkjJz94OMg23hIALxIACevKnPBGemed3rz98EU1Z+n/toSVtwupVZ7krEQsipyeMcA/rjPJPqdx+anTiZVuMNhfJi38bdwIBHpnse+cbvmxupJ3mfz2YhGaKMsB8u5cA9Mc9j7/e5+9SzFSt0QikGNCGckEZ5yPc+nPBz82N4UI2S+X/to7p6jZB5guGwqkRIdu7bxgdBgbucHuep55YOdVjS9VlCsYkOHbnJIOVPGQevfg5wcbwXEnNyvmAkxRqSp4bgE5wBnnnvkjPzffBJuj+2fMwV4EHQrnJVgCMjPTP8AFnrz98NO6VvL/wBsJd0guFmVrkSGNXMUZI4XKkAjA/i/hOec/e5+8ImQQpOrFWLImDg55w2ePYd8g57/AHg0fvBc7AmwIDk4U4BGMDPJ9ue57bhPdSNFJOiKcSQxAnGM/KrZ4AzkjPcd8scPRH7Kt2/9tDrYZcghrk4VwVX5mzkEnOQc8/rnJPON1LO/k/aEZXBkjRQTxkcNk465wD3Hfk4YJNcEtdBGwHVVbDY3Ac8468gHqeRnk/NS3BZmuHkVlLQpgMNpOcEHqOo57k5zz96iG135f+2lT313FddqXXlFoovLTcrH72ecc9s8gc9AecbhGw2LchDLGrRoGCtuDngkE8cZBYdfujr96pJyE+1hGbY6ISS2ck84JPqecc9P4sbgXAAM5VQo8tBwSwyQCRnj39emfm+/RFp2v5f+2hJakWprtvTwfmVW5HJyAc/jnPfPq33iUmpGM3f7oKEEaYCjA+6PYf1+rfeJQvhXovyFr1HSqI3mV42jYxhlXHY4Pr0IOQefx+9Trpjvm+980UYJVQA3yg88DrjOeckZ+b71E5K+cjRlWaNMBlwcYznj1GD75zz96iSIJHcAExnyo229mzg45wSOc8Z6DqPmD/lfp/7aLrqx7xLEt4AQFEaMu3ockEDk579t3TPI+cMnLYnADIAiBjlQG/Lrn73Gc4zz96lmMkbXQlRgXRcZXkg8g9ehHPfPB5+8JJpDKL0kBT5UY5KksQQCQRjOeT/ET1O4guFFtJddv/bP6/4cp7tlZyCbk9cjjAyMk5xnPpn16d/vCaYQr9pGN37mMRsq9+Cc4xg4B5+b/gX3ws6uFuAsiriGLftYKGGBx0G7nB4znG75vv0XOGa4fzMbo4yFVcbs4POcfU9ckZ+b79EW218v/bRO0nd7iTOyPdqJAQ8S5IGNwyDg88+vfJGefv0l1E/mXG/ch2I5BGd2Rnr6EHPfPB5HzUt15ckkxMQVvJjKg8Y+UcgYGc/jwSfm+/THLSJcuVVcovYjvkd/T6568/eDgtE/T/20E73aZNKFCXuPLKmKMAoNoHIwByMn1+9nBPzffEM2VWdVLKdqbg7glvUg45BOG4Pp97G4SSOVF2GwzPEgyhO0ZIYjtnp1OQcZ5++GGQf6YTnEiDBOSSSwPPP167v/AGYKC27af+2jva7X9LXoPv40SSUNlSsUW0CIfMSo74GARk5GcnHLffpJPMk+1tIqqxjVum3IyMdTk5BB756nP3gtzJGrXScx7o4wqgdeh7Y4PXvnj733w102R3IYKxMMbDdlWGQGyMgZ689c5zzjeFCXw38v/bRN293+v61ItRwbrIXb8i9gM8Dn8evvnq33iVHdtun6AYUDAxjge3+T6nqSn0XovyC1ieUxgzqkLhWRdu49OhyPr178ev3qU7hHcfvEBEKIVjOAw44wMA8gHvkjPP3gtwrJ9o8xOqpgsuDzyD+I56nOc8/eEcsiuZmlTLmNVUgZweOeMdgeTnPuTuFQtZfL/wBtCUetgl8orOYgEXahCk5Oe+D6Z7c9vvY3VJcqm6cLJjbFHjhRu4A9vrxnnn5vv0+7jiWW4DZV/JjaMYPOQCeuOxzn5v8AgWd9RvJs+1AzmTzEAyu7nJDYPPbHOcgkcZ4YTDo15f8AtoPVXW6/IddkLNO2XHmQocKOCSFODnHuc4OcdW+/TDII1ulVsCWFAR8wzyre2enfP44DB8zf8fO503NEnCdDyDj3x685Izz98I4DLdvGoCKijBI457c8+uOT1PONwIpWSfl/7aLlbGkGGO5T513RrkjO05IbB9u46jjv94OuJHL3DvECZI0BIUgDIBz2zwOpznk8n5qW4WWRrmQblXy037zgsDgg++SN2OT35wWpSSsd4QkiGSFAd7HnJDe2c4yM545wcbwR6S9P/bRsjuVkBkEiEErGRiMLkbflJwO45z368/eoeWKQXDPFtcxIqYXhSMemOoB5Oe+ck7wk5Iln3xuHaNcbhjIIBz75HOefXn71EiCMzq2Y90alVC43A4Izz0xg988dfvBws0reX/tvYTa+4SVFiMyyqDKUTZgcDIBz27d+c9ec7hPO7RC5HQyQxA4yAQQrdMDPQHnOevzffEckZ2XAE0bBUTJTjd7e5B64zkjPI+YNmj8nzkEuRsXJTo2ecH+fGRxkZHzUoxvbXt/7b9w7EF0rLOQxJOAcnvkZz+v/ANc9aKffBVnXZ90xocemVBPYf57t94lGtlfsvyC99SxIUl+2Nt4WNACVUEsMDPBHUZPG4nqc8uC4VnFw8igNsQrmPaSD0P4rz3z15+8FkQbbvzV2usUZH3eTx9M5BzxnPXn74Jnmf7W7ps3xoW+XZkHBX3ORg9843HP3hNO1l8v/AG0aa6fd9/4ELSMEuQCRvRc5G0kZBHGec8Hv68/eDn3wi6KHfHJGoZiWX7xDDr1PHQ56ZGcBqddkrJOWiYEoi5kPOSAdwIwDuxkZzwc8n5qUMXhvSZAT5a5OfvHcCe4z69+mcHG8ONnZry/9tB/3thbgSMblxLuxFGJGBChgccYwNxyAe+SCfm5an3XzLdMgXBiiLEnBPAPoMknnHPr82N4jcu8dyVYQr5KExqMbxwQOcEjv3JIB5wWEc6+Z58iMGUBGYkgnJHOCR6knHJ7/ADY3URitL+X/ALaEbD3XfHcsFgZRFHkj5dpwOnAyeCCOc8nnG8SOyyR3zFSpMaYBI9R05Gc9f4j3wfvhk4ljW4wzorQxeYrHBfIDDsMgnDY5zwfmxuolCzfa5FbywiICqsMMeMjPfkZ7njPPLgglovT/ANtEr8um3/DiHe6Xsmcfu1Dcj5vmHvz0z3PGcHG4LLIifahCWMbxKhGWOTwx6YGMqeDn8cb6hYvtnwDgqFJDcYznHHXJGe/rz1Fi4heL7UqoI08qMlSexwRjpkHqBzkc/NjeHHRr5f8Ato2rt3RFOhLT+ZF5TiNHClgMggcjuc5B4zkc8/eC+ZF5N0Akg3RKoJbcNwIJ5BHXBI6/Qn5wtxIy/aSCCskaKdoKDBwwGO54HXOeTz96mu8aJPsLfPEinDZGeCQencZ79O/3gR2Sfl/7aS09mRXylbgZUqTGhORjOVBz0HXrnvnqepKbeYFwQIzGAB8pIPbqD6HqOvXqepKFsvRfkNK2jJ59xM7AkAxoOm3cMA8gYz0B7+vP3qdOjkzGRgHEEbjK7SwIHY4ycHORnPXn71SXkSiS4LblxDFs+XAOVB54HXk55z1y2d9Muv8AWXB3lB5SDCLgNwCAfyznnJGct96lTa0a8v8A20Vno0EzjfeFVkTeiqfmyG5BOeeQcbh16d/vBJSMTmf7xiQoDEQSTgg+3y5Oec5zz98LKix/ahHuAESLkEbT0yO2c4yOvTPP3wyRpnS5Z5OSqq2xR8+PUjjtnvkgHn7wdPZW8v8A23+mU022xHIk89mIjxEoRUUAEccY47DPcnqc5LCSVN4uS8TRlYUYByM545B4yCCWHXjnnlxG7F47hn2SMQvzkYOfUe/r1zknB+8HXChXmSUMsgiQrkg54HPQcHOcjOffO4OO6Xp/7b+AmreQyVWQzKyr/q1IyoU44wR+H1znPP3qklkj3XSgAb4xxncdwIJII7dfXt1++Gyq7NOcGMCFTjaE3LxtyO+Rhu5OM88tT5MTx3MnlgERxk7mGc9yDxnPXHJwcnOC4ILZvy/9tEno1f7iM7IFuYzuO6JVDZOM5DdiM5x7j68MFaI+XctISHCq43Dlt3156NnjOcZwR8wGuiY7hUVB5kaodvyAKCOw6nheue5OT8wdNNHIbpiisXRNrHsRjJ4xyffPcnJ+cKPMmvl/7aNe7ewsgaL7ZkH95EuSSe7K3qM9M/xevP3wSSlVuwn3ZI0XhSvHDe3pnnOcZ5OGBcwFWuMHGI42fLZyxAJwSBkZ57+vzAb6JYoVjuticCOMpkngnBPp78c8f3vviY2aT9P/AG0N3fuV74g3TEDAIB785Gf89fqepKZdFjMN6hTtXAC44xx+nfv1yc5JVdF6L8gvfUtMUiW8WN1KsiqAhbGCQxHvgj3HGRn7wbMXj+0Isi7SiK21Qu8cEA+vY98kZ5+9T5lIa7KloV8lAVBHzn5ep4yDjcOp4B55emyuyfagNzCWNdxyW5JDHnI44J5yPrgNUxTaXfT/ANtDvpqOnhRGuYyCNsaMmV5JODnORgEEnjdn3++ImfCTKcfMibcrg4xnsP8A9fXn7wmbYsd55bAhok7EdSGx29Pf15xvC3K83J3YYwxsRnGQQDjkDPOD39fm++Kg3dX8v/bQvqRkmNbxUKlWjUMVfjqDjnG7ntz0zzjcCUtH542OokjTrxkY3Z6d8A859efvBd8ckV2zMGby12HHJOefx656/Q/fV1wgie4XlB5EYUEY3AhTnAwOevf1+b79ELJq/l/7aJ6PQSZuLqR0kRnjQDd3zg5PTrjcM5P1+8HTRqgu8jIEUeM+pIPByPc/xZ68/fEMyhDNhtxZFJJ3fxDceo5/H8M/eqbywkF6EkkVfKjJHA8zJBwc44/iwM/dHUDeEla1n2/9tKdnqNuUSJ7hQxBMaFFUcMCAec9u+eeg+998RSW7qs7OHBUIx3HltwznJHOev055HzCQEvHeuspQNGpYZyZDuHBJI4745PAODgsGjYYbx4htXaoO5iOCc/jyBxk/jjcCN1a77f8Ato1u7f1uFwFd7iRF3DC/MT0J6/U9fXv1+9Sygj7Qdu3EaBsEYPT884z36Z5+8GuBG10sqksYwIzs6HIIPBxyuTn5s5/4EJphJGbxRKCjQpkjjIJVgOoz6/xZIzzjeHDpby/9tFJ6u6Kt+E+1EoSVZVbJzkkgE5z7/X6t94lO1IMLw7sZKIeFwMFQfx69ec9ctncSkvhXovyEWLlFSGViNzFIBlh03Rlyfz4/HJyearGdtkxCgb1VMAkAD73rz0HXPr15ooqruz+X/tgb0035fqKHzZzuVXJ8uPgYwME547/KP59eaex+0RzzygGT5FBAxj5Sc8dT8vU9cknJ5oooiv3n3f8AtpSXu/15iXyLBOUQYV4o5MHnBZAxAPpkn9Op5p8qho5myRlY+B05Qvj8wP65PNFFFLVxv5f+2lLr/XRi3LmBrgLyJ4oi2WPG5Q5788jvn16gEJeILe5eNCxWW3jc5Y8FkWQ/UZ9c/nzRRWNFtyin/XwmM3rL1H3Ci3inCE4kjhzkn+JPMP6jofr1ANMv/luinUPBE5OBnJjDfzP49Tk80UVvT+Nei/KA2yO7fbKQFX5oY+gxj5VPb+vXqcnmpLqT7O8saKu2aCInPUZVX4x7+ufU5IBooqdvuX/tpD0Wncr3pzc7j1ZEY/UqD/X6+uTRRRSWy9Eay3Z//9k=" + + assert next(task.process([{"prompt": "a white siamese cat"}])) == [ + { + "prompt": "a white siamese cat", + "image": img_str, + "model_name": "test", + "distilabel_metadata": { + "raw_input_image_generation_0": "a white siamese cat", + "raw_output_image_generation_0": { + "images": [ + "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCABkAGQDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDw9whjujGGK7EOS3fv2HfJxz0/ixuDrgqv2jciofJjUKiZG7A7jAxgE55z1+b74jkfzBcMWZfkVRsQYbHZsdM4JzzkjJz94OMg23hIALxIACevKnPBGemed3rz98EU1Z+n/toSVtwupVZ7krEQsipyeMcA/rjPJPqdx+anTiZVuMNhfJi38bdwIBHpnse+cbvmxupJ3mfz2YhGaKMsB8u5cA9Mc9j7/e5+9SzFSt0QikGNCGckEZ5yPc+nPBz82N4UI2S+X/to7p6jZB5guGwqkRIdu7bxgdBgbucHuep55YOdVjS9VlCsYkOHbnJIOVPGQevfg5wcbwXEnNyvmAkxRqSp4bgE5wBnnnvkjPzffBJuj+2fMwV4EHQrnJVgCMjPTP8AFnrz98NO6VvL/wBsJd0guFmVrkSGNXMUZI4XKkAjA/i/hOec/e5+8ImQQpOrFWLImDg55w2ePYd8g57/AHg0fvBc7AmwIDk4U4BGMDPJ9ue57bhPdSNFJOiKcSQxAnGM/KrZ4AzkjPcd8scPRH7Kt2/9tDrYZcghrk4VwVX5mzkEnOQc8/rnJPON1LO/k/aEZXBkjRQTxkcNk465wD3Hfk4YJNcEtdBGwHVVbDY3Ac8468gHqeRnk/NS3BZmuHkVlLQpgMNpOcEHqOo57k5zz96iG135f+2lT313FddqXXlFoovLTcrH72ecc9s8gc9AecbhGw2LchDLGrRoGCtuDngkE8cZBYdfujr96pJyE+1hGbY6ISS2ck84JPqecc9P4sbgXAAM5VQo8tBwSwyQCRnj39emfm+/RFp2v5f+2hJakWprtvTwfmVW5HJyAc/jnPfPq33iUmpGM3f7oKEEaYCjA+6PYf1+rfeJQvhXovyFr1HSqI3mV42jYxhlXHY4Pr0IOQefx+9Trpjvm+980UYJVQA3yg88DrjOeckZ+b71E5K+cjRlWaNMBlwcYznj1GD75zz96iSIJHcAExnyo229mzg45wSOc8Z6DqPmD/lfp/7aLrqx7xLEt4AQFEaMu3ockEDk579t3TPI+cMnLYnADIAiBjlQG/Lrn73Gc4zz96lmMkbXQlRgXRcZXkg8g9ehHPfPB5+8JJpDKL0kBT5UY5KksQQCQRjOeT/ET1O4guFFtJddv/bP6/4cp7tlZyCbk9cjjAyMk5xnPpn16d/vCaYQr9pGN37mMRsq9+Cc4xg4B5+b/gX3ws6uFuAsiriGLftYKGGBx0G7nB4znG75vv0XOGa4fzMbo4yFVcbs4POcfU9ckZ+b79EW218v/bRO0nd7iTOyPdqJAQ8S5IGNwyDg88+vfJGefv0l1E/mXG/ch2I5BGd2Rnr6EHPfPB5HzUt15ckkxMQVvJjKg8Y+UcgYGc/jwSfm+/THLSJcuVVcovYjvkd/T6568/eDgtE/T/20E73aZNKFCXuPLKmKMAoNoHIwByMn1+9nBPzffEM2VWdVLKdqbg7glvUg45BOG4Pp97G4SSOVF2GwzPEgyhO0ZIYjtnp1OQcZ5++GGQf6YTnEiDBOSSSwPPP167v/AGYKC27af+2jva7X9LXoPv40SSUNlSsUW0CIfMSo74GARk5GcnHLffpJPMk+1tIqqxjVum3IyMdTk5BB756nP3gtzJGrXScx7o4wqgdeh7Y4PXvnj733w102R3IYKxMMbDdlWGQGyMgZ689c5zzjeFCXw38v/bRN293+v61ItRwbrIXb8i9gM8Dn8evvnq33iVHdtun6AYUDAxjge3+T6nqSn0XovyC1ieUxgzqkLhWRdu49OhyPr178ev3qU7hHcfvEBEKIVjOAw44wMA8gHvkjPP3gtwrJ9o8xOqpgsuDzyD+I56nOc8/eEcsiuZmlTLmNVUgZweOeMdgeTnPuTuFQtZfL/wBtCUetgl8orOYgEXahCk5Oe+D6Z7c9vvY3VJcqm6cLJjbFHjhRu4A9vrxnnn5vv0+7jiWW4DZV/JjaMYPOQCeuOxzn5v8AgWd9RvJs+1AzmTzEAyu7nJDYPPbHOcgkcZ4YTDo15f8AtoPVXW6/IddkLNO2XHmQocKOCSFODnHuc4OcdW+/TDII1ulVsCWFAR8wzyre2enfP44DB8zf8fO503NEnCdDyDj3x685Izz98I4DLdvGoCKijBI457c8+uOT1PONwIpWSfl/7aLlbGkGGO5T513RrkjO05IbB9u46jjv94OuJHL3DvECZI0BIUgDIBz2zwOpznk8n5qW4WWRrmQblXy037zgsDgg++SN2OT35wWpSSsd4QkiGSFAd7HnJDe2c4yM545wcbwR6S9P/bRsjuVkBkEiEErGRiMLkbflJwO45z368/eoeWKQXDPFtcxIqYXhSMemOoB5Oe+ck7wk5Iln3xuHaNcbhjIIBz75HOefXn71EiCMzq2Y90alVC43A4Izz0xg988dfvBws0reX/tvYTa+4SVFiMyyqDKUTZgcDIBz27d+c9ec7hPO7RC5HQyQxA4yAQQrdMDPQHnOevzffEckZ2XAE0bBUTJTjd7e5B64zkjPI+YNmj8nzkEuRsXJTo2ecH+fGRxkZHzUoxvbXt/7b9w7EF0rLOQxJOAcnvkZz+v/ANc9aKffBVnXZ90xocemVBPYf57t94lGtlfsvyC99SxIUl+2Nt4WNACVUEsMDPBHUZPG4nqc8uC4VnFw8igNsQrmPaSD0P4rz3z15+8FkQbbvzV2usUZH3eTx9M5BzxnPXn74Jnmf7W7ps3xoW+XZkHBX3ORg9843HP3hNO1l8v/AG0aa6fd9/4ELSMEuQCRvRc5G0kZBHGec8Hv68/eDn3wi6KHfHJGoZiWX7xDDr1PHQ56ZGcBqddkrJOWiYEoi5kPOSAdwIwDuxkZzwc8n5qUMXhvSZAT5a5OfvHcCe4z69+mcHG8ONnZry/9tB/3thbgSMblxLuxFGJGBChgccYwNxyAe+SCfm5an3XzLdMgXBiiLEnBPAPoMknnHPr82N4jcu8dyVYQr5KExqMbxwQOcEjv3JIB5wWEc6+Z58iMGUBGYkgnJHOCR6knHJ7/ADY3URitL+X/ALaEbD3XfHcsFgZRFHkj5dpwOnAyeCCOc8nnG8SOyyR3zFSpMaYBI9R05Gc9f4j3wfvhk4ljW4wzorQxeYrHBfIDDsMgnDY5zwfmxuolCzfa5FbywiICqsMMeMjPfkZ7njPPLgglovT/ANtEr8um3/DiHe6Xsmcfu1Dcj5vmHvz0z3PGcHG4LLIifahCWMbxKhGWOTwx6YGMqeDn8cb6hYvtnwDgqFJDcYznHHXJGe/rz1Fi4heL7UqoI08qMlSexwRjpkHqBzkc/NjeHHRr5f8Ato2rt3RFOhLT+ZF5TiNHClgMggcjuc5B4zkc8/eC+ZF5N0Akg3RKoJbcNwIJ5BHXBI6/Qn5wtxIy/aSCCskaKdoKDBwwGO54HXOeTz96mu8aJPsLfPEinDZGeCQencZ79O/3gR2Sfl/7aS09mRXylbgZUqTGhORjOVBz0HXrnvnqepKbeYFwQIzGAB8pIPbqD6HqOvXqepKFsvRfkNK2jJ59xM7AkAxoOm3cMA8gYz0B7+vP3qdOjkzGRgHEEbjK7SwIHY4ycHORnPXn71SXkSiS4LblxDFs+XAOVB54HXk55z1y2d9Muv8AWXB3lB5SDCLgNwCAfyznnJGct96lTa0a8v8A20Vno0EzjfeFVkTeiqfmyG5BOeeQcbh16d/vBJSMTmf7xiQoDEQSTgg+3y5Oec5zz98LKix/ahHuAESLkEbT0yO2c4yOvTPP3wyRpnS5Z5OSqq2xR8+PUjjtnvkgHn7wdPZW8v8A23+mU022xHIk89mIjxEoRUUAEccY47DPcnqc5LCSVN4uS8TRlYUYByM545B4yCCWHXjnnlxG7F47hn2SMQvzkYOfUe/r1zknB+8HXChXmSUMsgiQrkg54HPQcHOcjOffO4OO6Xp/7b+AmreQyVWQzKyr/q1IyoU44wR+H1znPP3qklkj3XSgAb4xxncdwIJII7dfXt1++Gyq7NOcGMCFTjaE3LxtyO+Rhu5OM88tT5MTx3MnlgERxk7mGc9yDxnPXHJwcnOC4ILZvy/9tEno1f7iM7IFuYzuO6JVDZOM5DdiM5x7j68MFaI+XctISHCq43Dlt3156NnjOcZwR8wGuiY7hUVB5kaodvyAKCOw6nheue5OT8wdNNHIbpiisXRNrHsRjJ4xyffPcnJ+cKPMmvl/7aNe7ewsgaL7ZkH95EuSSe7K3qM9M/xevP3wSSlVuwn3ZI0XhSvHDe3pnnOcZ5OGBcwFWuMHGI42fLZyxAJwSBkZ57+vzAb6JYoVjuticCOMpkngnBPp78c8f3vviY2aT9P/AG0N3fuV74g3TEDAIB785Gf89fqepKZdFjMN6hTtXAC44xx+nfv1yc5JVdF6L8gvfUtMUiW8WN1KsiqAhbGCQxHvgj3HGRn7wbMXj+0Isi7SiK21Qu8cEA+vY98kZ5+9T5lIa7KloV8lAVBHzn5ep4yDjcOp4B55emyuyfagNzCWNdxyW5JDHnI44J5yPrgNUxTaXfT/ANtDvpqOnhRGuYyCNsaMmV5JODnORgEEnjdn3++ImfCTKcfMibcrg4xnsP8A9fXn7wmbYsd55bAhok7EdSGx29Pf15xvC3K83J3YYwxsRnGQQDjkDPOD39fm++Kg3dX8v/bQvqRkmNbxUKlWjUMVfjqDjnG7ntz0zzjcCUtH542OokjTrxkY3Z6d8A859efvBd8ckV2zMGby12HHJOefx656/Q/fV1wgie4XlB5EYUEY3AhTnAwOevf1+b79ELJq/l/7aJ6PQSZuLqR0kRnjQDd3zg5PTrjcM5P1+8HTRqgu8jIEUeM+pIPByPc/xZ68/fEMyhDNhtxZFJJ3fxDceo5/H8M/eqbywkF6EkkVfKjJHA8zJBwc44/iwM/dHUDeEla1n2/9tKdnqNuUSJ7hQxBMaFFUcMCAec9u+eeg+998RSW7qs7OHBUIx3HltwznJHOev055HzCQEvHeuspQNGpYZyZDuHBJI4745PAODgsGjYYbx4htXaoO5iOCc/jyBxk/jjcCN1a77f8Ato1u7f1uFwFd7iRF3DC/MT0J6/U9fXv1+9Sygj7Qdu3EaBsEYPT884z36Z5+8GuBG10sqksYwIzs6HIIPBxyuTn5s5/4EJphJGbxRKCjQpkjjIJVgOoz6/xZIzzjeHDpby/9tFJ6u6Kt+E+1EoSVZVbJzkkgE5z7/X6t94lO1IMLw7sZKIeFwMFQfx69ec9ctncSkvhXovyEWLlFSGViNzFIBlh03Rlyfz4/HJyearGdtkxCgb1VMAkAD73rz0HXPr15ooqruz+X/tgb0035fqKHzZzuVXJ8uPgYwME547/KP59eaex+0RzzygGT5FBAxj5Sc8dT8vU9cknJ5oooiv3n3f8AtpSXu/15iXyLBOUQYV4o5MHnBZAxAPpkn9Op5p8qho5myRlY+B05Qvj8wP65PNFFFLVxv5f+2lLr/XRi3LmBrgLyJ4oi2WPG5Q5788jvn16gEJeILe5eNCxWW3jc5Y8FkWQ/UZ9c/nzRRWNFtyin/XwmM3rL1H3Ci3inCE4kjhzkn+JPMP6jofr1ANMv/luinUPBE5OBnJjDfzP49Tk80UVvT+Nei/KA2yO7fbKQFX5oY+gxj5VPb+vXqcnmpLqT7O8saKu2aCInPUZVX4x7+ufU5IBooqdvuX/tpD0Wncr3pzc7j1ZEY/UqD/X6+uTRRRSWy9Eay3Z//9k=" + ] + }, + }, + } + ] diff --git a/tests/unit/steps/tasks/test_improving_text_embeddings.py b/tests/unit/steps/tasks/test_improving_text_embeddings.py index 1bc4128c7c..25e60cab0b 100644 --- a/tests/unit/steps/tasks/test_improving_text_embeddings.py +++ b/tests/unit/steps/tasks/test_improving_text_embeddings.py @@ -18,7 +18,6 @@ import pytest from distilabel.models.llms.base import LLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.improving_text_embeddings import ( BitextRetrievalGenerator, @@ -29,7 +28,7 @@ GenerateTextRetrievalData, MonolingualTripletGenerator, ) -from distilabel.steps.tasks.typing import ChatType +from distilabel.typing import ChatType, GenerateOutput class MockLLM(LLM): diff --git a/tests/unit/steps/tasks/test_instruction_backtranslation.py b/tests/unit/steps/tasks/test_instruction_backtranslation.py index 5e54d94658..4139eef525 100644 --- a/tests/unit/steps/tasks/test_instruction_backtranslation.py +++ b/tests/unit/steps/tasks/test_instruction_backtranslation.py @@ -15,12 +15,11 @@ from typing import Any, List from distilabel.models.llms.base import LLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.instruction_backtranslation import ( InstructionBacktranslation, ) -from distilabel.steps.tasks.typing import ChatType +from distilabel.typing import ChatType, GenerateOutput class InstructionBacktranslationLLM(LLM): diff --git a/tests/unit/steps/tasks/test_structured_generation.py b/tests/unit/steps/tasks/test_structured_generation.py index 125b26ed37..689f18e007 100644 --- a/tests/unit/steps/tasks/test_structured_generation.py +++ b/tests/unit/steps/tasks/test_structured_generation.py @@ -18,10 +18,9 @@ from typing_extensions import override from distilabel.models.llms.base import LLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.structured_generation import StructuredGeneration -from distilabel.steps.tasks.typing import StructuredInput +from distilabel.typing import GenerateOutput, StructuredInput class DummyStructuredLLM(LLM): diff --git a/tests/unit/steps/tasks/test_text_classification.py b/tests/unit/steps/tasks/test_text_classification.py index c1bcf47e24..a3d3b0518b 100644 --- a/tests/unit/steps/tasks/test_text_classification.py +++ b/tests/unit/steps/tasks/test_text_classification.py @@ -21,8 +21,7 @@ from tests.unit.conftest import DummyAsyncLLM if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import FormattedInput + from distilabel.typing import FormattedInput, GenerateOutput class TextClassificationLLM(DummyAsyncLLM): diff --git a/tests/unit/steps/tasks/test_ultrafeedback.py b/tests/unit/steps/tasks/test_ultrafeedback.py index 3754c8803d..d94a4d8721 100644 --- a/tests/unit/steps/tasks/test_ultrafeedback.py +++ b/tests/unit/steps/tasks/test_ultrafeedback.py @@ -17,9 +17,8 @@ import pytest from distilabel.models.llms.base import LLM -from distilabel.models.llms.typing import GenerateOutput -from distilabel.steps.tasks.typing import ChatType from distilabel.steps.tasks.ultrafeedback import UltraFeedback +from distilabel.typing import ChatType, GenerateOutput class UltraFeedbackLLM(LLM): diff --git a/tests/unit/steps/test_base.py b/tests/unit/steps/test_base.py index 6e8297bb06..e3e479e5c1 100644 --- a/tests/unit/steps/test_base.py +++ b/tests/unit/steps/test_base.py @@ -24,7 +24,7 @@ from distilabel.pipeline.local import Pipeline from distilabel.steps.base import GeneratorStep, GlobalStep, Step, StepInput from distilabel.steps.decorator import step -from distilabel.steps.typing import GeneratorStepOutput, StepOutput +from distilabel.typing import GeneratorStepOutput, StepOutput from distilabel.utils.serialization import TYPE_INFO_KEY diff --git a/tests/unit/steps/test_decorator.py b/tests/unit/steps/test_decorator.py index 0071507cf0..d4b2aeb839 100644 --- a/tests/unit/steps/test_decorator.py +++ b/tests/unit/steps/test_decorator.py @@ -25,7 +25,7 @@ StepInput, ) from distilabel.steps.decorator import step -from distilabel.steps.typing import GeneratorStepOutput, StepOutput +from distilabel.typing import GeneratorStepOutput, StepOutput class TestStepDecorator: diff --git a/tests/unit/test_distiset.py b/tests/unit/test_distiset.py index 1649a2ff18..1eb47a5e96 100644 --- a/tests/unit/test_distiset.py +++ b/tests/unit/test_distiset.py @@ -236,3 +236,59 @@ def test_dataset_card(self, distiset: Distiset) -> None: "size_categories": "n<1K", "tags": ["synthetic", "distilabel", "rlaif"], } + + def test_transform_columns_to_image(self): + import numpy as np + from PIL import Image + + arr = np.random.randint(0, 255, (100, 100, 3)) + image = Image.fromarray(arr, "RGB") + from distilabel.models.image_generation.utils import image_to_str + + img_str = image_to_str(image) + + distiset_with_images = Distiset( + { + "leaf_step_1": Dataset.from_dict({"image": [img_str] * 3}), + "leaf_step_2": Dataset.from_dict( + {"image": [img_str] * 4, "column": [5, 6, 7, 8]} + ), + } + ) + distiset_with_images.transform_columns_to_image("image") + assert all( + isinstance(img, Image.Image) + for img in distiset_with_images["leaf_step_1"]["image"] + ) + assert all( + isinstance(img, Image.Image) + for img in distiset_with_images["leaf_step_2"]["image"] + ) + + distiset_with_images = Distiset( + { + "leaf_step_1": Dataset.from_dict({"image": [img_str] * 3}), + "leaf_step_2": Dataset.from_dict( + {"image": [img_str] * 4, "column": [5, 6, 7, 8]} + ), + } + ) + distiset_with_images = distiset_with_images.train_test_split(0.8) + print(distiset_with_images) + distiset_with_images.transform_columns_to_image("image") + assert all( + isinstance(img, Image.Image) + for img in distiset_with_images["leaf_step_1"]["train"]["image"] + ) + assert all( + isinstance(img, Image.Image) + for img in distiset_with_images["leaf_step_1"]["test"]["image"] + ) + assert all( + isinstance(img, Image.Image) + for img in distiset_with_images["leaf_step_2"]["train"]["image"] + ) + assert all( + isinstance(img, Image.Image) + for img in distiset_with_images["leaf_step_2"]["test"]["image"] + ) diff --git a/tests/unit/utils/test_requirements.py b/tests/unit/utils/test_requirements.py index 32ed762560..04125242ee 100644 --- a/tests/unit/utils/test_requirements.py +++ b/tests/unit/utils/test_requirements.py @@ -19,7 +19,7 @@ from distilabel.pipeline import Pipeline from distilabel.steps import Step from distilabel.steps.base import StepInput -from distilabel.steps.typing import StepOutput +from distilabel.typing import StepOutput from distilabel.utils.requirements import requirements from ..pipeline.utils import DummyGeneratorStep