Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image Language Models and ImageGeneration task #1060

Merged
merged 57 commits into from
Jan 15, 2025
Merged

Conversation

plaguss
Copy link
Contributor

@plaguss plaguss commented Nov 14, 2024

Description

This PR adds a new module to models: models/image_generation to store image models (InferenceEndpointsImageGeneration and OpenAIImageGeneration), with 2 new base classes: ImageGenerationModel and AsyncImageGenerationModel, and a new ImageGeneration task.

Sample pipeline and dataset. Take into account the distiset.transform_columns_to_image method, necessary to push the dataset with the images as objects instead of strings.

from datasets import load_dataset

from distilabel.models.image_generation import InferenceEndpointsImageGeneration
from distilabel.pipeline import Pipeline
from distilabel.steps import KeepColumns
from distilabel.steps.tasks import ImageGeneration

ds = load_dataset("dvilasuero/finepersonas-v0.1-tiny", split="train").select(range(3))

with Pipeline(name="image_generation_pipeline") as pipeline:
    igm = InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell")

    img_generation = ImageGeneration(
        name="flux_schnell", image_generation_model=igm, input_mappings={"prompt": "persona"}
    )

    keep_columns = KeepColumns(columns=["persona", "model_name", "image"])

    img_generation >> keep_columns


if __name__ == "__main__":
    distiset = pipeline.run(use_cache=False, dataset=ds)
    # Save the images as `PIL.Image.Image`
    distiset = distiset.transform_columns_to_image("image")
    distiset.push_to_hub("plaguss/test-finepersonas-v0.1-tiny-flux-schnell")

@plaguss plaguss added the enhancement New feature or request label Nov 14, 2024
@plaguss plaguss added this to the 1.5.0 milestone Nov 14, 2024
@plaguss plaguss self-assigned this Nov 14, 2024
@plaguss plaguss requested a review from gabrielmbmb November 14, 2024 11:58
Copy link

Documentation for this PR has been built. You can view it at: https://distilabel.argilla.io/pr-1060/

Copy link

codspeed-hq bot commented Nov 14, 2024

CodSpeed Performance Report

Merging #1060 will not alter performance

Comparing vision-language-models (7debafd) with develop (e866345)

Summary

✅ 1 untouched benchmarks

@plaguss plaguss marked this pull request as ready for review November 15, 2024 08:24
@plaguss plaguss requested a review from dvsrepo November 15, 2024 11:51
Copy link
Member

@gabrielmbmb gabrielmbmb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! I think we need to fix some issues related to inheritance, but maybe we can tackle those in a separate PRs before the release.

docs/api/models/image_generation/index.md Outdated Show resolved Hide resolved
docs/api/task/image_task.md Outdated Show resolved Hide resolved
docs/sections/how_to_guides/advanced/distiset.md Outdated Show resolved Hide resolved
Comment on lines 156 to 195
def get_runtime_parameters_info(self) -> list["RuntimeParameterInfo"]:
"""Gets the information of the runtime parameters of the `ImageGenerationModel` such as the name
and the description. This function is meant to include the information of the runtime
parameters in the serialized data of the `ImageGenerationModel`.

Returns:
A list containing the information for each runtime parameter of the `ImageGenerationModel`.
"""
runtime_parameters_info = super().get_runtime_parameters_info()

generation_kwargs_info = next(
(
runtime_parameter_info
for runtime_parameter_info in runtime_parameters_info
if runtime_parameter_info["name"] == "generation_kwargs"
),
None,
)

# If `generation_kwargs` attribute is present, we need to include the `generate`
# method arguments as the information for this attribute.
if generation_kwargs_info:
generate_docstring_args = self.generate_parsed_docstring["args"]
generation_kwargs_info["keys"] = []
# TODO: This doesn't happen with LLM, but with ImageGenerationModel the optional key is not found
# in a pipeline, due to some bug. For the moment this does the job. It may be
# related to the InferenceEndpointsImageGeneration for example being both
# ImageGenerationModel and InferenceEdnpointsLLM, but cannot find the point that makes the
# error appear.
if "optional" not in generation_kwargs_info:
return runtime_parameters_info
for key, value in generation_kwargs_info["optional"].items():
info = {"name": key, "optional": value}
if description := generate_docstring_args.get(key):
info["description"] = description
generation_kwargs_info["keys"].append(info)

generation_kwargs_info.pop("optional")

return runtime_parameters_info
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue here is that when calling super().get_runtime_parameters_info(), the InferenceEndpointsLLM.get_runtime_parameters_info method is being called instead of the one from RuntimeParametersMixin class, which is returning the dictionary without the optional key (already popped). I guess similar it's happening with OpenAI class.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the inheritance from InferenceEndpointsLLM is a bit messy, not only in this method but also in other parts of the class. For example:

from distilabel.models import InferenceEndpointsImageGeneration, InferenceEndpointsLLM

igm = InferenceEndpointsImageGeneration()

This should raise a ValidationError because of the validator only_one_of_model_id_endpoint_name_or_base_url_provided (no endpoint name or model id), but instead gives TypeError:

TypeError: ValidationError.__new__() missing 1 required positional argument: 
'line_errors'

which is caused by the multiple inheritance I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should:

  1. Move extended methods from RuntimeParametersMixins like runtime_parameters_names, get_runtime_parameters_info and generate_parsed_docstring to another class that then we can use in the base LLM, ImageGenerationModel and the future base classes to come.
  2. Just duplicate the code for classes that uses a client like inference endpoints or openai, because inheriting from the LLM class it's messy or can get messy. Maybe we can create a base class that offers the client functionality for openai, inference endpoints, etc and then use it in the LLMs, ImageGenerationModel, etc.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@plaguss plaguss requested a review from gabrielmbmb January 15, 2025 09:50
Copy link
Member

@gabrielmbmb gabrielmbmb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

src/distilabel/mixins/runtime_parameters.py Outdated Show resolved Hide resolved
src/distilabel/models/image_generation/base.py Outdated Show resolved Hide resolved
@plaguss plaguss merged commit 5257600 into develop Jan 15, 2025
8 checks passed
@plaguss plaguss deleted the vision-language-models branch January 15, 2025 11:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants