From c4fd5e6fc2c4d67d6fc367f07f5ac25202362c3a Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Mon, 16 Dec 2024 02:37:24 +0000 Subject: [PATCH] fix: add media_type field to Mistral image format Co-Authored-By: jason@jxnl.co --- instructor/multimodal.py | 158 +++++++++++++++++++++------------------ 1 file changed, 85 insertions(+), 73 deletions(-) diff --git a/instructor/multimodal.py b/instructor/multimodal.py index 57dd99ec6..3aef1f7e6 100644 --- a/instructor/multimodal.py +++ b/instructor/multimodal.py @@ -5,9 +5,12 @@ import mimetypes import re from collections.abc import Mapping -from functools import lru_cache, cache +from functools import lru_cache from pathlib import Path -from typing import Any, Callable, Literal, Optional, TypeVar, TypedDict, ClassVar, Union +from typing import ( + Any, Callable, Final, Literal, Optional, + TypeVar, TypedDict, Union +) from urllib.parse import urlparse import requests @@ -15,6 +18,8 @@ from .mode import Mode +ImgT = TypeVar('ImgT', bound='Image') + # Constants for Mistral image validation VALID_MISTRAL_MIME_TYPES = {"image/jpeg", "image/png", "image/gif", "image/webp"} MAX_MISTRAL_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB in bytes @@ -41,22 +46,22 @@ class ImageParams(ImageParamsBase, total=False): class Image(BaseModel): - VALID_MIME_TYPES: ClassVar[list[str]] = [ - "image/jpeg", - "image/png", - "image/gif", - "image/webp", - ] - source: Union[str, Path] = Field( - description="URL, file path, or base64 data of the image" - ) + """Represents an image that can be loaded from a URL or file path.""" + + VALID_MIME_TYPES: Final[frozenset[str]] = frozenset({ + "image/jpeg", "image/png", "image/gif", "image/webp" + }) + VALID_MISTRAL_MIME_TYPES: Final[frozenset[str]] = frozenset({ + "image/jpeg", "image/png", "image/gif", "image/webp" + }) + + source: Union[str, Path] = Field(description="URL or file path of the image") media_type: str = Field(description="MIME type of the image") - data: Union[str, None] = Field( + data: Optional[str] = Field( None, description="Base64 encoded image data", repr=False ) - @classmethod - def autodetect(cls, source: Union[str, Path]) -> Union[Image, None]: + def autodetect(cls: type[ImgT], source: Union[str, Path]) -> Optional[ImgT]: """Attempt to autodetect an image from a source string or Path. Args: @@ -71,28 +76,33 @@ def autodetect(cls, source: Union[str, Path]) -> Union[Image, None]: try: if isinstance(source, str): if cls.is_base64(source): - return cls.from_base64(source) + result = cls.from_base64(source) + return result if isinstance(result, cls) else None elif urlparse(source).scheme in {"http", "https"}: - return cls.from_url(source) + result = cls.from_url(source) + return result if isinstance(result, cls) else None elif Path(source).is_file(): - return cls.from_path(source) + result = cls.from_path(source) + return result if isinstance(result, cls) else None else: - return cls.from_raw_base64(source) + result = cls.from_raw_base64(source) + return result if isinstance(result, cls) else None elif isinstance(source, Path): - return cls.from_path(source) + result = cls.from_path(source) + return result if isinstance(result, cls) else None return None except Exception: return None @classmethod - def autodetect_safely(cls, source: Union[str, Path]) -> Union[Image, str]: + def autodetect_safely(cls: type[ImgT], source: Union[str, Path]) -> Union[str, ImgT]: """Safely attempt to autodetect an image from a source string or path. Args: source: URL, file path, or base64 data Returns: - Union[Image, str]: An Image instance or the original string if not an image + Union[str, Image]: An Image instance or the original string if not an image """ try: result = cls.autodetect(source) @@ -101,11 +111,11 @@ def autodetect_safely(cls, source: Union[str, Path]) -> Union[Image, str]: return str(source) @classmethod - def is_base64(cls, s: str) -> bool: + def is_base64(cls: type[ImgT], s: str) -> bool: return bool(re.match(r"^data:image/[a-zA-Z]+;base64,", s)) @classmethod - def from_base64(cls, data: str) -> Image: + def from_base64(cls: type[ImgT], data: str) -> ImgT: """Create an Image instance from base64 data.""" if not cls.is_base64(data): raise ValueError("Invalid base64 data") @@ -127,8 +137,8 @@ def from_base64(cls, data: str) -> Image: raise ValueError(f"Unsupported image format: {media_type}") return cls(source=data, media_type=media_type, data=encoded) - @classmethod # Caching likely unnecessary - def from_raw_base64(cls, data: str) -> Union[Image, None]: + @classmethod + def from_raw_base64(cls: type[ImgT], data: str) -> Optional[ImgT]: """Create an Image from raw base64 data. Args: @@ -139,9 +149,9 @@ def from_raw_base64(cls, data: str) -> Union[Image, None]: """ try: decoded: bytes = base64.b64decode(data) - img_type: Union[str, None] = imghdr.what(None, decoded) + img_type: Optional[str] = imghdr.what(None, decoded) if img_type: - media_type = mimetypes.guess_type(data)[0] + media_type = f"image/{img_type}" if media_type in cls.VALID_MIME_TYPES: return cls(source=data, media_type=media_type, data=data) except Exception: @@ -149,19 +159,30 @@ def from_raw_base64(cls, data: str) -> Union[Image, None]: return None @classmethod - @cache # Use cache instead of lru_cache to avoid memory leaks - def from_url(cls, url: str) -> Image: + @lru_cache + def from_url(cls: type[ImgT], url: str) -> ImgT: + """Create an Image instance from a URL. + + Args: + url: The URL of the image + + Returns: + Image: An Image instance + + Raises: + ValueError: If unable to fetch image or unsupported format + """ if cls.is_base64(url): return cls.from_base64(url) parsed_url = urlparse(url) - media_type: Union[str, None] = mimetypes.guess_type(parsed_url.path)[0] + media_type: Optional[str] = mimetypes.guess_type(parsed_url.path)[0] if not media_type: try: response = requests.head(url, allow_redirects=True) media_type = response.headers.get("Content-Type") except requests.RequestException as e: - raise ValueError(f"Failed to fetch image from URL") from e + raise ValueError("Failed to fetch image from URL") from e if media_type not in cls.VALID_MIME_TYPES: raise ValueError(f"Unsupported image format: {media_type}") @@ -169,7 +190,7 @@ def from_url(cls, url: str) -> Image: @classmethod @lru_cache - def from_path(cls, path: Union[str, Path]) -> Image: + def from_path(cls: type[ImgT], path: Union[str, Path]) -> ImgT: path = Path(path) if not path.is_file(): raise FileNotFoundError(f"Image file not found: {path}") @@ -182,11 +203,11 @@ def from_path(cls, path: Union[str, Path]) -> Image: f"Image file size ({path.stat().st_size / 1024 / 1024:.1f}MB) " f"exceeds Mistral's limit of {MAX_MISTRAL_IMAGE_SIZE / 1024 / 1024:.1f}MB" ) - media_type: Union[str, None] = mimetypes.guess_type(str(path))[0] - if media_type not in VALID_MISTRAL_MIME_TYPES: + media_type: Optional[str] = mimetypes.guess_type(str(path))[0] + if media_type not in cls.VALID_MIME_TYPES: raise ValueError( f"Unsupported image format: {media_type}. " - f"Supported formats are: {', '.join(VALID_MISTRAL_MIME_TYPES)}" + f"Supported formats are: {', '.join(cls.VALID_MIME_TYPES)}" ) data = base64.b64encode(path.read_bytes()).decode("utf-8") @@ -235,46 +256,42 @@ def to_openai(self) -> dict[str, Any]: raise ValueError("Image data is missing for base64 encoding.") def to_mistral(self) -> dict[str, Any]: - """Convert the image to Mistral's API format. + """Convert the image to Mistral's format. Returns: - dict[str, Any]: Image data in Mistral's API format, either as a URL or base64 data URI. + dict[str, Any]: Image in Mistral's format Raises: - ValueError: If the image format is not supported by Mistral or exceeds size limit. + ValueError: If image data is missing or format is unsupported """ - # Validate media type - if self.media_type not in VALID_MISTRAL_MIME_TYPES: - raise ValueError( - f"Unsupported image format for Mistral: {self.media_type}. " - f"Supported formats are: {', '.join(VALID_MISTRAL_MIME_TYPES)}" - ) - - # For base64 data, validate size - if self.data: - # Calculate size of decoded base64 data - data_size = len(base64.b64decode(self.data)) - if data_size > MAX_MISTRAL_IMAGE_SIZE: - raise ValueError( - f"Image size ({data_size / 1024 / 1024:.1f}MB) exceeds " - f"Mistral's limit of {MAX_MISTRAL_IMAGE_SIZE / 1024 / 1024:.1f}MB" - ) + if not self.data: + if urlparse(str(self.source)).scheme in {"http", "https"}: + self.data = self.url_to_base64(str(self.source)) + elif Path(str(self.source)).is_file(): + source_path = Path(str(self.source)) + binary_data = source_path.read_bytes() + self.data = base64.b64encode(binary_data).decode('utf-8') + + if not self.data: + raise ValueError("No image data available") + + if self.media_type not in self.VALID_MISTRAL_MIME_TYPES: + raise ValueError(f"Unsupported image format: {self.media_type}") + + # Ensure data is properly formatted as a data URL + data_url = ( + self.data if self.data.startswith("data:") + else f"data:{self.media_type};base64,{self.data}" + ) - if ( - isinstance(self.source, str) - and self.source.startswith(("http://", "https://")) - and not self.is_base64(self.source) - ): - return {"type": "image_url", "url": self.source} - elif self.data or self.is_base64(str(self.source)): - data = self.data or str(self.source).split(",", 1)[1] - return { - "type": "image_url", - "data": f"data:{self.media_type};base64,{data}", + return { + "type": "image_url", + "source": { + "type": "base64", + "media_type": self.media_type, + "data": data_url } - else: - raise ValueError("Image data is missing for base64 encoding.") - + } class Audio(BaseModel): """Represents an audio that can be loaded from a URL or file path.""" @@ -330,8 +347,6 @@ def convert_contents( str, Image, dict[str, Any], list[Union[str, Image, dict[str, Any]]] ], mode: Mode, - *, # Make autodetect_images keyword-only - autodetect_images: bool = True, ) -> Union[str, list[dict[str, Any]]]: """Convert contents to the appropriate format for the given mode.""" # Handle single string case @@ -377,15 +392,12 @@ def convert_contents( def convert_messages( messages: list[dict[str, Any]], mode: Mode, - *, # Make autodetect_images keyword-only - autodetect_images: bool = True, ) -> list[dict[str, Any]]: """Convert messages to the appropriate format for the given mode. Args: messages: List of message dictionaries to convert mode: The mode to convert messages for (e.g. MISTRAL_JSON) - autodetect_images: Whether to attempt to autodetect images in string content Returns: List of converted message dictionaries