diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..17f821d --- /dev/null +++ b/Dockerfile @@ -0,0 +1,51 @@ +FROM registry.access.redhat.com/ubi9/ubi-minimal:latest as base + +RUN microdnf update -y && \ + microdnf install -y \ + python3-devel python-pip && \ + pip install --upgrade --no-cache-dir pip wheel && \ + microdnf clean all + +FROM base as builder +WORKDIR /build + +RUN pip install --no-cache tox +COPY README.md . +COPY pyproject.toml . +COPY tox.ini . +COPY caikit_computer_vision caikit_computer_vision +# .git is required for setuptools-scm get the version +RUN --mount=source=.git,target=.git,type=bind \ + --mount=type=cache,target=/root/.cache/pip \ + tox -e build + + +FROM base as deploy + +RUN python -m venv --upgrade-deps /opt/caikit/ + +ENV VIRTUAL_ENV=/opt/caikit +ENV PATH="$VIRTUAL_ENV/bin:$PATH" + +COPY --from=builder /build/dist/caikit_computer_vision*.whl /tmp/ +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install /tmp/caikit_computer_vision*.whl && \ + rm /tmp/caikit_computer_vision*.whl + +COPY LICENSE /opt/caikit/ +COPY README.md /opt/caikit/ + +RUN groupadd --system caikit --gid 1001 && \ + adduser --system --uid 1001 --gid 0 --groups caikit \ + --home-dir /caikit --shell /sbin/nologin \ + --comment "Caikit User" caikit + +USER caikit + +ENV RUNTIME_LIBRARY=caikit_computer_vision +# Optional: use `CONFIG_FILES` and the /caikit/ volume to explicitly provide a configuration file and models +# ENV CONFIG_FILES=/caikit/caikit.yml +VOLUME ["/caikit/"] +WORKDIR /caikit + +CMD ["python"] \ No newline at end of file diff --git a/caikit_computer_vision/data_model/__init__.py b/caikit_computer_vision/data_model/__init__.py index fc1f138..05d88eb 100644 --- a/caikit_computer_vision/data_model/__init__.py +++ b/caikit_computer_vision/data_model/__init__.py @@ -18,3 +18,4 @@ from .image_classification import * from .image_segmentation import * from .object_detection import * +from .text_to_image import * diff --git a/caikit_computer_vision/data_model/tasks.py b/caikit_computer_vision/data_model/tasks.py index bcc2a29..91bfe3a 100644 --- a/caikit_computer_vision/data_model/tasks.py +++ b/caikit_computer_vision/data_model/tasks.py @@ -26,6 +26,7 @@ from .image_classification import ImageClassificationResult from .image_segmentation import ImageSegmentationResult from .object_detection import ObjectDetectionResult +from .text_to_image import CaptionedImage # TODO - add support for image DM primitives @@ -61,3 +62,14 @@ class ImageSegmentationTask(TaskBase): Note that at the moment, this task encapsulates all segmentation types, I.e., instance, object, semantic, etc... """ + + +@task( + required_parameters={"inputs": str}, + output_type=CaptionedImage, +) +class TextToImageTask(TaskBase): + """The text to image task is responsible for taking an input text prompt, along with + other optional image generation parameters, e.g., image height and width, + and generating an image. + """ diff --git a/caikit_computer_vision/data_model/text_to_image.py b/caikit_computer_vision/data_model/text_to_image.py new file mode 100644 index 0000000..544b586 --- /dev/null +++ b/caikit_computer_vision/data_model/text_to_image.py @@ -0,0 +1,33 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data structures for text to image.""" + + +# Third Party +from py_to_proto.dataclass_to_proto import Annotated, FieldNumber + +# First Party +from caikit.core import DataObjectBase, dataobject +from caikit.interfaces.common.data_model import ProducerId +from caikit.interfaces.vision import data_model as caikit_dm +import alog + +log = alog.use_channel("DATAM") + + +@dataobject(package="caikit_data_model.caikit_computer_vision") +class CaptionedImage(DataObjectBase): + output: Annotated[caikit_dm.Image, FieldNumber(1)] + caption: Annotated[str, FieldNumber(2)] + producer_id: Annotated[ProducerId, FieldNumber(3)] diff --git a/caikit_computer_vision/modules/__init__.py b/caikit_computer_vision/modules/__init__.py index 2d3a94c..b4dbe7a 100644 --- a/caikit_computer_vision/modules/__init__.py +++ b/caikit_computer_vision/modules/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # Local -from . import object_detection, segmentation +from . import object_detection, segmentation, text_to_image diff --git a/caikit_computer_vision/modules/text_to_image/__init__.py b/caikit_computer_vision/modules/text_to_image/__init__.py new file mode 100644 index 0000000..ba34abe --- /dev/null +++ b/caikit_computer_vision/modules/text_to_image/__init__.py @@ -0,0 +1,16 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Local +from .tti_stub import TTIStub diff --git a/caikit_computer_vision/modules/text_to_image/tti_stub.py b/caikit_computer_vision/modules/text_to_image/tti_stub.py new file mode 100644 index 0000000..cab805b --- /dev/null +++ b/caikit_computer_vision/modules/text_to_image/tti_stub.py @@ -0,0 +1,77 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Stub module for text to image for testing runtime interfaces. +""" +# Standard +from typing import Union, get_args +import os + +# Third Party +import numpy as np + +# First Party +from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module +from caikit.interfaces.vision import data_model as caikit_dm +import alog + +# Local +from ...data_model import CaptionedImage +from ...data_model.tasks import TextToImageTask + +log = alog.use_channel("TTI_STUB") + + +@module( + id="28aa938b-1a33-11a0-11a3-bb9c3b1cbb11", + name="Stub module for Text to Image", + version="0.1.0", + task=TextToImageTask, +) +class TTIStub(ModuleBase): + def __init__( + self, + model_name, + ) -> "TTIStub": + log.debug("STUB - initializing text to image instance") + super().__init__() + self.model_name = model_name + + @classmethod + def load(cls, model_path: Union[str, "ModuleConfig"]) -> "TTIStub": + config = ModuleConfig.load(model_path) + return cls.bootstrap(config.model_name) + + @classmethod + def bootstrap(cls, model_name: str) -> "TTIStub": + return cls(model_name) + + def save(self, model_path: str): + saver = ModuleSaver( + self, + model_path=model_path, + ) + with saver: + saver.update_config({"model_name": self.model_name}) + + def run(self, inputs: str, height: int, width: int) -> CaptionedImage: + """Generates an image matching the provided height and width.""" + log.debug("STUB - running text to image inference") + r_channel = np.full((height, width), 0, dtype=np.uint8) + g_channel = np.full((height, width), 100, dtype=np.uint8) + b_channel = np.full((height, width), 200, dtype=np.uint8) + img = np.stack((r_channel, g_channel, b_channel), axis=2) + return CaptionedImage( + output=caikit_dm.Image(img), + caption=inputs, + ) diff --git a/tests/modules/text_to_image/test_tti_stub.py b/tests/modules/text_to_image/test_tti_stub.py new file mode 100644 index 0000000..4e4f7f5 --- /dev/null +++ b/tests/modules/text_to_image/test_tti_stub.py @@ -0,0 +1,43 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from tempfile import TemporaryDirectory +import os + +# Local +from caikit_computer_vision.modules.text_to_image import TTIStub +import caikit_computer_vision + + +def test_tti_stub(): + """Ensure that the stubs for load / save / run work as expected.""" + # Make sure we can bootstrap a model + model = TTIStub.bootstrap("foo") + assert isinstance(model, TTIStub) + + # Make sure we can run a fake inference on it + pred = model.run("This is a prompt", height=500, width=550) + assert isinstance(pred, caikit_computer_vision.data_model.CaptionedImage) + assert pred.caption == "This is a prompt" + pil_img = pred.output.as_pil() + assert pil_img.width == 550 + assert pil_img.height == 500 + + # Make sure we can save the model + model_dirname = "my_model" + with TemporaryDirectory() as tmpdirname: + model_path = os.path.join(tmpdirname, model_dirname) + model.save(model_path) + reloaded_model = model.load(model_path) diff --git a/tox.ini b/tox.ini index 88317a9..224552b 100644 --- a/tox.ini +++ b/tox.ini @@ -42,4 +42,11 @@ passenv = setenv = FLIT_USERNAME = __token__ commands = flit publish -skip_install = True \ No newline at end of file +skip_install = True + +[testenv:build] +description = build wheel +deps = + build +commands = python -m build +skip_install = True