Skip to content

Commit

Permalink
feat: use sdk's new source image/mask/extra image handling
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Mar 24, 2024
1 parent 0c708d9 commit d6057cc
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ repos:
hooks:
- id: mypy
exclude: ^examples/.*$ # FIXME
additional_dependencies: [pydantic, strenum, types-colorama, types-docutils, types-Pillow, types-psutil, types-Pygments, types-pywin32, types-PyYAML, types-regex, types-requests, types-setuptools, types-tabulate, types-tqdm, types-urllib3, horde_sdk]
additional_dependencies: [pydantic, strenum, types-colorama, types-docutils, types-Pillow, types-psutil, types-Pygments, types-pywin32, types-PyYAML, types-regex, types-requests, types-setuptools, types-tabulate, types-tqdm, types-urllib3, horde_sdk==0.9.1]
25 changes: 25 additions & 0 deletions examples/run_sdk_inference_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

hordelib.initialise(setup_logging=False, logging_verbosity=5)

import asyncio
from uuid import UUID

from aiohttp import ClientSession
from horde_sdk.ai_horde_api import KNOWN_SOURCE_PROCESSING
from horde_sdk.ai_horde_api.apimodels import (
ExtraSourceImageEntry,
ImageGenerateJobPopPayload,
ImageGenerateJobPopResponse,
ImageGenerateJobPopSkippedStatus,
Expand All @@ -20,6 +23,16 @@
SharedModelManager.load_model_managers()


async def _download_images(payload: ImageGenerateJobPopResponse) -> None:
async with ClientSession() as session:
tasks = []
tasks.append(payload.async_download_source_image(session))
tasks.append(payload.async_download_source_mask(session))
tasks.append(payload.async_download_extra_source_images(session))

await asyncio.gather(*tasks)


def main():
example_response = ImageGenerateJobPopResponse(
ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))],
Expand All @@ -31,8 +44,20 @@ def main():
post_processing=["4x_AnimeSharp", "CodeFormers"],
n_iter=1,
),
source_image="https://raw.githubusercontent.com/db0/Stable-Horde/main/img_stable/0.jpg",
source_mask="https://raw.githubusercontent.com/db0/Stable-Horde/main/img_stable/1.jpg",
extra_source_images=[
ExtraSourceImageEntry(
image="https://raw.githubusercontent.com/db0/Stable-Horde/main/img_stable/2.jpg",
),
ExtraSourceImageEntry(
image="https://raw.githubusercontent.com/db0/Stable-Horde/main/img_stable/3.jpg",
),
],
)

asyncio.run(_download_images(example_response))

result = horde.basic_inference(example_response)
print(result)

Expand Down
54 changes: 52 additions & 2 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,10 +1031,60 @@ def basic_inference(
logger.debug(f"Post-processing requested: {post_processor_requested}")

sub_payload = payload.payload.model_dump()
source_image = payload.source_image
mask_image = payload.source_mask

def handle_images(
payload: ImageGenerateJobPopResponse,
image_type: str,
get_downloaded_image_func: Callable,
):
image = getattr(payload, image_type)

if image is not None and "http" in image:
image = get_downloaded_image_func()

if image is None:
logger.error(
f"{image_type.capitalize()} is a URL but wasn't downloaded, "
"this is not supported in this context. Run the `async_download_*` methods first.",
)

return None

return image

source_image = handle_images(
payload,
"source_image",
payload.get_downloaded_source_image,
)
if source_image is None:
logger.info("No source image found in payload.")

mask_image = handle_images(
payload,
"source_mask",
payload.get_downloaded_source_mask,
)
if mask_image is None:
logger.info("No mask image found in payload.")

extra_source_images = payload.extra_source_images

if extra_source_images is not None:
extra_source_images = payload.get_downloaded_extra_source_images()
if extra_source_images is not None:
logger.info(f"Using {len(extra_source_images)} downloaded extra source images.")
else:
logger.info("No extra source images found in payload.")

esi_to_remove = []
if extra_source_images is not None:
for esi in extra_source_images:
if "http" in esi.image:
logger.warning("Extra source image is a URL, this is not supported in this context.")
esi_to_remove.append(esi)

extra_source_images = [esi for esi in extra_source_images if esi not in esi_to_remove]
# If its a base64 encoded image, decode it
if isinstance(source_image, str):
try:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Add this in for tox, comment out for build
--extra-index-url https://download.pytorch.org/whl/cu121
horde_sdk>=0.8.3
horde_sdk>=0.9.1
horde_model_reference>=0.5.2
pydantic
torch>=2.1.0
Expand Down

0 comments on commit d6057cc

Please sign in to comment.