diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6c3b6439..05627545 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,4 +19,4 @@ repos: hooks: - id: mypy exclude: ^examples/.*$ # FIXME - additional_dependencies: [pydantic, strenum, types-colorama, types-docutils, types-Pillow, types-psutil, types-Pygments, types-pywin32, types-PyYAML, types-regex, types-requests, types-setuptools, types-tabulate, types-tqdm, types-urllib3, horde_sdk] + additional_dependencies: [pydantic, strenum, types-colorama, types-docutils, types-Pillow, types-psutil, types-Pygments, types-pywin32, types-PyYAML, types-regex, types-requests, types-setuptools, types-tabulate, types-tqdm, types-urllib3, horde_sdk==0.9.1] diff --git a/examples/run_sdk_inference_example.py b/examples/run_sdk_inference_example.py index 4979bceb..98f97cc2 100644 --- a/examples/run_sdk_inference_example.py +++ b/examples/run_sdk_inference_example.py @@ -2,10 +2,13 @@ hordelib.initialise(setup_logging=False, logging_verbosity=5) +import asyncio from uuid import UUID +from aiohttp import ClientSession from horde_sdk.ai_horde_api import KNOWN_SOURCE_PROCESSING from horde_sdk.ai_horde_api.apimodels import ( + ExtraSourceImageEntry, ImageGenerateJobPopPayload, ImageGenerateJobPopResponse, ImageGenerateJobPopSkippedStatus, @@ -20,6 +23,16 @@ SharedModelManager.load_model_managers() +async def _download_images(payload: ImageGenerateJobPopResponse) -> None: + async with ClientSession() as session: + tasks = [] + tasks.append(payload.async_download_source_image(session)) + tasks.append(payload.async_download_source_mask(session)) + tasks.append(payload.async_download_extra_source_images(session)) + + await asyncio.gather(*tasks) + + def main(): example_response = ImageGenerateJobPopResponse( ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))], @@ -31,8 +44,20 @@ def main(): post_processing=["4x_AnimeSharp", "CodeFormers"], n_iter=1, ), + source_image="https://raw.githubusercontent.com/db0/Stable-Horde/main/img_stable/0.jpg", + source_mask="https://raw.githubusercontent.com/db0/Stable-Horde/main/img_stable/1.jpg", + extra_source_images=[ + ExtraSourceImageEntry( + image="https://raw.githubusercontent.com/db0/Stable-Horde/main/img_stable/2.jpg", + ), + ExtraSourceImageEntry( + image="https://raw.githubusercontent.com/db0/Stable-Horde/main/img_stable/3.jpg", + ), + ], ) + asyncio.run(_download_images(example_response)) + result = horde.basic_inference(example_response) print(result) diff --git a/hordelib/horde.py b/hordelib/horde.py index f36300f2..2002bc09 100644 --- a/hordelib/horde.py +++ b/hordelib/horde.py @@ -1031,10 +1031,60 @@ def basic_inference( logger.debug(f"Post-processing requested: {post_processor_requested}") sub_payload = payload.payload.model_dump() - source_image = payload.source_image - mask_image = payload.source_mask + + def handle_images( + payload: ImageGenerateJobPopResponse, + image_type: str, + get_downloaded_image_func: Callable, + ): + image = getattr(payload, image_type) + + if image is not None and "http" in image: + image = get_downloaded_image_func() + + if image is None: + logger.error( + f"{image_type.capitalize()} is a URL but wasn't downloaded, " + "this is not supported in this context. Run the `async_download_*` methods first.", + ) + + return None + + return image + + source_image = handle_images( + payload, + "source_image", + payload.get_downloaded_source_image, + ) + if source_image is None: + logger.info("No source image found in payload.") + + mask_image = handle_images( + payload, + "source_mask", + payload.get_downloaded_source_mask, + ) + if mask_image is None: + logger.info("No mask image found in payload.") + extra_source_images = payload.extra_source_images + if extra_source_images is not None: + extra_source_images = payload.get_downloaded_extra_source_images() + if extra_source_images is not None: + logger.info(f"Using {len(extra_source_images)} downloaded extra source images.") + else: + logger.info("No extra source images found in payload.") + + esi_to_remove = [] + if extra_source_images is not None: + for esi in extra_source_images: + if "http" in esi.image: + logger.warning("Extra source image is a URL, this is not supported in this context.") + esi_to_remove.append(esi) + + extra_source_images = [esi for esi in extra_source_images if esi not in esi_to_remove] # If its a base64 encoded image, decode it if isinstance(source_image, str): try: diff --git a/requirements.txt b/requirements.txt index 6dcf9a55..4ff5caa0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Add this in for tox, comment out for build --extra-index-url https://download.pytorch.org/whl/cu121 -horde_sdk>=0.8.3 +horde_sdk>=0.9.1 horde_model_reference>=0.5.2 pydantic torch>=2.1.0