Skip to content

Commit

Permalink
Merge pull request #2 from sndrtj/sander/june-2024-update
Browse files Browse the repository at this point in the history
Sander/june 2024 update
  • Loading branch information
sndrtj authored Jun 1, 2024
2 parents 08682e1 + 817e224 commit 3269327
Show file tree
Hide file tree
Showing 15 changed files with 1,432 additions and 1,459 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
python-version: '3.10'
cache: 'poetry'
- run: poetry install
- run: poetry run black --check src/ tests/
- run: poetry run ruff format --check src/ tests/
- run: poetry run mypy src/ tests/
- run: poetry run ruff check src/ tests/
- run: poetry run py.test tests/
18 changes: 6 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ repos:
- id: check-added-large-files
- repo: local
hooks:
- id: isort
name: isort
entry: poetry run isort
- id: ruff
name: ruff
entry: poetry run ruff check --fix .
language: system
types: [python]
require_serial: true
- id: black
name: black
entry: poetry run black
- id: ruff format
name: ruff format
entry: poetry run ruff format .
language: system
types: [python]
require_serial: true
Expand All @@ -29,9 +29,3 @@ repos:
language: system
types: [python]
require_serial: true
- id: ruff
name: ruff
entry: poetry run ruff
language: system
types: [python]
require_serial: true
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ To start a worker, run the following in the virtual environment
droombot worker
```

## How to use in discord

Use the `/prompt` command to type your prompt. This by default uses whatever model
Stability AI now considers its "core". You can also select Stable Diffusion 3 and
Stable Diffusion 3 Turbo models by appending `-m sd3` or `-m sd3-turbo` to your
prompt.

You can give individual words in your prompt more some weight by doing something like
the following;
``A table with (red:0.5) raspberries and (purple:0.5) blueberries.``

## Components

Droombot consists of two components:
Expand Down Expand Up @@ -65,7 +76,7 @@ All configuration is handled via environment variables. See the following table

## Container

Droombot can run as a container. For a howto using Docker or Podman, see the
Droombot can run as a container. For a howto using Docker or Podman, see the
[container docs](docs/containers.md).

## Future plans
Expand Down
2,480 changes: 1,218 additions & 1,262 deletions poetry.lock

Large diffs are not rendered by default.

14 changes: 6 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "droombot"
version = "0.2.0a0"
version = "0.2.0"
description = "A Discord Bot for generating images from text prompts"
authors = ["Sander Bollen <sander@sndrtj.eu>"]
readme = "README.md"
Expand All @@ -17,18 +17,17 @@ classifiers = [
[tool.poetry.dependencies]
python = "^3.10"
aiohttp = {version ="^3.8.4", extras = ['speedups'] }
pydantic = "^1.10.7"
pydantic = "^2.7.2"
click = "^8.1.3"
py-cord = {extras = ["speed"], version = "^2.4.1"}
redis = {extras = ["hiredis"], version = "^4.5.4"}
aiolimiter = "^1.0.0"
typing-extensions = "^4.12.0"


[tool.poetry.group.dev.dependencies]
black = "^23.1.0"
mypy = "^1.1.1"
isort = "^5.12.0"
ruff = "^0.0.259"
ruff = "^0.4.7"
pytest = "^7.2.2"
pre-commit = "^3.2.1"
types-redis = "^4.5.4.1"
Expand All @@ -41,6 +40,5 @@ droombot = 'droombot.cli:cli'
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"


[tool.isort]
profile = "black"
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "I"]
73 changes: 42 additions & 31 deletions src/droombot/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Sander Bollen
# Copyright 2023-2024 Sander Bollen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,22 +13,30 @@
# limitations under the License.

import asyncio
import json
import logging

import aiohttp

from .config import STABILITY_API_KEY
from .models import TextToImageRequest, TextToImageResponse
from .models import (
TextToImageRequestV2Core,
TextToImageRequestV2SD3,
TextToImageResponse,
)
from .version import VERSION

logger = logging.getLogger(__name__)

TEX_TO_IMAGE_BASE_URL = "https://api.stability.ai/v1/generation"
CORE_TEXT_TO_IMAGE_BASE_URL = (
"https://api.stability.ai/v2beta/stable-image/generate/core"
)
SD3_TEXT_TO_IMAGE_BASE_URL = "https://api.stability.ai/v2beta/stable-image/generate/sd3"


async def text_to_image(
session: aiohttp.ClientSession, request: TextToImageRequest, timeout: int = 300
session: aiohttp.ClientSession,
request: TextToImageRequestV2Core | TextToImageRequestV2SD3,
timeout: int = 300,
) -> list[TextToImageResponse]:
"""Call Stability with a text to image request
Expand All @@ -39,28 +47,36 @@ async def text_to_image(
:return: list of responses, one for each text prompt
:raises: timeout
"""
logger.info(
"Incoming call for text-to-image generation with "
f"engine id {request.engine_id}"
)
match request:
case TextToImageRequestV2Core():
logger.info("Incoming call for text-to-image generation for Core")
url = CORE_TEXT_TO_IMAGE_BASE_URL
case TextToImageRequestV2SD3():
logger.info("Incoming call for text-to-image generation for SD3")
url = SD3_TEXT_TO_IMAGE_BASE_URL
case _:
raise ValueError(f"Unsupported request: {type(request)}")

user_agent = f"droombot/{VERSION}"
headers = {"Authorization": f"Bearer {STABILITY_API_KEY}", "User-Agent": user_agent}
url = f"{TEX_TO_IMAGE_BASE_URL}/{request.engine_id}/text-to-image"
logger.debug(f"Generated url: {url}")

# FIXME: need to load json serialized because enums.
raw_post_data = json.loads(request.json())
# need to filter out engine_id and sampler if it is none
post_data = {}
headers = {
"Authorization": f"Bearer {STABILITY_API_KEY}",
"User-Agent": user_agent,
"accept": "application/json",
}

raw_post_data = request.model_dump(mode="json")
writer = aiohttp.MultipartWriter("form-data")
# filter out values that are None
for k, v in raw_post_data.items():
if k == "engine_id":
if v is None:
continue
if k == "sampler" and v is None:
continue
post_data[k] = v
writer.append(
aiohttp.StringPayload(value=str(v)),
headers={aiohttp.hdrs.CONTENT_DISPOSITION: f'form-data; name="{k}"'},
)

async with session.post(
url=url, json=post_data, timeout=timeout, headers=headers
url=url, data=writer, timeout=timeout, headers=headers
) as resp:
results = await resp.json()
if resp.status >= 400:
Expand All @@ -70,23 +86,18 @@ async def text_to_image(
# responses are encoded in an 'artifacts' item, but this is NOT
# mentioned in the docs.
logger.info("Received a successful response from text-to-image generation.")
return [TextToImageResponse.from_raw_api(r) for r in results["artifacts"]]
return [TextToImageResponse.from_raw_api(results)]


if __name__ == "__main__":
# for testing
async def main():
request = TextToImageRequest(
text_prompts=[
{
"text": "Green trees in a forest with ferns, oil painting",
"weight": 0.5,
}
],
request = TextToImageRequestV2Core(
prompt="Green trees in a forest with ferns, oil painting"
)
async with aiohttp.ClientSession() as session:
results = await text_to_image(session, request)

print(results)
print(results[0].model_dump_json())

asyncio.run(main())
8 changes: 5 additions & 3 deletions src/droombot/bot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Sander Bollen
# Copyright 2023-2024 Sander Bollen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -55,7 +55,9 @@ async def poll_interaction_result(
continue

logger.debug("Found results, parsing to response")
return pydantic.parse_raw_as(list[TextToImageResponse], raw_results)
return pydantic.TypeAdapter(list[TextToImageResponse]).validate_json(
raw_results
)


def create_bot() -> discord.Bot:
Expand Down Expand Up @@ -96,7 +98,7 @@ async def prompt(ctx, text: str):
interaction_id=str(ctx.interaction.id), text_prompt=text
)

await redis_connection.publish("droombot-prompts", message.json())
await redis_connection.publish("droombot-prompts", message.model_dump_json())
logger.info("Polling for result")

try:
Expand Down
2 changes: 1 addition & 1 deletion src/droombot/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Sander Bollen
# Copyright 2023-2024 Sander Bollen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion src/droombot/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Sander Bollen
# Copyright 2023-2024 Sander Bollen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion src/droombot/log.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Sander Bollen
# Copyright 2023-2024 Sander Bollen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Loading

0 comments on commit 3269327

Please sign in to comment.