Skip to content

Commit

Permalink
Allow model selection with -m / --model
Browse files Browse the repository at this point in the history
  • Loading branch information
sndrtj committed Jun 1, 2024
1 parent aff2ab9 commit 168c050
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 5 deletions.
24 changes: 21 additions & 3 deletions src/droombot/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import enum
import logging
import shlex
from typing import Annotated, Literal

import pydantic

logger = logging.getLogger(__name__)


class FinishReason(enum.Enum):
CONTENT_FILTERED = "CONTENT_FILTERED"
Expand Down Expand Up @@ -108,5 +112,19 @@ def pubsub_to_t2i(
:param message: the message
:return: text to image request
"""
# FIXME: core only for now...
return TextToImageRequestV2Core(prompt=message.text_prompt)
parser = argparse.ArgumentParser(exit_on_error=False)
parser.add_argument(
"-m", "--model", choices=["core", "sd3", "sd3-turbo"], default="core"
)

# we want to consider everything before a `-` as the text prompt, without quoting.
prompt, maybe_dash, options = message.text_prompt.partition("-")

args, failures = parser.parse_known_args(shlex.split(maybe_dash + options), None)
if failures:
logger.warning(f"Unrecognized arguments: {''.join(failures)}, ignoring...")

if args.model == "core":
return TextToImageRequestV2Core(prompt=prompt.strip())

return TextToImageRequestV2SD3(prompt=prompt.strip(), model=args.model)
35 changes: 33 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import pydantic
import pytest
from droombot.models import TextToImageRequestV2Core, TextToImageResponse
from droombot.models import (
PubSubMessage,
TextToImageRequestV2Core,
TextToImageRequestV2SD3,
TextToImageResponse,
pubsub_to_t2i,
)


def test_generation_text_prompts_not_defined():
Expand All @@ -11,7 +17,32 @@ def test_generation_text_prompts_not_defined():


def test_response_from_raw_api():
example_response = {"base64": "foo", "finishReason": "SUCCESS", "seed": 0}
example_response = {"image": "foo", "finish_reason": "SUCCESS", "seed": 0}
assert TextToImageResponse.from_raw_api(example_response) == TextToImageResponse(
base64="foo", finish_reason="SUCCESS", seed=0
)


PUBSUB_TO_T2I_DATA = [
(
PubSubMessage(interaction_id="0", text_prompt="foo bar"),
TextToImageRequestV2Core(prompt="foo bar"),
),
(
PubSubMessage(interaction_id="0", text_prompt="foo bar -m sd3"),
TextToImageRequestV2SD3(prompt="foo bar"),
),
(
PubSubMessage(interaction_id="0", text_prompt="foo bar -m sd3-turbo"),
TextToImageRequestV2SD3(prompt="foo bar", model="sd3-turbo"),
),
(
PubSubMessage(interaction_id="0", text_prompt="foo bar --model sd3"),
TextToImageRequestV2SD3(prompt="foo bar"),
),
]


@pytest.mark.parametrize("message, t2i_request", PUBSUB_TO_T2I_DATA)
def test_request_from_pubsub(message, t2i_request):
assert pubsub_to_t2i(message) == t2i_request

0 comments on commit 168c050

Please sign in to comment.