|  | 
| 14 | 14 | 
 | 
| 15 | 15 | 
 | 
| 16 | 16 | import base64 | 
| 17 |  | -from typing import Any, Callable, Union | 
|  | 17 | +import os | 
|  | 18 | +from io import BytesIO | 
|  | 19 | +from typing import Callable, Union | 
| 18 | 20 | 
 | 
| 19 |  | -import cv2 | 
| 20 | 21 | import numpy as np | 
| 21 | 22 | import requests | 
|  | 23 | +from PIL import Image as PILImage | 
| 22 | 24 | from PIL.Image import Image | 
| 23 | 25 | 
 | 
| 24 | 26 | 
 | 
| 25 | 27 | def preprocess_image( | 
| 26 |  | -    image: Union[Image, str, bytes, np.ndarray[Any, np.dtype[np.uint8 | np.float_]]], | 
| 27 |  | -    encoding_function: Callable[[Any], str] = lambda x: base64.b64encode(x).decode( | 
|  | 28 | +    image: Union[Image, str, bytes, np.ndarray], | 
|  | 29 | +    encoding_function: Callable[[bytes], str] = lambda b: base64.b64encode(b).decode( | 
| 28 | 30 |         "utf-8" | 
| 29 | 31 |     ), | 
| 30 | 32 | ) -> str: | 
| 31 |  | -    if isinstance(image, Image): | 
| 32 |  | -        image = np.array(image) | 
| 33 |  | -        _, image_data = cv2.imencode(".png", image) | 
| 34 |  | -        encoding_function = lambda x: base64.b64encode(x).decode("utf-8") | 
| 35 |  | -    elif isinstance(image, str) and image.startswith(("http://", "https://")): | 
| 36 |  | -        response = requests.get(image) | 
| 37 |  | -        response.raise_for_status() | 
| 38 |  | -        image_data = response.content | 
| 39 |  | -    elif isinstance(image, str): | 
| 40 |  | -        with open(image, "rb") as image_file: | 
| 41 |  | -            image_data = image_file.read() | 
| 42 |  | -    elif isinstance(image, bytes): | 
| 43 |  | -        image_data = image | 
| 44 |  | -        encoding_function = lambda x: x.decode("utf-8") | 
| 45 |  | -    elif isinstance(image, np.ndarray):  # type: ignore | 
| 46 |  | -        if image.dtype == np.float32 or image.dtype == np.float64: | 
| 47 |  | -            image = (image * 255).astype(np.uint8) | 
| 48 |  | -        _, image_data = cv2.imencode(".png", image) | 
| 49 |  | -        encoding_function = lambda x: base64.b64encode(x).decode("utf-8") | 
| 50 |  | -    else: | 
| 51 |  | -        image_data = image | 
| 52 |  | - | 
| 53 |  | -    return encoding_function(image_data) | 
|  | 33 | +    """Convert various image inputs into a base64-encoded PNG string. | 
|  | 34 | +
 | 
|  | 35 | +    Parameters | 
|  | 36 | +    ---------- | 
|  | 37 | +    image : PIL.Image.Image or str or bytes or numpy.ndarray | 
|  | 38 | +        Supported inputs: | 
|  | 39 | +        - PIL Image | 
|  | 40 | +        - Path to a file, ``file://`` URL, or HTTP(S) URL | 
|  | 41 | +        - Raw bytes containing image data | 
|  | 42 | +        - ``numpy.ndarray`` with dtype ``uint8`` or ``float32``/``float64``; | 
|  | 43 | +          grayscale or 3/4-channel arrays are supported. | 
|  | 44 | +    encoding_function : callable, optional | 
|  | 45 | +        Function that converts PNG bytes to the final string representation. | 
|  | 46 | +        By default, returns base64-encoded UTF-8 string. | 
|  | 47 | +
 | 
|  | 48 | +    Returns | 
|  | 49 | +    ------- | 
|  | 50 | +    str | 
|  | 51 | +        Base64-encoded PNG string. | 
|  | 52 | +
 | 
|  | 53 | +    Raises | 
|  | 54 | +    ------ | 
|  | 55 | +    FileNotFoundError | 
|  | 56 | +        If a file path (or ``file://`` URL) does not exist. | 
|  | 57 | +    TypeError | 
|  | 58 | +        If the input type is not supported. | 
|  | 59 | +    requests.HTTPError | 
|  | 60 | +        If fetching an HTTP(S) URL fails with a non-2xx response. | 
|  | 61 | +    requests.RequestException | 
|  | 62 | +        If a network error occurs while fetching an HTTP(S) URL. | 
|  | 63 | +    OSError | 
|  | 64 | +        If the input cannot be decoded as an image by Pillow. | 
|  | 65 | +
 | 
|  | 66 | +    Notes | 
|  | 67 | +    ----- | 
|  | 68 | +    - All inputs are decoded and re-encoded to PNG to guarantee consistent output. | 
|  | 69 | +    - Float arrays are assumed to be in [0, 1] and are scaled to ``uint8``. | 
|  | 70 | +    - Network requests use a timeout of ``(5, 15)`` seconds (connect, read). | 
|  | 71 | +
 | 
|  | 72 | +    Examples | 
|  | 73 | +    -------- | 
|  | 74 | +    >>> b64_png = preprocess_image("path/to/image.jpg") | 
|  | 75 | +    >>> import numpy as np | 
|  | 76 | +    >>> arr = np.random.rand(64, 64, 3).astype(np.float32) | 
|  | 77 | +    >>> b64_png = preprocess_image(arr) | 
|  | 78 | +    """ | 
|  | 79 | + | 
|  | 80 | +    def _to_pil_from_ndarray(arr: np.ndarray) -> Image: | 
|  | 81 | +        a = arr | 
|  | 82 | +        if a.dtype in (np.float32, np.float64): | 
|  | 83 | +            a = np.clip(a, 0.0, 1.0) | 
|  | 84 | +            a = (a * 255.0).round().astype(np.uint8) | 
|  | 85 | +        a = np.ascontiguousarray(a) | 
|  | 86 | +        return PILImage.fromarray(a) | 
|  | 87 | + | 
|  | 88 | +    def _ensure_pil(img: Union[Image, str, bytes, np.ndarray]) -> Image: | 
|  | 89 | +        if isinstance(img, Image): | 
|  | 90 | +            return img | 
|  | 91 | +        if isinstance(img, np.ndarray):  # type: ignore | 
|  | 92 | +            return _to_pil_from_ndarray(img) | 
|  | 93 | +        if isinstance(img, str): | 
|  | 94 | +            if img.startswith(("http://", "https://")): | 
|  | 95 | +                response = requests.get(img, timeout=(5, 15)) | 
|  | 96 | +                response.raise_for_status() | 
|  | 97 | +                return PILImage.open(BytesIO(response.content)) | 
|  | 98 | +            if img.startswith("file://"): | 
|  | 99 | +                file_path = img[len("file://") :] | 
|  | 100 | +            else: | 
|  | 101 | +                # fallback to file path if not marked with file:// | 
|  | 102 | +                file_path = img | 
|  | 103 | +            if not os.path.exists(file_path): | 
|  | 104 | +                raise FileNotFoundError(f"File not found: {file_path}") | 
|  | 105 | +            return PILImage.open(file_path) | 
|  | 106 | +        if isinstance(img, bytes): | 
|  | 107 | +            return PILImage.open(BytesIO(img)) | 
|  | 108 | +        raise TypeError(f"Unsupported image type: {type(img).__name__}") | 
|  | 109 | + | 
|  | 110 | +    pil_image = _ensure_pil(image) | 
|  | 111 | + | 
|  | 112 | +    # Normalize to PNG bytes | 
|  | 113 | +    with BytesIO() as buffer: | 
|  | 114 | +        pil_image.save(buffer, format="PNG") | 
|  | 115 | +        png_bytes = buffer.getvalue() | 
|  | 116 | + | 
|  | 117 | +    return encoding_function(png_bytes) | 
0 commit comments