-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Image Language Models and ImageGeneration
task
#1060
Conversation
Documentation for this PR has been built. You can view it at: https://distilabel.argilla.io/pr-1060/ |
CodSpeed Performance ReportMerging #1060 will not alter performanceComparing Summary
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool! I think we need to fix some issues related to inheritance, but maybe we can tackle those in a separate PRs before the release.
def get_runtime_parameters_info(self) -> list["RuntimeParameterInfo"]: | ||
"""Gets the information of the runtime parameters of the `ImageGenerationModel` such as the name | ||
and the description. This function is meant to include the information of the runtime | ||
parameters in the serialized data of the `ImageGenerationModel`. | ||
|
||
Returns: | ||
A list containing the information for each runtime parameter of the `ImageGenerationModel`. | ||
""" | ||
runtime_parameters_info = super().get_runtime_parameters_info() | ||
|
||
generation_kwargs_info = next( | ||
( | ||
runtime_parameter_info | ||
for runtime_parameter_info in runtime_parameters_info | ||
if runtime_parameter_info["name"] == "generation_kwargs" | ||
), | ||
None, | ||
) | ||
|
||
# If `generation_kwargs` attribute is present, we need to include the `generate` | ||
# method arguments as the information for this attribute. | ||
if generation_kwargs_info: | ||
generate_docstring_args = self.generate_parsed_docstring["args"] | ||
generation_kwargs_info["keys"] = [] | ||
# TODO: This doesn't happen with LLM, but with ImageGenerationModel the optional key is not found | ||
# in a pipeline, due to some bug. For the moment this does the job. It may be | ||
# related to the InferenceEndpointsImageGeneration for example being both | ||
# ImageGenerationModel and InferenceEdnpointsLLM, but cannot find the point that makes the | ||
# error appear. | ||
if "optional" not in generation_kwargs_info: | ||
return runtime_parameters_info | ||
for key, value in generation_kwargs_info["optional"].items(): | ||
info = {"name": key, "optional": value} | ||
if description := generate_docstring_args.get(key): | ||
info["description"] = description | ||
generation_kwargs_info["keys"].append(info) | ||
|
||
generation_kwargs_info.pop("optional") | ||
|
||
return runtime_parameters_info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue here is that when calling super().get_runtime_parameters_info()
, the InferenceEndpointsLLM.get_runtime_parameters_info
method is being called instead of the one from RuntimeParametersMixin
class, which is returning the dictionary without the optional
key (already popped). I guess similar it's happening with OpenAI
class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the inheritance from InferenceEndpointsLLM
is a bit messy, not only in this method but also in other parts of the class. For example:
from distilabel.models import InferenceEndpointsImageGeneration, InferenceEndpointsLLM
igm = InferenceEndpointsImageGeneration()
This should raise a ValidationError
because of the validator only_one_of_model_id_endpoint_name_or_base_url_provided
(no endpoint name or model id), but instead gives TypeError
:
TypeError: ValidationError.__new__() missing 1 required positional argument:
'line_errors'
which is caused by the multiple inheritance I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should:
- Move extended methods from
RuntimeParametersMixins
likeruntime_parameters_names
,get_runtime_parameters_info
andgenerate_parsed_docstring
to another class that then we can use in the baseLLM
,ImageGenerationModel
and the future base classes to come. - Just duplicate the code for classes that uses a client like inference endpoints or openai, because inheriting from the
LLM
class it's messy or can get messy. Maybe we can create a base class that offers the client functionality for openai, inference endpoints, etc and then use it in theLLM
s,ImageGenerationModel
, etc.
for more information, see https://pre-commit.ci
Co-authored-by: Gabriel Martín Blázquez <gmartinbdev@gmail.com>
Co-authored-by: Gabriel Martín Blázquez <gmartinbdev@gmail.com>
Co-authored-by: Gabriel Martín Blázquez <gmartinbdev@gmail.com>
…fferent types of client based models
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Co-authored-by: Gabriel Martín Blázquez <gmartinbdev@gmail.com>
Description
This PR adds a new module to models:
models/image_generation
to store image models (InferenceEndpointsImageGeneration
andOpenAIImageGeneration
), with 2 new base classes:ImageGenerationModel
andAsyncImageGenerationModel
, and a newImageGeneration
task.Sample pipeline and dataset. Take into account the
distiset.transform_columns_to_image
method, necessary to push the dataset with the images as objects instead of strings.