From 81fb1b2546f32587fb507c337d1307e9d63041ec Mon Sep 17 00:00:00 2001 From: Sidhant Kohli Date: Sun, 23 Feb 2025 14:53:50 -0800 Subject: [PATCH 1/8] init udstore Signed-off-by: Sidhant Kohli --- Makefile | 3 +- examples/servingstore/in_memory/Dockerfile | 55 +++++ examples/servingstore/in_memory/Makefile | 23 ++ examples/servingstore/in_memory/README.md | 3 + examples/servingstore/in_memory/entry.sh | 4 + examples/servingstore/in_memory/example.py | 37 +++ .../servingstore/in_memory/pyproject.toml | 15 ++ pynumaflow/_constants.py | 3 + pynumaflow/info/types.py | 2 + pynumaflow/proto/serving/__init__.py | 0 pynumaflow/proto/serving/store.proto | 74 ++++++ pynumaflow/proto/serving/store_pb2.py | 43 ++++ pynumaflow/proto/serving/store_pb2.pyi | 64 ++++++ pynumaflow/proto/serving/store_pb2_grpc.py | 170 ++++++++++++++ pynumaflow/servingstore/__init__.py | 4 + pynumaflow/servingstore/_dtypes.py | 179 +++++++++++++++ pynumaflow/servingstore/server.py | 104 +++++++++ pynumaflow/servingstore/servicer/__init__.py | 0 pynumaflow/servingstore/servicer/servicer.py | 68 ++++++ pynumaflow/shared/server.py | 3 + tests/servingstore/__init__.py | 0 tests/servingstore/test_responses.py | 53 +++++ .../test_side_input_server.py | 0 tests/sideinput/test_responses.py | 60 ++--- tests/sideinput/test_serving_store_server.py | 215 ++++++++++++++++++ 25 files changed, 1151 insertions(+), 31 deletions(-) create mode 100644 examples/servingstore/in_memory/Dockerfile create mode 100644 examples/servingstore/in_memory/Makefile create mode 100644 examples/servingstore/in_memory/README.md create mode 100644 examples/servingstore/in_memory/entry.sh create mode 100644 examples/servingstore/in_memory/example.py create mode 100644 examples/servingstore/in_memory/pyproject.toml create mode 100644 pynumaflow/proto/serving/__init__.py create mode 100644 pynumaflow/proto/serving/store.proto create mode 100644 pynumaflow/proto/serving/store_pb2.py create mode 100644 pynumaflow/proto/serving/store_pb2.pyi create mode 100644 pynumaflow/proto/serving/store_pb2_grpc.py create mode 100644 pynumaflow/servingstore/__init__.py create mode 100644 pynumaflow/servingstore/_dtypes.py create mode 100644 pynumaflow/servingstore/server.py create mode 100644 pynumaflow/servingstore/servicer/__init__.py create mode 100644 pynumaflow/servingstore/servicer/servicer.py create mode 100644 tests/servingstore/__init__.py create mode 100644 tests/servingstore/test_responses.py rename tests/{sideinput => servingstore}/test_side_input_server.py (100%) create mode 100644 tests/sideinput/test_serving_store_server.py diff --git a/Makefile b/Makefile index b403ac66..43913307 100644 --- a/Makefile +++ b/Makefile @@ -28,12 +28,11 @@ setup: proto: python3 -m grpc_tools.protoc --pyi_out=pynumaflow/proto/sinker -I=pynumaflow/proto/sinker --python_out=pynumaflow/proto/sinker --grpc_python_out=pynumaflow/proto/sinker pynumaflow/proto/sinker/*.proto python3 -m grpc_tools.protoc --pyi_out=pynumaflow/proto/mapper -I=pynumaflow/proto/mapper --python_out=pynumaflow/proto/mapper --grpc_python_out=pynumaflow/proto/mapper pynumaflow/proto/mapper/*.proto - python3 -m grpc_tools.protoc --pyi_out=pynumaflow/proto/mapstreamer -I=pynumaflow/proto/mapstreamer --python_out=pynumaflow/proto/mapstreamer --grpc_python_out=pynumaflow/proto/mapstreamer pynumaflow/proto/mapstreamer/*.proto python3 -m grpc_tools.protoc --pyi_out=pynumaflow/proto/reducer -I=pynumaflow/proto/reducer --python_out=pynumaflow/proto/reducer --grpc_python_out=pynumaflow/proto/reducer pynumaflow/proto/reducer/*.proto python3 -m grpc_tools.protoc --pyi_out=pynumaflow/proto/sourcetransformer -I=pynumaflow/proto/sourcetransformer --python_out=pynumaflow/proto/sourcetransformer --grpc_python_out=pynumaflow/proto/sourcetransformer pynumaflow/proto/sourcetransformer/*.proto python3 -m grpc_tools.protoc --pyi_out=pynumaflow/proto/sideinput -I=pynumaflow/proto/sideinput --python_out=pynumaflow/proto/sideinput --grpc_python_out=pynumaflow/proto/sideinput pynumaflow/proto/sideinput/*.proto python3 -m grpc_tools.protoc --pyi_out=pynumaflow/proto/sourcer -I=pynumaflow/proto/sourcer --python_out=pynumaflow/proto/sourcer --grpc_python_out=pynumaflow/proto/sourcer pynumaflow/proto/sourcer/*.proto - python3 -m grpc_tools.protoc --pyi_out=pynumaflow/proto/batchmapper -I=pynumaflow/proto/batchmapper --python_out=pynumaflow/proto/batchmapper --grpc_python_out=pynumaflow/proto/batchmapper pynumaflow/proto/batchmapper/*.proto + python3 -m grpc_tools.protoc --pyi_out=pynumaflow/proto/serving -I=pynumaflow/proto/serving --python_out=pynumaflow/proto/serving --grpc_python_out=pynumaflow/proto/serving pynumaflow/proto/serving/*.proto sed -i '' 's/^\(import.*_pb2\)/from . \1/' pynumaflow/proto/*/*.py diff --git a/examples/servingstore/in_memory/Dockerfile b/examples/servingstore/in_memory/Dockerfile new file mode 100644 index 00000000..726073ba --- /dev/null +++ b/examples/servingstore/in_memory/Dockerfile @@ -0,0 +1,55 @@ +#################################################################################################### +# builder: install needed dependencies +#################################################################################################### + +FROM python:3.10-slim-bullseye AS builder + +ENV PYTHONFAULTHANDLER=1 \ + PYTHONUNBUFFERED=1 \ + PYTHONHASHSEED=random \ + PIP_NO_CACHE_DIR=on \ + PIP_DISABLE_PIP_VERSION_CHECK=on \ + PIP_DEFAULT_TIMEOUT=100 \ + POETRY_VERSION=1.2.2 \ + POETRY_HOME="/opt/poetry" \ + POETRY_VIRTUALENVS_IN_PROJECT=true \ + POETRY_NO_INTERACTION=1 \ + PYSETUP_PATH="/opt/pysetup" + +ENV EXAMPLE_PATH="$PYSETUP_PATH/examples/servingstore/in_memory" +ENV VENV_PATH="$EXAMPLE_PATH/.venv" +ENV PATH="$POETRY_HOME/bin:$VENV_PATH/bin:$PATH" + +RUN apt-get update \ + && apt-get install --no-install-recommends -y \ + curl \ + wget \ + # deps for building python deps + build-essential \ + && apt-get install -y git \ + && apt-get clean && rm -rf /var/lib/apt/lists/* \ + \ + # install dumb-init + && wget -O /dumb-init https://github.com/Yelp/dumb-init/releases/download/v1.2.5/dumb-init_1.2.5_x86_64 \ + && chmod +x /dumb-init \ + && curl -sSL https://install.python-poetry.org | python3 - + +#################################################################################################### +# udf: used for running the udf vertices +#################################################################################################### +FROM builder AS udf + +WORKDIR $PYSETUP_PATH +COPY ./ ./ + +WORKDIR $EXAMPLE_PATH +RUN poetry lock +RUN poetry install --no-cache --no-root && \ + rm -rf ~/.cache/pypoetry/ + +RUN chmod +x entry.sh + +ENTRYPOINT ["/dumb-init", "--"] +CMD ["sh", "-c", "$EXAMPLE_PATH/entry.sh"] + +EXPOSE 5000 diff --git a/examples/servingstore/in_memory/Makefile b/examples/servingstore/in_memory/Makefile new file mode 100644 index 00000000..0237f20f --- /dev/null +++ b/examples/servingstore/in_memory/Makefile @@ -0,0 +1,23 @@ +TAG ?= stable +PUSH ?= false +IMAGE_REGISTRY = quay.io/numaio/numaflow-python/serving-store-example:${TAG} +DOCKER_FILE_PATH = examples/servingstore/in_memory/Dockerfile + + +.PHONY: update +update: + poetry update -vv + +.PHONY: image-push +image-push: update + cd ../../../ && docker buildx build \ + -f ${DOCKER_FILE_PATH} \ + -t ${IMAGE_REGISTRY} \ + --platform linux/amd64,linux/arm64 . --push + +.PHONY: image +image: update + cd ../../../ && docker build \ + -f ${DOCKER_FILE_PATH} \ + -t ${IMAGE_REGISTRY} . + @if [ "$(PUSH)" = "true" ]; then docker push ${IMAGE_REGISTRY}; fi diff --git a/examples/servingstore/in_memory/README.md b/examples/servingstore/in_memory/README.md new file mode 100644 index 00000000..cae0bb85 --- /dev/null +++ b/examples/servingstore/in_memory/README.md @@ -0,0 +1,3 @@ +# Serving Store Example + +An example that demonstrates how to write a `serving store` in python. \ No newline at end of file diff --git a/examples/servingstore/in_memory/entry.sh b/examples/servingstore/in_memory/entry.sh new file mode 100644 index 00000000..073b05e3 --- /dev/null +++ b/examples/servingstore/in_memory/entry.sh @@ -0,0 +1,4 @@ +#!/bin/sh +set -eux + +python example.py diff --git a/examples/servingstore/in_memory/example.py b/examples/servingstore/in_memory/example.py new file mode 100644 index 00000000..ce448b3b --- /dev/null +++ b/examples/servingstore/in_memory/example.py @@ -0,0 +1,37 @@ +from pynumaflow.servingstore import ( + ServingStorer, + PutDatum, + GetDatum, + StoredResult, + ServingStoreServer, + Payload, +) + + +class InMemoryStore(ServingStorer): + def __init__(self): + self.store = {} + + def put(self, datum: PutDatum): + req_id = datum.id + print("Received Put request for ", req_id) + if req_id not in self.store: + self.store[req_id] = [] + + cur_payloads = self.store[req_id] + for x in datum.payloads: + cur_payloads.append(Payload(x.origin, x.value)) + self.store[req_id] = cur_payloads + + def get(self, datum: GetDatum) -> StoredResult: + req_id = datum.id + print("Received Get request for ", req_id) + resp = [] + if req_id in self.store: + resp = self.store[req_id] + return StoredResult(id_=req_id, payloads=resp) + + +if __name__ == "__main__": + grpc_server = ServingStoreServer(InMemoryStore()) + grpc_server.start() diff --git a/examples/servingstore/in_memory/pyproject.toml b/examples/servingstore/in_memory/pyproject.toml new file mode 100644 index 00000000..783dc8a7 --- /dev/null +++ b/examples/servingstore/in_memory/pyproject.toml @@ -0,0 +1,15 @@ +[tool.poetry] +name = "in-memory-servingstore" +version = "0.2.4" +description = "" +authors = ["Numaflow developers"] + +[tool.poetry.dependencies] +python = ">=3.10,<3.13" +pynumaflow = { path = "../../../"} + +[tool.poetry.dev-dependencies] + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/pynumaflow/_constants.py b/pynumaflow/_constants.py index 174870ce..b0638ea7 100644 --- a/pynumaflow/_constants.py +++ b/pynumaflow/_constants.py @@ -18,6 +18,7 @@ MULTIPROC_MAP_SOCK_ADDR = "/var/run/numaflow/multiproc" FALLBACK_SINK_SOCK_PATH = "/var/run/numaflow/fb-sink.sock" BATCH_MAP_SOCK_PATH = "/var/run/numaflow/batchmap.sock" +SERVING_STORE_SOCK_PATH = "/var/run/numaflow/serving.sock" # Server information file configs MAP_SERVER_INFO_FILE_PATH = "/var/run/numaflow/mapper-server-info" @@ -28,6 +29,7 @@ SIDE_INPUT_SERVER_INFO_FILE_PATH = "/var/run/numaflow/sideinput-server-info" SOURCE_SERVER_INFO_FILE_PATH = "/var/run/numaflow/sourcer-server-info" FALLBACK_SINK_SERVER_INFO_FILE_PATH = "/var/run/numaflow/fb-sinker-server-info" +SERVING_STORE_SERVER_INFO_FILE_PATH = "/var/run/numaflow/serving-server-info" ENV_UD_CONTAINER_TYPE = "NUMAFLOW_UD_CONTAINER_TYPE" UD_CONTAINER_FALLBACK_SINK = "fb-udsink" @@ -64,3 +66,4 @@ class UDFType(str, Enum): Source = "source" SideInput = "sideinput" SourceTransformer = "sourcetransformer" + ServingStore = "servingstore" diff --git a/pynumaflow/info/types.py b/pynumaflow/info/types.py index 2845c264..8f1f0b51 100644 --- a/pynumaflow/info/types.py +++ b/pynumaflow/info/types.py @@ -71,6 +71,7 @@ class ContainerType(str, Enum): Sessionreducer = "sessionreducer" Sideinput = "sideinput" Fbsinker = "fb-sinker" + Serving = "serving" # Minimum version of Numaflow required by the current SDK version @@ -86,6 +87,7 @@ class ContainerType(str, Enum): ContainerType.Sessionreducer: "1.4.0-z", ContainerType.Sideinput: "1.4.0-z", ContainerType.Fbsinker: "1.4.0-z", + ContainerType.Serving: "1.5.0-z", } diff --git a/pynumaflow/proto/serving/__init__.py b/pynumaflow/proto/serving/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pynumaflow/proto/serving/store.proto b/pynumaflow/proto/serving/store.proto new file mode 100644 index 00000000..3dc469e0 --- /dev/null +++ b/pynumaflow/proto/serving/store.proto @@ -0,0 +1,74 @@ +/* +Copyright 2022 The Numaproj Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +syntax = "proto3"; + +import "google/protobuf/empty.proto"; +import "google/protobuf/timestamp.proto"; + +package serving.v1; + +// ServingStore defines a set of methods to interface with a user-defined Store. +service ServingStore { + // Put is to put the PutRequest into the Store. + rpc Put(PutRequest) returns (PutResponse); + + // Get gets the GetRequest from the Store. + rpc Get(GetRequest) returns (GetResponse); + + // IsReady checks the health of the container interfacing the Store. + rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); +} + +// Payload that represent the output that is to be written into to the store. +message Payload { + // Origin is the Vertex that generated this result. + string origin = 1; + // Value is the result of the computation. + bytes value = 2; +} + +// PutRequest is the request sent to the Store. +message PutRequest { + // ID is the unique id as provided by the user in the original request. If not provided, it will be a system generated + // uuid. + string id = 1; + // Payloads are one or more results generated (could be more than one due to flat-map). + repeated Payload payloads = 2; +} + +// PutResponse is the result of the Put call. +message PutResponse { + bool success = 1; +} + +// GetRequest is the call to get the result stored in the Store. +message GetRequest { + // ID is the unique id as provided by the user in the original request. If not provided, it will be a system generated + // uuid. + string id = 1; +} + +// GetResponse is the result stored in the Store. +message GetResponse { + string id = 1; + // Payloads are one or more results generated (could be more than one due to flat-map). + repeated Payload payloads = 2; +} + +/** + * ReadyResponse is the health check result. + */ +message ReadyResponse { + bool ready = 1; +} diff --git a/pynumaflow/proto/serving/store_pb2.py b/pynumaflow/proto/serving/store_pb2.py new file mode 100644 index 00000000..7ce1e005 --- /dev/null +++ b/pynumaflow/proto/serving/store_pb2.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: store.proto +# Protobuf Python Version: 4.25.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 +from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x0bstore.proto\x12\nserving.v1\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1fgoogle/protobuf/timestamp.proto"(\n\x07Payload\x12\x0e\n\x06origin\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c"?\n\nPutRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12%\n\x08payloads\x18\x02 \x03(\x0b\x32\x13.serving.v1.Payload"\x1e\n\x0bPutResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08"\x18\n\nGetRequest\x12\n\n\x02id\x18\x01 \x01(\t"@\n\x0bGetResponse\x12\n\n\x02id\x18\x01 \x01(\t\x12%\n\x08payloads\x18\x02 \x03(\x0b\x32\x13.serving.v1.Payload"\x1e\n\rReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08\x32\xbc\x01\n\x0cServingStore\x12\x36\n\x03Put\x12\x16.serving.v1.PutRequest\x1a\x17.serving.v1.PutResponse\x12\x36\n\x03Get\x12\x16.serving.v1.GetRequest\x1a\x17.serving.v1.GetResponse\x12<\n\x07IsReady\x12\x16.google.protobuf.Empty\x1a\x19.serving.v1.ReadyResponseb\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "store_pb2", _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals["_PAYLOAD"]._serialized_start = 89 + _globals["_PAYLOAD"]._serialized_end = 129 + _globals["_PUTREQUEST"]._serialized_start = 131 + _globals["_PUTREQUEST"]._serialized_end = 194 + _globals["_PUTRESPONSE"]._serialized_start = 196 + _globals["_PUTRESPONSE"]._serialized_end = 226 + _globals["_GETREQUEST"]._serialized_start = 228 + _globals["_GETREQUEST"]._serialized_end = 252 + _globals["_GETRESPONSE"]._serialized_start = 254 + _globals["_GETRESPONSE"]._serialized_end = 318 + _globals["_READYRESPONSE"]._serialized_start = 320 + _globals["_READYRESPONSE"]._serialized_end = 350 + _globals["_SERVINGSTORE"]._serialized_start = 353 + _globals["_SERVINGSTORE"]._serialized_end = 541 +# @@protoc_insertion_point(module_scope) diff --git a/pynumaflow/proto/serving/store_pb2.pyi b/pynumaflow/proto/serving/store_pb2.pyi new file mode 100644 index 00000000..121ed99c --- /dev/null +++ b/pynumaflow/proto/serving/store_pb2.pyi @@ -0,0 +1,64 @@ +from google.protobuf import empty_pb2 as _empty_pb2 +from google.protobuf import timestamp_pb2 as _timestamp_pb2 +from google.protobuf.internal import containers as _containers +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ( + ClassVar as _ClassVar, + Iterable as _Iterable, + Mapping as _Mapping, + Optional as _Optional, + Union as _Union, +) + +DESCRIPTOR: _descriptor.FileDescriptor + +class Payload(_message.Message): + __slots__ = ("origin", "value") + ORIGIN_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + origin: str + value: bytes + def __init__(self, origin: _Optional[str] = ..., value: _Optional[bytes] = ...) -> None: ... + +class PutRequest(_message.Message): + __slots__ = ("id", "payloads") + ID_FIELD_NUMBER: _ClassVar[int] + PAYLOADS_FIELD_NUMBER: _ClassVar[int] + id: str + payloads: _containers.RepeatedCompositeFieldContainer[Payload] + def __init__( + self, + id: _Optional[str] = ..., + payloads: _Optional[_Iterable[_Union[Payload, _Mapping]]] = ..., + ) -> None: ... + +class PutResponse(_message.Message): + __slots__ = ("success",) + SUCCESS_FIELD_NUMBER: _ClassVar[int] + success: bool + def __init__(self, success: bool = ...) -> None: ... + +class GetRequest(_message.Message): + __slots__ = ("id",) + ID_FIELD_NUMBER: _ClassVar[int] + id: str + def __init__(self, id: _Optional[str] = ...) -> None: ... + +class GetResponse(_message.Message): + __slots__ = ("id", "payloads") + ID_FIELD_NUMBER: _ClassVar[int] + PAYLOADS_FIELD_NUMBER: _ClassVar[int] + id: str + payloads: _containers.RepeatedCompositeFieldContainer[Payload] + def __init__( + self, + id: _Optional[str] = ..., + payloads: _Optional[_Iterable[_Union[Payload, _Mapping]]] = ..., + ) -> None: ... + +class ReadyResponse(_message.Message): + __slots__ = ("ready",) + READY_FIELD_NUMBER: _ClassVar[int] + ready: bool + def __init__(self, ready: bool = ...) -> None: ... diff --git a/pynumaflow/proto/serving/store_pb2_grpc.py b/pynumaflow/proto/serving/store_pb2_grpc.py new file mode 100644 index 00000000..97f1da61 --- /dev/null +++ b/pynumaflow/proto/serving/store_pb2_grpc.py @@ -0,0 +1,170 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 +from . import store_pb2 as store__pb2 + + +class ServingStoreStub(object): + """ServingStore defines a set of methods to interface with a user-defined Store.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Put = channel.unary_unary( + "/serving.v1.ServingStore/Put", + request_serializer=store__pb2.PutRequest.SerializeToString, + response_deserializer=store__pb2.PutResponse.FromString, + ) + self.Get = channel.unary_unary( + "/serving.v1.ServingStore/Get", + request_serializer=store__pb2.GetRequest.SerializeToString, + response_deserializer=store__pb2.GetResponse.FromString, + ) + self.IsReady = channel.unary_unary( + "/serving.v1.ServingStore/IsReady", + request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + response_deserializer=store__pb2.ReadyResponse.FromString, + ) + + +class ServingStoreServicer(object): + """ServingStore defines a set of methods to interface with a user-defined Store.""" + + def Put(self, request, context): + """Put is to put the PutRequest into the Store.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def Get(self, request, context): + """Get gets the GetRequest from the Store.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def IsReady(self, request, context): + """IsReady checks the health of the container interfacing the Store.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_ServingStoreServicer_to_server(servicer, server): + rpc_method_handlers = { + "Put": grpc.unary_unary_rpc_method_handler( + servicer.Put, + request_deserializer=store__pb2.PutRequest.FromString, + response_serializer=store__pb2.PutResponse.SerializeToString, + ), + "Get": grpc.unary_unary_rpc_method_handler( + servicer.Get, + request_deserializer=store__pb2.GetRequest.FromString, + response_serializer=store__pb2.GetResponse.SerializeToString, + ), + "IsReady": grpc.unary_unary_rpc_method_handler( + servicer.IsReady, + request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + response_serializer=store__pb2.ReadyResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + "serving.v1.ServingStore", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) + + +# This class is part of an EXPERIMENTAL API. +class ServingStore(object): + """ServingStore defines a set of methods to interface with a user-defined Store.""" + + @staticmethod + def Put( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/serving.v1.ServingStore/Put", + store__pb2.PutRequest.SerializeToString, + store__pb2.PutResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def Get( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/serving.v1.ServingStore/Get", + store__pb2.GetRequest.SerializeToString, + store__pb2.GetResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def IsReady( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/serving.v1.ServingStore/IsReady", + google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + store__pb2.ReadyResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/pynumaflow/servingstore/__init__.py b/pynumaflow/servingstore/__init__.py new file mode 100644 index 00000000..4c4dfae9 --- /dev/null +++ b/pynumaflow/servingstore/__init__.py @@ -0,0 +1,4 @@ +from pynumaflow.servingstore._dtypes import PutDatum, GetDatum, ServingStorer, StoredResult, Payload +from pynumaflow.servingstore.server import ServingStoreServer + +__all__ = ["PutDatum", "GetDatum", "ServingStorer", "ServingStoreServer", "StoredResult", "Payload"] diff --git a/pynumaflow/servingstore/_dtypes.py b/pynumaflow/servingstore/_dtypes.py new file mode 100644 index 00000000..78c9193d --- /dev/null +++ b/pynumaflow/servingstore/_dtypes.py @@ -0,0 +1,179 @@ +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from typing import TypeVar + +P = TypeVar("P", bound="Payload") + + +# class Payload: +# """ +# Class to define each independent result stored in the Store for the given ID. +# Args: +# origin: origin for a given payload +# value: the data associated for the given message +# """ +# __slots__ = ("_origin", "_value") +# +# def __init__( +# self, +# origin: str, +# value: bytes, +# ): +# """ +# Creates a Payload object to send value retrieved from the store. +# """ +# self._origin = origin +# self._value = value +# +# @property +# def value(self) -> bytes: +# return self._value +# +# @property +# def origin(self) -> str: +# return self._origin + + +@dataclass +class Payload: + """ + Class to define each independent result stored in the Store for the given ID. + + Attributes: + origin (str): The origin of a given payload, typically describing where or what the data + comes from, for example, a sensor ID or a message source. + value (bytes): The data associated with the payload, stored as bytes to accommodate various + types of binary data or encoded string data. + """ + + origin: str + value: bytes + + +@dataclass(init=False) +class PutDatum: + """ + Class to define data for the Put rpc. + Args: + id_: the id of the request. + payloads: the payload to be stored. + + >>> # Example usage + >>> from pynumaflow.servingstore import PutDatum + >>> from datetime import datetime, timezone + >>> payload = bytes("test_mock_message", encoding="utf-8") + >>> d = PutDatum( + ... id_ = "avc", payloads = [payload] + ... ) + """ + + __slots__ = ("_id", "_payloads") + + _id: str + _payloads: list[Payload] + + def __init__( + self, + id_: str, + payloads: list[Payload], + ): + self._id = id_ + self._payloads = payloads or [] + + @property + def id(self) -> str: + """Returns the id of the event""" + return self._id + + @property + def payloads(self) -> list[Payload]: + """Returns the payloads of the event.""" + return self._payloads + + +@dataclass(init=False) +class GetDatum: + """ + Class to retrieve data from the Get rpc. + Args: + id_: the id of the request. + + >>> # Example usage + >>> from pynumaflow.servingstore import GetDatum + >>> from datetime import datetime, timezone + >>> payload = bytes("test_mock_message", encoding="utf-8") + >>> d = GetDatum( + ... id_ = "avc" + ... ) + """ + + __slots__ = ("_id",) + + _id: str + + def __init__( + self, + id_: str, + ): + self._id = id_ + + @property + def id(self) -> str: + """Returns the id of the event""" + return self._id + + +@dataclass +class StoredResult: + """ + Class to define the data stored in the store per origin.. + Args: + id_: unique ID for the response + payloads: the payloads of the given ID + """ + + __slots__ = ("_id", "_payloads") + + _id: str + _payloads: list[Payload] + + def __init__(self, id_: str, payloads: list[Payload] = None): + """ + Creates a StoredResult object to send value to a vertex. + """ + self._id = id_ + self._payloads = payloads or [] + + @property + def payloads(self) -> list[Payload]: + """Returns the payloads of the event""" + return self._payloads + + @property + def id(self) -> str: + """Returns the id of the event""" + return self._id + + +class ServingStorer(metaclass=ABCMeta): + """ + Provides an interface to write a Serving Store Class + which will be exposed over gRPC. + """ + + @abstractmethod + def put(self, datum: PutDatum): + """ + This function is called when a Side Input request is received. + """ + pass + + @abstractmethod + def get(self, datum: GetDatum) -> StoredResult: + """ + The simple source always returns zero to indicate there is no pending record. + """ + pass + + +ServingStoreCallable = ServingStorer diff --git a/pynumaflow/servingstore/server.py b/pynumaflow/servingstore/server.py new file mode 100644 index 00000000..7fdef066 --- /dev/null +++ b/pynumaflow/servingstore/server.py @@ -0,0 +1,104 @@ +from pynumaflow._constants import ( + NUM_THREADS_DEFAULT, + MAX_MESSAGE_SIZE, + _LOGGER, + UDFType, + MAX_NUM_THREADS, + SERVING_STORE_SOCK_PATH, + SERVING_STORE_SERVER_INFO_FILE_PATH, +) +from pynumaflow.info.types import ServerInfo, MINIMUM_NUMAFLOW_VERSION, ContainerType +from pynumaflow.servingstore._dtypes import ServingStoreCallable +from pynumaflow.servingstore.servicer.servicer import SyncServingStoreServicer +from pynumaflow.shared import NumaflowServer +from pynumaflow.shared.server import sync_server_start + + +class ServingStoreServer(NumaflowServer): + """ + Class for a new Serving Store instance. + Args: + serving_store_instance: The serving store instance to be used + sock_path: The UNIX socket path to be used for the server + max_message_size: The max message size in bytes the server can receive and send + max_threads: The max number of threads to be spawned; + + Example invocation: + import datetime + from pynumaflow.servingstore import Response, ServingStoreServer, ServingStorer + + class InMemoryStore(ServingStorer): + def __init__(self): + self.store = {} + + def put(self, datum: PutDatum): + req_id = datum.id + print("Received Put request for ", req_id) + if req_id not in self.store: + self.store[req_id] = [] + + cur_payloads = self.store[req_id] + for x in datum.payloads: + cur_payloads.append(Payload(x.origin, x.value)) + self.store[req_id] = cur_payloads + + def get(self, datum: GetDatum) -> StoredResult: + req_id = datum.id + print("Received Get request for ", req_id) + resp = [] + if req_id in self.store: + resp = self.store[req_id] + return StoredResult(id_=req_id, payloads=resp) + + if __name__ == "__main__": + grpc_server = ServingStoreServer(InMemoryStore()) + grpc_server.start() + + """ + + def __init__( + self, + serving_store_instance: ServingStoreCallable, + sock_path=SERVING_STORE_SOCK_PATH, + max_message_size=MAX_MESSAGE_SIZE, + max_threads=NUM_THREADS_DEFAULT, + server_info_file=SERVING_STORE_SERVER_INFO_FILE_PATH, + ): + self.sock_path = f"unix://{sock_path}" + self.max_threads = min(max_threads, MAX_NUM_THREADS) + self.max_message_size = max_message_size + self.server_info_file = server_info_file + + self._server_options = [ + ("grpc.max_send_message_length", self.max_message_size), + ("grpc.max_receive_message_length", self.max_message_size), + ] + + self.serving_store_instance = serving_store_instance + self.servicer = SyncServingStoreServicer(serving_store_instance) + + def start(self): + """ + Starts the Synchronous gRPC server on the given UNIX socket with given max threads. + """ + # Get the servicer instance based on the server type + serving_store_servicer = self.servicer + + _LOGGER.info( + "Serving Store GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, + ) + + serv_info = ServerInfo.get_default_server_info() + serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Serving] + # Start the server + sync_server_start( + servicer=serving_store_servicer, + bind_address=self.sock_path, + max_threads=self.max_threads, + server_info_file=self.server_info_file, + server_options=self._server_options, + udf_type=UDFType.ServingStore, + server_info=serv_info, + ) diff --git a/pynumaflow/servingstore/servicer/__init__.py b/pynumaflow/servingstore/servicer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pynumaflow/servingstore/servicer/servicer.py b/pynumaflow/servingstore/servicer/servicer.py new file mode 100644 index 00000000..d62b459e --- /dev/null +++ b/pynumaflow/servingstore/servicer/servicer.py @@ -0,0 +1,68 @@ +from google.protobuf import empty_pb2 as _empty_pb2 + +from pynumaflow._constants import ( + _LOGGER, +) +from pynumaflow.proto.serving import store_pb2_grpc, store_pb2 +from pynumaflow.servingstore._dtypes import ServingStoreCallable, Payload, PutDatum, GetDatum +from pynumaflow.shared.server import exit_on_error +from pynumaflow.types import NumaflowServicerContext + + +class SyncServingStoreServicer(store_pb2_grpc.ServingStoreServicer): + def __init__( + self, + handler: ServingStoreCallable, + ): + self.__serving_store_instance: ServingStoreCallable = handler + + def Put( + self, request: store_pb2.PutRequest, context: NumaflowServicerContext + ) -> store_pb2.PutResponse: + """ + Applies a Put function for store request. + The pascal case function name comes from the proto store_pb2_grpc.py file. + """ + # if there is an exception, we will mark all the responses as a failure + try: + input_payloads = [] + for x in request.payloads: + input_payloads.append(Payload(origin=x.origin, value=x.value)) + self.__serving_store_instance.put( + datum=PutDatum(id_=request.id, payloads=input_payloads) + ) + return store_pb2.PutResponse(success=True) + except BaseException as err: + err_msg = f"Serving Store Put: {repr(err)}" + _LOGGER.critical(err_msg, exc_info=True) + exit_on_error(context, repr(err)) + return + + def Get( + self, request: store_pb2.GetRequest, context: NumaflowServicerContext + ) -> store_pb2.GetResponse: + """ + Applies a Put function for store request. + The pascal case function name comes from the proto store_pb2_grpc.py file. + """ + # if there is an exception, we will mark all the responses as a failure + try: + resps = self.__serving_store_instance.get(datum=GetDatum(id_=request.id)) + resp_payloads = [] + for resp in resps.payloads: + resp_payloads.append(store_pb2.Payload(origin=resp.origin, value=resp.value)) + return store_pb2.GetResponse(id=request.id, payloads=resp_payloads) + except BaseException as err: + err_msg = f"Serving Store Get: {repr(err)}" + _LOGGER.critical(err_msg, exc_info=True) + exit_on_error(context, repr(err)) + return + + def IsReady( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> store_pb2.ReadyResponse: + """ + IsReady is the heartbeat endpoint for gRPC. + The pascal case function name comes from the proto store_pb2_grpc.py file. + """ + return store_pb2.ReadyResponse(ready=True) diff --git a/pynumaflow/shared/server.py b/pynumaflow/shared/server.py index ab86c9f0..3bcf767a 100644 --- a/pynumaflow/shared/server.py +++ b/pynumaflow/shared/server.py @@ -26,6 +26,7 @@ MULTIPROC_KEY, ) from pynumaflow.proto.mapper import map_pb2_grpc +from pynumaflow.proto.serving import store_pb2_grpc from pynumaflow.proto.sideinput import sideinput_pb2_grpc from pynumaflow.proto.sinker import sink_pb2_grpc from pynumaflow.proto.sourcer import source_pb2_grpc @@ -107,6 +108,8 @@ def _run_server( source_pb2_grpc.add_SourceServicer_to_server(servicer, server) elif udf_type == UDFType.SideInput: sideinput_pb2_grpc.add_SideInputServicer_to_server(servicer, server) + elif udf_type == UDFType.ServingStore: + store_pb2_grpc.add_ServingStoreServicer_to_server(servicer, server) # bind the server to the UDS/TCP socket server.add_insecure_port(bind_address) diff --git a/tests/servingstore/__init__.py b/tests/servingstore/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/servingstore/test_responses.py b/tests/servingstore/test_responses.py new file mode 100644 index 00000000..859f4bb1 --- /dev/null +++ b/tests/servingstore/test_responses.py @@ -0,0 +1,53 @@ +import unittest + +from pynumaflow.sideinput import Response, SideInput + + +class TestResponse(unittest.TestCase): + """ + Test the Response class for SideInput + """ + + def test_broadcast_message(self): + """ + Test the broadcast_message method, + where we expect the no_broadcast flag to be False. + """ + succ_response = Response.broadcast_message(b"2") + self.assertFalse(succ_response.no_broadcast) + self.assertEqual(b"2", succ_response.value) + + def test_no_broadcast_message(self): + """ + Test the no_broadcast_message method, + where we expect the no_broadcast flag to be True. + """ + succ_response = Response.no_broadcast_message() + self.assertTrue(succ_response.no_broadcast) + + +class ExampleSideInput(SideInput): + def retrieve_handler(self) -> Response: + return Response.broadcast_message(b"testMessage") + + +class TestSideInputClass(unittest.TestCase): + def setUp(self) -> None: + # Create a side input class instance + self.side_input_instance = ExampleSideInput() + + def test_side_input_class_call(self): + """Test that the __call__ functionality for the class works, + ie the class instance can be called directly to invoke the handler function + """ + # make a call to the class directly + ret = self.side_input_instance() + self.assertEqual(b"testMessage", ret.value) + # make a call to the handler + ret_handler = self.side_input_instance.retrieve_handler() + # Both responses should be equal + self.assertEqual(ret, ret_handler) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/sideinput/test_side_input_server.py b/tests/servingstore/test_side_input_server.py similarity index 100% rename from tests/sideinput/test_side_input_server.py rename to tests/servingstore/test_side_input_server.py diff --git a/tests/sideinput/test_responses.py b/tests/sideinput/test_responses.py index 859f4bb1..f0a1f50f 100644 --- a/tests/sideinput/test_responses.py +++ b/tests/sideinput/test_responses.py @@ -1,21 +1,22 @@ import unittest -from pynumaflow.sideinput import Response, SideInput +from pynumaflow.proto.serving import store_pb2 +from pynumaflow.servingstore import Payload +from pynumaflow.sideinput import Response -class TestResponse(unittest.TestCase): +class TestPayload(unittest.TestCase): """ Test the Response class for SideInput """ - def test_broadcast_message(self): + def test_create_payload(self): """ - Test the broadcast_message method, - where we expect the no_broadcast flag to be False. + Test the new payload method, """ - succ_response = Response.broadcast_message(b"2") - self.assertFalse(succ_response.no_broadcast) - self.assertEqual(b"2", succ_response.value) + x = store_pb2.Payload(origin="abc1", value=bytes("test_put", encoding="utf-8")) + succ_response = Payload(origin=x.origin, value=x.value) + print(succ_response.value) def test_no_broadcast_message(self): """ @@ -26,27 +27,28 @@ def test_no_broadcast_message(self): self.assertTrue(succ_response.no_broadcast) -class ExampleSideInput(SideInput): - def retrieve_handler(self) -> Response: - return Response.broadcast_message(b"testMessage") - - -class TestSideInputClass(unittest.TestCase): - def setUp(self) -> None: - # Create a side input class instance - self.side_input_instance = ExampleSideInput() - - def test_side_input_class_call(self): - """Test that the __call__ functionality for the class works, - ie the class instance can be called directly to invoke the handler function - """ - # make a call to the class directly - ret = self.side_input_instance() - self.assertEqual(b"testMessage", ret.value) - # make a call to the handler - ret_handler = self.side_input_instance.retrieve_handler() - # Both responses should be equal - self.assertEqual(ret, ret_handler) +# +# class ExampleSideInput(SideInput): +# def retrieve_handler(self) -> Response: +# return Response.broadcast_message(b"testMessage") +# +# +# class TestSideInputClass(unittest.TestCase): +# def setUp(self) -> None: +# # Create a side input class instance +# self.side_input_instance = ExampleSideInput() +# +# def test_side_input_class_call(self): +# """Test that the __call__ functionality for the class works, +# ie the class instance can be called directly to invoke the handler function +# """ +# # make a call to the class directly +# ret = self.side_input_instance() +# self.assertEqual(b"testMessage", ret.value) +# # make a call to the handler +# ret_handler = self.side_input_instance.retrieve_handler() +# # Both responses should be equal +# self.assertEqual(ret, ret_handler) if __name__ == "__main__": diff --git a/tests/sideinput/test_serving_store_server.py b/tests/sideinput/test_serving_store_server.py new file mode 100644 index 00000000..fc43039b --- /dev/null +++ b/tests/sideinput/test_serving_store_server.py @@ -0,0 +1,215 @@ +import unittest +from unittest.mock import patch + +import grpc +from google.protobuf import empty_pb2 as _empty_pb2 +from grpc import StatusCode +from grpc_testing import server_from_dictionary, strict_real_time + +from pynumaflow.proto.serving import store_pb2 +from pynumaflow.servingstore import ( + ServingStorer, + PutDatum, + Payload, + GetDatum, + StoredResult, + ServingStoreServer, +) +from tests.testing_utils import mock_terminate_on_stop + + +class InMemoryStore(ServingStorer): + def __init__(self): + self.store = {} + + def put(self, datum: PutDatum): + req_id = datum.id + print("Received Put request for ", req_id) + if req_id not in self.store: + self.store[req_id] = [] + + cur_payloads = self.store[req_id] + for x in datum.payloads: + print(x) + cur_payloads.append(Payload(x.origin, x.value)) + self.store[req_id] = cur_payloads + + def get(self, datum: GetDatum) -> StoredResult: + req_id = datum.id + print("Received Get request for ", req_id) + resp = [] + if req_id in self.store: + resp = self.store[req_id] + return StoredResult(id_=req_id, payloads=resp) + + +class ErrInMemoryStore(ServingStorer): + def __init__(self): + self.store = {} + + def put(self, datum: PutDatum): + req_id = datum.id + print("Received Put request for ", req_id) + if req_id not in self.store: + self.store[req_id] = [] + + cur_payloads = self.store[req_id] + for x in datum.payloads: + cur_payloads.append(Payload(x.origin, x.value)) + raise ValueError("something fishy") + self.store[req_id] = cur_payloads + + def get(self, datum: GetDatum) -> StoredResult: + req_id = datum.id + print("Received Get request for ", req_id) + raise ValueError("get is fishy") + + +def mock_message(): + msg = bytes("test_side_input", encoding="utf-8") + return msg + + +# We are mocking the terminate function from the psutil to not exit the program during testing +@patch("psutil.Process.kill", mock_terminate_on_stop) +class TestServer(unittest.TestCase): + """ + Test the SideInput grpc server + """ + + def setUp(self) -> None: + self.InMem = InMemoryStore() + server = ServingStoreServer(self.InMem) + my_service = server.servicer + services = {store_pb2.DESCRIPTOR.services_by_name["ServingStore"]: my_service} + self.test_server = server_from_dictionary(services, strict_real_time()) + + def test_init_with_args(self) -> None: + """ + Test the initialization of the SideInput class, + """ + my_server = ServingStoreServer( + serving_store_instance=InMemoryStore(), + sock_path="/tmp/test_serving_store.sock", + max_message_size=1024 * 1024 * 5, + ) + self.assertEqual(my_server.sock_path, "unix:///tmp/test_serving_store.sock") + self.assertEqual(my_server.max_message_size, 1024 * 1024 * 5) + + def test_serving_store_err(self): + """ + Test the error case for the Put method, + """ + server = ServingStoreServer(ErrInMemoryStore()) + my_service = server.servicer + services = {store_pb2.DESCRIPTOR.services_by_name["ServingStore"]: my_service} + self.test_server = server_from_dictionary(services, strict_real_time()) + + method = self.test_server.invoke_unary_unary( + method_descriptor=( + store_pb2.DESCRIPTOR.services_by_name["ServingStore"].methods_by_name["Put"] + ), + invocation_metadata={ + ("this_metadata_will_be_skipped", "test_ignore"), + }, + request=store_pb2.PutRequest( + id="abc", + payloads=[ + store_pb2.Payload(origin="abc", value=bytes("test_put", encoding="utf-8")) + ], + ), + timeout=1, + ) + response, metadata, code, details = method.termination() + self.assertEqual(grpc.StatusCode.UNKNOWN, code) + self.assertTrue("something fishy" in details) + + def test_is_ready(self): + method = self.test_server.invoke_unary_unary( + method_descriptor=( + store_pb2.DESCRIPTOR.services_by_name["ServingStore"].methods_by_name["IsReady"] + ), + invocation_metadata={}, + request=_empty_pb2.Empty(), + timeout=1, + ) + + response, metadata, code, details = method.termination() + expected = store_pb2.ReadyResponse(ready=True) + self.assertEqual(expected, response) + self.assertEqual(code, StatusCode.OK) + + def test_put_message(self): + """ + Test the broadcast_message method, + where we expect the no_broadcast flag to be False and + the message value to be the mock_message. + """ + request = store_pb2.PutRequest( + id="abc", + payloads=[store_pb2.Payload(origin="abc1", value=bytes("test_put", encoding="utf-8"))], + ) + method = self.test_server.invoke_unary_unary( + method_descriptor=( + store_pb2.DESCRIPTOR.services_by_name["ServingStore"].methods_by_name["Put"] + ), + invocation_metadata={ + ("this_metadata_will_be_skipped", "test_ignore"), + }, + request=request, + timeout=1, + ) + response, metadata, code, details = method.termination() + self.assertEqual(True, response.success) + self.assertEqual(code, StatusCode.OK) + stored = self.InMem.store["abc"] + self.assertEqual(stored[0].origin, "abc1") + self.assertEqual(stored[0].value, bytes("test_put", encoding="utf-8")) + + def test_get_message(self): + """ + Test the broadcast_message method, + where we expect the no_broadcast flag to be False and + the message value to be the mock_message. + """ + request = store_pb2.GetRequest( + id="abc", + ) + + method = self.test_server.invoke_unary_unary( + method_descriptor=( + store_pb2.DESCRIPTOR.services_by_name["ServingStore"].methods_by_name["Get"] + ), + invocation_metadata={ + ("this_metadata_will_be_skipped", "test_ignore"), + }, + request=request, + timeout=1, + ) + pl = Payload(origin="abc", value=bytes("test_put", encoding="utf-8")) + self.InMem.store["abc"] = [pl] + response, metadata, code, details = method.termination() + self.assertEqual(pl.value, response.payloads[0].value) + self.assertEqual(pl.origin, response.payloads[0].origin) + self.assertEqual(code, StatusCode.OK) + + def test_invalid_input(self): + with self.assertRaises(TypeError): + ServingStoreServer() + + def test_max_threads(self): + # max cap at 16 + server = ServingStoreServer(InMemoryStore(), max_threads=32) + self.assertEqual(server.max_threads, 16) + + # use argument provided + server = ServingStoreServer(InMemoryStore(), max_threads=5) + self.assertEqual(server.max_threads, 5) + + # defaults to 4 + server = ServingStoreServer(InMemoryStore()) + self.assertEqual(server.max_threads, 4) + + +if __name__ == "__main__": + unittest.main() From a28e42793384aaa3609a9e6a3409ab3d5240fa53 Mon Sep 17 00:00:00 2001 From: Sidhant Kohli Date: Sun, 23 Feb 2025 14:55:31 -0800 Subject: [PATCH 2/8] clean Signed-off-by: Sidhant Kohli --- tests/sideinput/test_responses.py | 35 +------------------------------ 1 file changed, 1 insertion(+), 34 deletions(-) diff --git a/tests/sideinput/test_responses.py b/tests/sideinput/test_responses.py index f0a1f50f..421d2829 100644 --- a/tests/sideinput/test_responses.py +++ b/tests/sideinput/test_responses.py @@ -2,7 +2,6 @@ from pynumaflow.proto.serving import store_pb2 from pynumaflow.servingstore import Payload -from pynumaflow.sideinput import Response class TestPayload(unittest.TestCase): @@ -16,39 +15,7 @@ def test_create_payload(self): """ x = store_pb2.Payload(origin="abc1", value=bytes("test_put", encoding="utf-8")) succ_response = Payload(origin=x.origin, value=x.value) - print(succ_response.value) - - def test_no_broadcast_message(self): - """ - Test the no_broadcast_message method, - where we expect the no_broadcast flag to be True. - """ - succ_response = Response.no_broadcast_message() - self.assertTrue(succ_response.no_broadcast) - - -# -# class ExampleSideInput(SideInput): -# def retrieve_handler(self) -> Response: -# return Response.broadcast_message(b"testMessage") -# -# -# class TestSideInputClass(unittest.TestCase): -# def setUp(self) -> None: -# # Create a side input class instance -# self.side_input_instance = ExampleSideInput() -# -# def test_side_input_class_call(self): -# """Test that the __call__ functionality for the class works, -# ie the class instance can be called directly to invoke the handler function -# """ -# # make a call to the class directly -# ret = self.side_input_instance() -# self.assertEqual(b"testMessage", ret.value) -# # make a call to the handler -# ret_handler = self.side_input_instance.retrieve_handler() -# # Both responses should be equal -# self.assertEqual(ret, ret_handler) + self.assertEqual(succ_response.value, x.value) if __name__ == "__main__": From bc2df8a512afdbc793ef4e37f11780b5336c6a92 Mon Sep 17 00:00:00 2001 From: Sidhant Kohli Date: Sun, 23 Feb 2025 14:57:56 -0800 Subject: [PATCH 3/8] clean Signed-off-by: Sidhant Kohli --- tests/sideinput/test_serving_store_server.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/sideinput/test_serving_store_server.py b/tests/sideinput/test_serving_store_server.py index fc43039b..b2cd17b5 100644 --- a/tests/sideinput/test_serving_store_server.py +++ b/tests/sideinput/test_serving_store_server.py @@ -189,8 +189,7 @@ def test_get_message(self): pl = Payload(origin="abc", value=bytes("test_put", encoding="utf-8")) self.InMem.store["abc"] = [pl] response, metadata, code, details = method.termination() - self.assertEqual(pl.value, response.payloads[0].value) - self.assertEqual(pl.origin, response.payloads[0].origin) + self.assertEqual(len(response.payloads), 1) self.assertEqual(code, StatusCode.OK) def test_invalid_input(self): From df0db30d2b21ff6b6815661a6b05b75f847d1c6e Mon Sep 17 00:00:00 2001 From: Sidhant Kohli Date: Sun, 23 Feb 2025 14:59:54 -0800 Subject: [PATCH 4/8] clean Signed-off-by: Sidhant Kohli --- tests/sideinput/test_serving_store_server.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/sideinput/test_serving_store_server.py b/tests/sideinput/test_serving_store_server.py index b2cd17b5..259846ab 100644 --- a/tests/sideinput/test_serving_store_server.py +++ b/tests/sideinput/test_serving_store_server.py @@ -186,10 +186,8 @@ def test_get_message(self): request=request, timeout=1, ) - pl = Payload(origin="abc", value=bytes("test_put", encoding="utf-8")) - self.InMem.store["abc"] = [pl] response, metadata, code, details = method.termination() - self.assertEqual(len(response.payloads), 1) + self.assertEqual(len(response.payloads), 0) self.assertEqual(code, StatusCode.OK) def test_invalid_input(self): From ff36c92c1d5194449b6d8991fbc75cfc5ae3faf6 Mon Sep 17 00:00:00 2001 From: Sidhant Kohli Date: Sun, 23 Feb 2025 15:31:58 -0800 Subject: [PATCH 5/8] add async Signed-off-by: Sidhant Kohli --- pynumaflow/servingstore/_dtypes.py | 47 +---- pynumaflow/servingstore/async_server.py | 112 +++++++++++ .../servingstore/servicer/async_servicer.py | 71 +++++++ tests/servingstore/test_async_source.py | 190 ++++++++++++++++++ tests/servingstore/test_responses.py | 53 ----- .../test_serving_store_server.py | 25 +++ tests/servingstore/test_store_responses.py | 22 ++ tests/sideinput/test_responses.py | 47 ++++- .../test_side_input_server.py | 0 9 files changed, 468 insertions(+), 99 deletions(-) create mode 100644 pynumaflow/servingstore/async_server.py create mode 100644 pynumaflow/servingstore/servicer/async_servicer.py create mode 100644 tests/servingstore/test_async_source.py delete mode 100644 tests/servingstore/test_responses.py rename tests/{sideinput => servingstore}/test_serving_store_server.py (87%) create mode 100644 tests/servingstore/test_store_responses.py rename tests/{servingstore => sideinput}/test_side_input_server.py (100%) diff --git a/pynumaflow/servingstore/_dtypes.py b/pynumaflow/servingstore/_dtypes.py index 78c9193d..2d65722f 100644 --- a/pynumaflow/servingstore/_dtypes.py +++ b/pynumaflow/servingstore/_dtypes.py @@ -1,39 +1,10 @@ from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import TypeVar +from typing import TypeVar, Awaitable P = TypeVar("P", bound="Payload") -# class Payload: -# """ -# Class to define each independent result stored in the Store for the given ID. -# Args: -# origin: origin for a given payload -# value: the data associated for the given message -# """ -# __slots__ = ("_origin", "_value") -# -# def __init__( -# self, -# origin: str, -# value: bytes, -# ): -# """ -# Creates a Payload object to send value retrieved from the store. -# """ -# self._origin = origin -# self._value = value -# -# @property -# def value(self) -> bytes: -# return self._value -# -# @property -# def origin(self) -> str: -# return self._origin - - @dataclass class Payload: """ @@ -41,7 +12,7 @@ class Payload: Attributes: origin (str): The origin of a given payload, typically describing where or what the data - comes from, for example, a sensor ID or a message source. + comes from value (bytes): The data associated with the payload, stored as bytes to accommodate various types of binary data or encoded string data. """ @@ -61,7 +32,7 @@ class PutDatum: >>> # Example usage >>> from pynumaflow.servingstore import PutDatum >>> from datetime import datetime, timezone - >>> payload = bytes("test_mock_message", encoding="utf-8") + >>> payload = Payload(_id="avc", value=bytes("test_mock_message", encoding="utf-8")) >>> d = PutDatum( ... id_ = "avc", payloads = [payload] ... ) @@ -73,9 +44,9 @@ class PutDatum: _payloads: list[Payload] def __init__( - self, - id_: str, - payloads: list[Payload], + self, + id_: str, + payloads: list[Payload], ): self._id = id_ self._payloads = payloads or [] @@ -112,8 +83,8 @@ class GetDatum: _id: str def __init__( - self, - id_: str, + self, + id_: str, ): self._id = id_ @@ -169,7 +140,7 @@ def put(self, datum: PutDatum): pass @abstractmethod - def get(self, datum: GetDatum) -> StoredResult: + def get(self, datum: GetDatum) -> [StoredResult, Awaitable[StoredResult]]: """ The simple source always returns zero to indicate there is no pending record. """ diff --git a/pynumaflow/servingstore/async_server.py b/pynumaflow/servingstore/async_server.py new file mode 100644 index 00000000..448556a3 --- /dev/null +++ b/pynumaflow/servingstore/async_server.py @@ -0,0 +1,112 @@ +import aiorun +import grpc + +from pynumaflow._constants import ( + MAX_MESSAGE_SIZE, + NUM_THREADS_DEFAULT, + MAX_NUM_THREADS, SERVING_STORE_SOCK_PATH, SERVING_STORE_SERVER_INFO_FILE_PATH, +) +from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION +from pynumaflow.proto.serving import store_pb2_grpc +from pynumaflow.servingstore._dtypes import ServingStoreCallable +from pynumaflow.servingstore.servicer.async_servicer import AsyncServingStoreServicer +from pynumaflow.shared.server import NumaflowServer, start_async_server + + +class ServingStoreAsyncServer(NumaflowServer): + """ + Class for a new Async Serving store Server instance. + """ + + def __init__( + self, + serving_store_instance: ServingStoreCallable, + sock_path=SERVING_STORE_SOCK_PATH, + max_message_size=MAX_MESSAGE_SIZE, + max_threads=NUM_THREADS_DEFAULT, + server_info_file=SERVING_STORE_SERVER_INFO_FILE_PATH, + ): + """ + Class for a new Async Serving Store server. + Args: + serving_store_instance: The serving store instance to be used + sock_path: The UNIX socket path to be used for the server + max_message_size: The max message size in bytes the server can receive and send + max_threads: The max number of threads to be spawned; + + Example invocation: + import datetime + from pynumaflow.servingstore import Response, ServingStoreServer, ServingStorer + + class InMemoryStore(ServingStorer): + def __init__(self): + self.store = {} + + async def put(self, datum: PutDatum): + req_id = datum.id + print("Received Put request for ", req_id) + if req_id not in self.store: + self.store[req_id] = [] + + cur_payloads = self.store[req_id] + for x in datum.payloads: + cur_payloads.append(Payload(x.origin, x.value)) + self.store[req_id] = cur_payloads + + async def get(self, datum: GetDatum) -> StoredResult: + req_id = datum.id + print("Received Get request for ", req_id) + resp = [] + if req_id in self.store: + resp = self.store[req_id] + return StoredResult(id_=req_id, payloads=resp) + + if __name__ == "__main__": + grpc_server = ServingStoreServer(InMemoryStore()) + grpc_server.start() + + """ + self.sock_path = f"unix://{sock_path}" + self.max_threads = min(max_threads, MAX_NUM_THREADS) + self.max_message_size = max_message_size + self.server_info_file = server_info_file + + self._server_options = [ + ("grpc.max_send_message_length", self.max_message_size), + ("grpc.max_receive_message_length", self.max_message_size), + ] + + self.serving_store_instance = serving_store_instance + self.servicer = AsyncServingStoreServicer(serving_store_instance) + + def start(self): + """ + Starter function for the Async server class, need a separate caller + so that all the async coroutines can be started from a single context + """ + aiorun.run(self.aexec(), use_uvloop=True) + + async def aexec(self): + """ + Starts the Async gRPC server on the given UNIX socket with given max threads + """ + # As the server is async, we need to create a new server instance in the + # same thread as the event loop so that all the async calls are made in the + # same context + # Create a new async server instance and add the servicer to it + server = grpc.aio.server(options=self._server_options) + server.add_insecure_port(self.sock_path) + store_servicer = self.servicer + store_pb2_grpc.add_ServingStoreServicer_to_server(store_servicer, server) + + serv_info = ServerInfo.get_default_server_info() + serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Serving] + # Start the async server + await start_async_server( + server_async=server, + sock_path=self.sock_path, + max_threads=self.max_threads, + cleanup_coroutines=list(), + server_info_file=self.server_info_file, + server_info=serv_info, + ) diff --git a/pynumaflow/servingstore/servicer/async_servicer.py b/pynumaflow/servingstore/servicer/async_servicer.py new file mode 100644 index 00000000..d9c28ed1 --- /dev/null +++ b/pynumaflow/servingstore/servicer/async_servicer.py @@ -0,0 +1,71 @@ +from google.protobuf import empty_pb2 as _empty_pb2 + +from pynumaflow._constants import _LOGGER +from pynumaflow.proto.serving import store_pb2_grpc, store_pb2 +from pynumaflow.proto.sourcer import source_pb2 +from pynumaflow.servingstore._dtypes import ServingStoreCallable, PutDatum, Payload, GetDatum +from pynumaflow.shared.server import handle_async_error +from pynumaflow.types import NumaflowServicerContext + + +class AsyncServingStoreServicer(store_pb2_grpc.ServingStoreServicer): + """ + This class is used to create a new grpc Async store servicer instance. + It implements the ServingStoreServicer interface from the proto store.proto file. + Provides the functionality for the required rpc methods. + """ + + def __init__(self, serving_store_instance: ServingStoreCallable): + """Initialize handler methods from the provided serving store handler.""" + self.background_tasks = set() + self.__serving_store_instance: ServingStoreCallable = serving_store_instance + self.cleanup_coroutines = [] + + async def Put( + self, request: store_pb2.PutRequest, context: NumaflowServicerContext + ) -> store_pb2.PutResponse: + """ + Handles the Put function, processing incoming requests and sending responses. + """ + # if there is an exception, we will mark all the responses as a failure + try: + input_payloads = [] + for x in request.payloads: + input_payloads.append(Payload(origin=x.origin, value=x.value)) + await self.__serving_store_instance.put( + datum=PutDatum(id_=request.id, payloads=input_payloads) + ) + except BaseException as err: + err_msg = f"Async Serving Store Put: {repr(err)}" + _LOGGER.critical(err_msg, exc_info=True) + await handle_async_error(context, err) + return store_pb2.PutResponse(success=True) + + async def Get( + self, request: store_pb2.GetRequest, context: NumaflowServicerContext + ) -> store_pb2.GetResponse: + """ + Handles the Get function, processing incoming requests and sending responses. + """ + # if there is an exception, we will mark all the responses as a failure + try: + resps = await self.__serving_store_instance.get(datum=GetDatum(id_=request.id)) + resp_payloads = [] + for resp in resps.payloads: + resp_payloads.append(store_pb2.Payload(origin=resp.origin, value=resp.value)) + except BaseException as err: + err_msg = f"Async Serving Store Get: {repr(err)}" + _LOGGER.critical(err_msg, exc_info=True) + await handle_async_error(context, err) + + return store_pb2.GetResponse(id=request.id, payloads=resp_payloads) + + async def IsReady( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> source_pb2.ReadyResponse: + """ + IsReady is the heartbeat endpoint for gRPC. + The pascal case function name comes from the proto source_pb2_grpc.py file. + """ + return source_pb2.ReadyResponse(ready=True) + diff --git a/tests/servingstore/test_async_source.py b/tests/servingstore/test_async_source.py new file mode 100644 index 00000000..c40cf1f7 --- /dev/null +++ b/tests/servingstore/test_async_source.py @@ -0,0 +1,190 @@ +# import asyncio +# import logging +# import threading +# import unittest +# +# import grpc +# from google.protobuf import empty_pb2 as _empty_pb2 +# from grpc.aio._server import Server +# +# from pynumaflow import setup_logging +# from pynumaflow.proto.sourcer import source_pb2_grpc, source_pb2 +# from pynumaflow.sourcer import ( +# SourceAsyncServer, +# ) +# from tests.source.utils import ( +# mock_offset, +# read_req_source_fn, +# ack_req_source_fn, +# mock_partitions, +# AsyncSource, +# ) +# +# LOGGER = setup_logging(__name__) +# +# # if set to true, map handler will raise a `ValueError` exception. +# raise_error_from_map = False +# +# server_port = "unix:///tmp/async_source.sock" +# +# _s: Server = None +# _channel = grpc.insecure_channel(server_port) +# _loop = None +# +# +# def startup_callable(loop): +# asyncio.set_event_loop(loop) +# loop.run_forever() +# +# +# def NewAsyncSourcer(): +# class_instance = AsyncSource() +# server = SourceAsyncServer(sourcer_instance=class_instance) +# udfs = server.servicer +# return udfs +# +# +# async def start_server(udfs): +# server = grpc.aio.server() +# source_pb2_grpc.add_SourceServicer_to_server(udfs, server) +# listen_addr = "unix:///tmp/async_source.sock" +# server.add_insecure_port(listen_addr) +# logging.info("Starting server on %s", listen_addr) +# global _s +# _s = server +# await server.start() +# await server.wait_for_termination() +# +# +# class TestAsyncSourcer(unittest.TestCase): +# @classmethod +# def setUpClass(cls) -> None: +# global _loop +# loop = asyncio.new_event_loop() +# _loop = loop +# _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) +# _thread.start() +# udfs = NewAsyncSourcer() +# asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) +# while True: +# try: +# with grpc.insecure_channel(server_port) as channel: +# f = grpc.channel_ready_future(channel) +# f.result(timeout=10) +# if f.done(): +# break +# except grpc.FutureTimeoutError as e: +# LOGGER.error("error trying to connect to grpc server") +# LOGGER.error(e) +# +# @classmethod +# def tearDownClass(cls) -> None: +# try: +# _loop.stop() +# LOGGER.info("stopped the event loop") +# except Exception as e: +# LOGGER.error(e) +# +# def test_read_source(self) -> None: +# with grpc.insecure_channel(server_port) as channel: +# stub = source_pb2_grpc.SourceStub(channel) +# +# request = read_req_source_fn() +# generator_response = None +# try: +# generator_response = stub.ReadFn(request=source_pb2.ReadRequest(request=request)) +# except grpc.RpcError as e: +# logging.error(e) +# +# counter = 0 +# # capture the output from the ReadFn generator and assert. +# for r in generator_response: +# counter += 1 +# self.assertEqual( +# bytes("payload:test_mock_message", encoding="utf-8"), +# r.result.payload, +# ) +# self.assertEqual( +# ["test_key"], +# r.result.keys, +# ) +# self.assertEqual( +# mock_offset().offset, +# r.result.offset.offset, +# ) +# self.assertEqual( +# mock_offset().partition_id, +# r.result.offset.partition_id, +# ) +# """Assert that the generator was called 10 times in the stream""" +# self.assertEqual(10, counter) +# +# def test_is_ready(self) -> None: +# with grpc.insecure_channel(server_port) as channel: +# stub = source_pb2_grpc.SourceStub(channel) +# +# request = _empty_pb2.Empty() +# response = None +# try: +# response = stub.IsReady(request=request) +# except grpc.RpcError as e: +# logging.error(e) +# +# self.assertTrue(response.ready) +# +# def test_ack(self) -> None: +# with grpc.insecure_channel(server_port) as channel: +# stub = source_pb2_grpc.SourceStub(channel) +# request = ack_req_source_fn() +# try: +# response = stub.AckFn(request=source_pb2.AckRequest(request=request)) +# except grpc.RpcError as e: +# print(e) +# +# self.assertEqual(response, source_pb2.AckResponse()) +# +# def test_pending(self) -> None: +# with grpc.insecure_channel(server_port) as channel: +# stub = source_pb2_grpc.SourceStub(channel) +# request = _empty_pb2.Empty() +# response = None +# try: +# response = stub.PendingFn(request=request) +# except grpc.RpcError as e: +# logging.error(e) +# +# self.assertEqual(response.result.count, 10) +# +# def test_partitions(self) -> None: +# with grpc.insecure_channel(server_port) as channel: +# stub = source_pb2_grpc.SourceStub(channel) +# request = _empty_pb2.Empty() +# response = None +# try: +# response = stub.PartitionsFn(request=request) +# except grpc.RpcError as e: +# logging.error(e) +# +# self.assertEqual(response.result.partitions, mock_partitions()) +# +# def __stub(self): +# return source_pb2_grpc.SourceStub(_channel) +# +# def test_max_threads(self): +# class_instance = AsyncSource() +# # max cap at 16 +# server = SourceAsyncServer(sourcer_instance=class_instance, max_threads=32) +# self.assertEqual(server.max_threads, 16) +# +# # use argument provided +# server = SourceAsyncServer(sourcer_instance=class_instance, max_threads=5) +# self.assertEqual(server.max_threads, 5) +# +# # defaults to 4 +# server = SourceAsyncServer(sourcer_instance=class_instance) +# self.assertEqual(server.max_threads, 4) +# +# +# if __name__ == "__main__": +# logging.basicConfig(level=logging.DEBUG) +# unittest.main() \ No newline at end of file diff --git a/tests/servingstore/test_responses.py b/tests/servingstore/test_responses.py deleted file mode 100644 index 859f4bb1..00000000 --- a/tests/servingstore/test_responses.py +++ /dev/null @@ -1,53 +0,0 @@ -import unittest - -from pynumaflow.sideinput import Response, SideInput - - -class TestResponse(unittest.TestCase): - """ - Test the Response class for SideInput - """ - - def test_broadcast_message(self): - """ - Test the broadcast_message method, - where we expect the no_broadcast flag to be False. - """ - succ_response = Response.broadcast_message(b"2") - self.assertFalse(succ_response.no_broadcast) - self.assertEqual(b"2", succ_response.value) - - def test_no_broadcast_message(self): - """ - Test the no_broadcast_message method, - where we expect the no_broadcast flag to be True. - """ - succ_response = Response.no_broadcast_message() - self.assertTrue(succ_response.no_broadcast) - - -class ExampleSideInput(SideInput): - def retrieve_handler(self) -> Response: - return Response.broadcast_message(b"testMessage") - - -class TestSideInputClass(unittest.TestCase): - def setUp(self) -> None: - # Create a side input class instance - self.side_input_instance = ExampleSideInput() - - def test_side_input_class_call(self): - """Test that the __call__ functionality for the class works, - ie the class instance can be called directly to invoke the handler function - """ - # make a call to the class directly - ret = self.side_input_instance() - self.assertEqual(b"testMessage", ret.value) - # make a call to the handler - ret_handler = self.side_input_instance.retrieve_handler() - # Both responses should be equal - self.assertEqual(ret, ret_handler) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/sideinput/test_serving_store_server.py b/tests/servingstore/test_serving_store_server.py similarity index 87% rename from tests/sideinput/test_serving_store_server.py rename to tests/servingstore/test_serving_store_server.py index 259846ab..9ef50781 100644 --- a/tests/sideinput/test_serving_store_server.py +++ b/tests/servingstore/test_serving_store_server.py @@ -190,6 +190,31 @@ def test_get_message(self): self.assertEqual(len(response.payloads), 0) self.assertEqual(code, StatusCode.OK) + def test_serving_store_get_err(self): + """ + Test the error case for the Put method, + """ + server = ServingStoreServer(ErrInMemoryStore()) + my_service = server.servicer + services = {store_pb2.DESCRIPTOR.services_by_name["ServingStore"]: my_service} + self.test_server = server_from_dictionary(services, strict_real_time()) + request = store_pb2.GetRequest( + id="abc", + ) + method = self.test_server.invoke_unary_unary( + method_descriptor=( + store_pb2.DESCRIPTOR.services_by_name["ServingStore"].methods_by_name["Get"] + ), + invocation_metadata={ + ("this_metadata_will_be_skipped", "test_ignore"), + }, + request=request, + timeout=1, + ) + response, metadata, code, details = method.termination() + self.assertEqual(grpc.StatusCode.UNKNOWN, code) + self.assertTrue("get is fishy" in details) + def test_invalid_input(self): with self.assertRaises(TypeError): ServingStoreServer() diff --git a/tests/servingstore/test_store_responses.py b/tests/servingstore/test_store_responses.py new file mode 100644 index 00000000..421d2829 --- /dev/null +++ b/tests/servingstore/test_store_responses.py @@ -0,0 +1,22 @@ +import unittest + +from pynumaflow.proto.serving import store_pb2 +from pynumaflow.servingstore import Payload + + +class TestPayload(unittest.TestCase): + """ + Test the Response class for SideInput + """ + + def test_create_payload(self): + """ + Test the new payload method, + """ + x = store_pb2.Payload(origin="abc1", value=bytes("test_put", encoding="utf-8")) + succ_response = Payload(origin=x.origin, value=x.value) + self.assertEqual(succ_response.value, x.value) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/sideinput/test_responses.py b/tests/sideinput/test_responses.py index 421d2829..859f4bb1 100644 --- a/tests/sideinput/test_responses.py +++ b/tests/sideinput/test_responses.py @@ -1,21 +1,52 @@ import unittest -from pynumaflow.proto.serving import store_pb2 -from pynumaflow.servingstore import Payload +from pynumaflow.sideinput import Response, SideInput -class TestPayload(unittest.TestCase): +class TestResponse(unittest.TestCase): """ Test the Response class for SideInput """ - def test_create_payload(self): + def test_broadcast_message(self): """ - Test the new payload method, + Test the broadcast_message method, + where we expect the no_broadcast flag to be False. """ - x = store_pb2.Payload(origin="abc1", value=bytes("test_put", encoding="utf-8")) - succ_response = Payload(origin=x.origin, value=x.value) - self.assertEqual(succ_response.value, x.value) + succ_response = Response.broadcast_message(b"2") + self.assertFalse(succ_response.no_broadcast) + self.assertEqual(b"2", succ_response.value) + + def test_no_broadcast_message(self): + """ + Test the no_broadcast_message method, + where we expect the no_broadcast flag to be True. + """ + succ_response = Response.no_broadcast_message() + self.assertTrue(succ_response.no_broadcast) + + +class ExampleSideInput(SideInput): + def retrieve_handler(self) -> Response: + return Response.broadcast_message(b"testMessage") + + +class TestSideInputClass(unittest.TestCase): + def setUp(self) -> None: + # Create a side input class instance + self.side_input_instance = ExampleSideInput() + + def test_side_input_class_call(self): + """Test that the __call__ functionality for the class works, + ie the class instance can be called directly to invoke the handler function + """ + # make a call to the class directly + ret = self.side_input_instance() + self.assertEqual(b"testMessage", ret.value) + # make a call to the handler + ret_handler = self.side_input_instance.retrieve_handler() + # Both responses should be equal + self.assertEqual(ret, ret_handler) if __name__ == "__main__": diff --git a/tests/servingstore/test_side_input_server.py b/tests/sideinput/test_side_input_server.py similarity index 100% rename from tests/servingstore/test_side_input_server.py rename to tests/sideinput/test_side_input_server.py From 8457559ebb1f041657288341ddfaafaaa9b3993f Mon Sep 17 00:00:00 2001 From: Sidhant Kohli Date: Sun, 23 Feb 2025 16:16:13 -0800 Subject: [PATCH 6/8] add async Signed-off-by: Sidhant Kohli --- pynumaflow/servingstore/__init__.py | 11 +- pynumaflow/servingstore/_dtypes.py | 13 +- pynumaflow/servingstore/async_server.py | 4 +- .../servingstore/servicer/async_servicer.py | 12 +- .../servingstore/test_async_serving_store.py | 235 ++++++++++++++++++ tests/servingstore/test_async_source.py | 190 -------------- .../servingstore/test_serving_store_server.py | 21 +- 7 files changed, 280 insertions(+), 206 deletions(-) create mode 100644 tests/servingstore/test_async_serving_store.py delete mode 100644 tests/servingstore/test_async_source.py diff --git a/pynumaflow/servingstore/__init__.py b/pynumaflow/servingstore/__init__.py index 4c4dfae9..91615079 100644 --- a/pynumaflow/servingstore/__init__.py +++ b/pynumaflow/servingstore/__init__.py @@ -1,4 +1,13 @@ from pynumaflow.servingstore._dtypes import PutDatum, GetDatum, ServingStorer, StoredResult, Payload +from pynumaflow.servingstore.async_server import ServingStoreAsyncServer from pynumaflow.servingstore.server import ServingStoreServer -__all__ = ["PutDatum", "GetDatum", "ServingStorer", "ServingStoreServer", "StoredResult", "Payload"] +__all__ = [ + "PutDatum", + "GetDatum", + "ServingStorer", + "ServingStoreServer", + "StoredResult", + "Payload", + "ServingStoreAsyncServer", +] diff --git a/pynumaflow/servingstore/_dtypes.py b/pynumaflow/servingstore/_dtypes.py index 2d65722f..393f8768 100644 --- a/pynumaflow/servingstore/_dtypes.py +++ b/pynumaflow/servingstore/_dtypes.py @@ -1,6 +1,7 @@ from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import TypeVar, Awaitable +from typing import TypeVar +from collections.abc import Awaitable P = TypeVar("P", bound="Payload") @@ -44,9 +45,9 @@ class PutDatum: _payloads: list[Payload] def __init__( - self, - id_: str, - payloads: list[Payload], + self, + id_: str, + payloads: list[Payload], ): self._id = id_ self._payloads = payloads or [] @@ -83,8 +84,8 @@ class GetDatum: _id: str def __init__( - self, - id_: str, + self, + id_: str, ): self._id = id_ diff --git a/pynumaflow/servingstore/async_server.py b/pynumaflow/servingstore/async_server.py index 448556a3..6e4021cc 100644 --- a/pynumaflow/servingstore/async_server.py +++ b/pynumaflow/servingstore/async_server.py @@ -4,7 +4,9 @@ from pynumaflow._constants import ( MAX_MESSAGE_SIZE, NUM_THREADS_DEFAULT, - MAX_NUM_THREADS, SERVING_STORE_SOCK_PATH, SERVING_STORE_SERVER_INFO_FILE_PATH, + MAX_NUM_THREADS, + SERVING_STORE_SOCK_PATH, + SERVING_STORE_SERVER_INFO_FILE_PATH, ) from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION from pynumaflow.proto.serving import store_pb2_grpc diff --git a/pynumaflow/servingstore/servicer/async_servicer.py b/pynumaflow/servingstore/servicer/async_servicer.py index d9c28ed1..7a017c7d 100644 --- a/pynumaflow/servingstore/servicer/async_servicer.py +++ b/pynumaflow/servingstore/servicer/async_servicer.py @@ -2,7 +2,6 @@ from pynumaflow._constants import _LOGGER from pynumaflow.proto.serving import store_pb2_grpc, store_pb2 -from pynumaflow.proto.sourcer import source_pb2 from pynumaflow.servingstore._dtypes import ServingStoreCallable, PutDatum, Payload, GetDatum from pynumaflow.shared.server import handle_async_error from pynumaflow.types import NumaflowServicerContext @@ -22,7 +21,7 @@ def __init__(self, serving_store_instance: ServingStoreCallable): self.cleanup_coroutines = [] async def Put( - self, request: store_pb2.PutRequest, context: NumaflowServicerContext + self, request: store_pb2.PutRequest, context: NumaflowServicerContext ) -> store_pb2.PutResponse: """ Handles the Put function, processing incoming requests and sending responses. @@ -42,7 +41,7 @@ async def Put( return store_pb2.PutResponse(success=True) async def Get( - self, request: store_pb2.GetRequest, context: NumaflowServicerContext + self, request: store_pb2.GetRequest, context: NumaflowServicerContext ) -> store_pb2.GetResponse: """ Handles the Get function, processing incoming requests and sending responses. @@ -61,11 +60,10 @@ async def Get( return store_pb2.GetResponse(id=request.id, payloads=resp_payloads) async def IsReady( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> source_pb2.ReadyResponse: + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> store_pb2.ReadyResponse: """ IsReady is the heartbeat endpoint for gRPC. The pascal case function name comes from the proto source_pb2_grpc.py file. """ - return source_pb2.ReadyResponse(ready=True) - + return store_pb2.ReadyResponse(ready=True) diff --git a/tests/servingstore/test_async_serving_store.py b/tests/servingstore/test_async_serving_store.py new file mode 100644 index 00000000..a31b8eae --- /dev/null +++ b/tests/servingstore/test_async_serving_store.py @@ -0,0 +1,235 @@ +import asyncio +import logging +import threading +import unittest + +import grpc +from google.protobuf import empty_pb2 as _empty_pb2 +from grpc.aio._server import Server + +from pynumaflow import setup_logging +from pynumaflow.proto.serving import store_pb2_grpc +from pynumaflow.servingstore import ( + ServingStoreAsyncServer, + ServingStorer, + PutDatum, + Payload, + GetDatum, + StoredResult, +) + + +class AsyncInMemoryStore(ServingStorer): + def __init__(self): + self.store = {} + + async def put(self, datum: PutDatum): + req_id = datum.id + print("Received Put request for ", req_id) + if req_id not in self.store: + self.store[req_id] = [] + + cur_payloads = self.store[req_id] + for x in datum.payloads: + print(x) + cur_payloads.append(Payload(x.origin, x.value)) + self.store[req_id] = cur_payloads + + async def get(self, datum: GetDatum) -> StoredResult: + req_id = datum.id + print("Received Get request for ", req_id) + resp = [] + if req_id in self.store: + resp = self.store[req_id] + return StoredResult(id_=req_id, payloads=resp) + + +class AsyncErrInMemoryStore(ServingStorer): + def __init__(self): + self.store = {} + + async def put(self, datum: PutDatum): + req_id = datum.id + print("Received Put request for ", req_id) + if req_id not in self.store: + self.store[req_id] = [] + + cur_payloads = self.store[req_id] + for x in datum.payloads: + cur_payloads.append(Payload(x.origin, x.value)) + raise ValueError("something fishy") + self.store[req_id] = cur_payloads + + async def get(self, datum: GetDatum) -> StoredResult: + req_id = datum.id + print("Received Get request for ", req_id) + raise ValueError("get is fishy") + + +LOGGER = setup_logging(__name__) + +# if set to true, map handler will raise a `ValueError` exception. +raise_error_from_map = False + +server_port = "unix:///tmp/async_serving_store.sock" + +_s: Server = None +_channel = grpc.insecure_channel(server_port) +_loop = None + + +def startup_callable(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + +def NewAsyncStore(): + server = ServingStoreAsyncServer(serving_store_instance=AsyncInMemoryStore()) + udfs = server.servicer + return udfs + + +async def start_server(udfs): + server = grpc.aio.server() + store_pb2_grpc.add_ServingStoreServicer_to_server(udfs, server) + listen_addr = "unix:///tmp/async_serving_store.sock" + server.add_insecure_port(listen_addr) + logging.info("Starting server on %s", listen_addr) + global _s + _s = server + await server.start() + await server.wait_for_termination() + + +class TestAsyncServingStore(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + global _loop + loop = asyncio.new_event_loop() + _loop = loop + _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) + _thread.start() + udfs = NewAsyncStore() + asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) + while True: + try: + with grpc.insecure_channel(server_port) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") + LOGGER.error(e) + + @classmethod + def tearDownClass(cls) -> None: + try: + _loop.stop() + LOGGER.info("stopped the event loop") + except Exception as e: + LOGGER.error(e) + + # def test_read_source(self) -> None: + # with grpc.insecure_channel(server_port) as channel: + # stub = store_pb2_grpc.ServingStoreStub(channel) + # + # request = read_req_source_fn() + # generator_response = None + # try: + # generator_response = stub.Put(request=source_pb2.ReadRequest(request=request)) + # except grpc.RpcError as e: + # logging.error(e) + # + # counter = 0 + # # capture the output from the ReadFn generator and assert. + # for r in generator_response: + # counter += 1 + # self.assertEqual( + # bytes("payload:test_mock_message", encoding="utf-8"), + # r.result.payload, + # ) + # self.assertEqual( + # ["test_key"], + # r.result.keys, + # ) + # self.assertEqual( + # mock_offset().offset, + # r.result.offset.offset, + # ) + # self.assertEqual( + # mock_offset().partition_id, + # r.result.offset.partition_id, + # ) + # """Assert that the generator was called 10 times in the stream""" + # self.assertEqual(10, counter) + + def test_is_ready(self) -> None: + with grpc.insecure_channel(server_port) as channel: + stub = store_pb2_grpc.ServingStoreStub(channel) + + request = _empty_pb2.Empty() + response = None + try: + response = stub.IsReady(request=request) + except grpc.RpcError as e: + logging.error(e) + + self.assertTrue(response.ready) + + # def test_ack(self) -> None: + # with grpc.insecure_channel(server_port) as channel: + # stub = source_pb2_grpc.SourceStub(channel) + # request = ack_req_source_fn() + # try: + # response = stub.AckFn(request=source_pb2.AckRequest(request=request)) + # except grpc.RpcError as e: + # print(e) + # + # self.assertEqual(response, source_pb2.AckResponse()) + # + # def test_pending(self) -> None: + # with grpc.insecure_channel(server_port) as channel: + # stub = source_pb2_grpc.SourceStub(channel) + # request = _empty_pb2.Empty() + # response = None + # try: + # response = stub.PendingFn(request=request) + # except grpc.RpcError as e: + # logging.error(e) + # + # self.assertEqual(response.result.count, 10) + # + # def test_partitions(self) -> None: + # with grpc.insecure_channel(server_port) as channel: + # stub = source_pb2_grpc.SourceStub(channel) + # request = _empty_pb2.Empty() + # response = None + # try: + # response = stub.PartitionsFn(request=request) + # except grpc.RpcError as e: + # logging.error(e) + # + # self.assertEqual(response.result.partitions, mock_partitions()) + # + # def __stub(self): + # return source_pb2_grpc.SourceStub(_channel) + # + # def test_max_threads(self): + # class_instance = AsyncSource() + # # max cap at 16 + # server = SourceAsyncServer(sourcer_instance=class_instance, max_threads=32) + # self.assertEqual(server.max_threads, 16) + # + # # use argument provided + # server = SourceAsyncServer(sourcer_instance=class_instance, max_threads=5) + # self.assertEqual(server.max_threads, 5) + # + # # defaults to 4 + # server = SourceAsyncServer(sourcer_instance=class_instance) + # self.assertEqual(server.max_threads, 4) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/servingstore/test_async_source.py b/tests/servingstore/test_async_source.py deleted file mode 100644 index c40cf1f7..00000000 --- a/tests/servingstore/test_async_source.py +++ /dev/null @@ -1,190 +0,0 @@ -# import asyncio -# import logging -# import threading -# import unittest -# -# import grpc -# from google.protobuf import empty_pb2 as _empty_pb2 -# from grpc.aio._server import Server -# -# from pynumaflow import setup_logging -# from pynumaflow.proto.sourcer import source_pb2_grpc, source_pb2 -# from pynumaflow.sourcer import ( -# SourceAsyncServer, -# ) -# from tests.source.utils import ( -# mock_offset, -# read_req_source_fn, -# ack_req_source_fn, -# mock_partitions, -# AsyncSource, -# ) -# -# LOGGER = setup_logging(__name__) -# -# # if set to true, map handler will raise a `ValueError` exception. -# raise_error_from_map = False -# -# server_port = "unix:///tmp/async_source.sock" -# -# _s: Server = None -# _channel = grpc.insecure_channel(server_port) -# _loop = None -# -# -# def startup_callable(loop): -# asyncio.set_event_loop(loop) -# loop.run_forever() -# -# -# def NewAsyncSourcer(): -# class_instance = AsyncSource() -# server = SourceAsyncServer(sourcer_instance=class_instance) -# udfs = server.servicer -# return udfs -# -# -# async def start_server(udfs): -# server = grpc.aio.server() -# source_pb2_grpc.add_SourceServicer_to_server(udfs, server) -# listen_addr = "unix:///tmp/async_source.sock" -# server.add_insecure_port(listen_addr) -# logging.info("Starting server on %s", listen_addr) -# global _s -# _s = server -# await server.start() -# await server.wait_for_termination() -# -# -# class TestAsyncSourcer(unittest.TestCase): -# @classmethod -# def setUpClass(cls) -> None: -# global _loop -# loop = asyncio.new_event_loop() -# _loop = loop -# _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) -# _thread.start() -# udfs = NewAsyncSourcer() -# asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) -# while True: -# try: -# with grpc.insecure_channel(server_port) as channel: -# f = grpc.channel_ready_future(channel) -# f.result(timeout=10) -# if f.done(): -# break -# except grpc.FutureTimeoutError as e: -# LOGGER.error("error trying to connect to grpc server") -# LOGGER.error(e) -# -# @classmethod -# def tearDownClass(cls) -> None: -# try: -# _loop.stop() -# LOGGER.info("stopped the event loop") -# except Exception as e: -# LOGGER.error(e) -# -# def test_read_source(self) -> None: -# with grpc.insecure_channel(server_port) as channel: -# stub = source_pb2_grpc.SourceStub(channel) -# -# request = read_req_source_fn() -# generator_response = None -# try: -# generator_response = stub.ReadFn(request=source_pb2.ReadRequest(request=request)) -# except grpc.RpcError as e: -# logging.error(e) -# -# counter = 0 -# # capture the output from the ReadFn generator and assert. -# for r in generator_response: -# counter += 1 -# self.assertEqual( -# bytes("payload:test_mock_message", encoding="utf-8"), -# r.result.payload, -# ) -# self.assertEqual( -# ["test_key"], -# r.result.keys, -# ) -# self.assertEqual( -# mock_offset().offset, -# r.result.offset.offset, -# ) -# self.assertEqual( -# mock_offset().partition_id, -# r.result.offset.partition_id, -# ) -# """Assert that the generator was called 10 times in the stream""" -# self.assertEqual(10, counter) -# -# def test_is_ready(self) -> None: -# with grpc.insecure_channel(server_port) as channel: -# stub = source_pb2_grpc.SourceStub(channel) -# -# request = _empty_pb2.Empty() -# response = None -# try: -# response = stub.IsReady(request=request) -# except grpc.RpcError as e: -# logging.error(e) -# -# self.assertTrue(response.ready) -# -# def test_ack(self) -> None: -# with grpc.insecure_channel(server_port) as channel: -# stub = source_pb2_grpc.SourceStub(channel) -# request = ack_req_source_fn() -# try: -# response = stub.AckFn(request=source_pb2.AckRequest(request=request)) -# except grpc.RpcError as e: -# print(e) -# -# self.assertEqual(response, source_pb2.AckResponse()) -# -# def test_pending(self) -> None: -# with grpc.insecure_channel(server_port) as channel: -# stub = source_pb2_grpc.SourceStub(channel) -# request = _empty_pb2.Empty() -# response = None -# try: -# response = stub.PendingFn(request=request) -# except grpc.RpcError as e: -# logging.error(e) -# -# self.assertEqual(response.result.count, 10) -# -# def test_partitions(self) -> None: -# with grpc.insecure_channel(server_port) as channel: -# stub = source_pb2_grpc.SourceStub(channel) -# request = _empty_pb2.Empty() -# response = None -# try: -# response = stub.PartitionsFn(request=request) -# except grpc.RpcError as e: -# logging.error(e) -# -# self.assertEqual(response.result.partitions, mock_partitions()) -# -# def __stub(self): -# return source_pb2_grpc.SourceStub(_channel) -# -# def test_max_threads(self): -# class_instance = AsyncSource() -# # max cap at 16 -# server = SourceAsyncServer(sourcer_instance=class_instance, max_threads=32) -# self.assertEqual(server.max_threads, 16) -# -# # use argument provided -# server = SourceAsyncServer(sourcer_instance=class_instance, max_threads=5) -# self.assertEqual(server.max_threads, 5) -# -# # defaults to 4 -# server = SourceAsyncServer(sourcer_instance=class_instance) -# self.assertEqual(server.max_threads, 4) -# -# -# if __name__ == "__main__": -# logging.basicConfig(level=logging.DEBUG) -# unittest.main() \ No newline at end of file diff --git a/tests/servingstore/test_serving_store_server.py b/tests/servingstore/test_serving_store_server.py index 9ef50781..02d0f831 100644 --- a/tests/servingstore/test_serving_store_server.py +++ b/tests/servingstore/test_serving_store_server.py @@ -172,6 +172,23 @@ def test_get_message(self): where we expect the no_broadcast flag to be False and the message value to be the mock_message. """ + val = bytes("test_put", encoding="utf-8") + request_put = store_pb2.PutRequest( + id="abc", + payloads=[store_pb2.Payload(origin="abc1", value=val)], + ) + method = self.test_server.invoke_unary_unary( + method_descriptor=( + store_pb2.DESCRIPTOR.services_by_name["ServingStore"].methods_by_name["Put"] + ), + invocation_metadata={ + ("this_metadata_will_be_skipped", "test_ignore"), + }, + request=request_put, + timeout=1, + ) + response, metadata, code, details = method.termination() + request = store_pb2.GetRequest( id="abc", ) @@ -187,7 +204,9 @@ def test_get_message(self): timeout=1, ) response, metadata, code, details = method.termination() - self.assertEqual(len(response.payloads), 0) + self.assertEqual(len(response.payloads), 1) + self.assertEqual(response.payloads[0].value, val) + self.assertEqual(response.payloads[0].origin, "abc1") self.assertEqual(code, StatusCode.OK) def test_serving_store_get_err(self): From 98995b70319116fde2630fe80e0e8940f88e5faf Mon Sep 17 00:00:00 2001 From: Sidhant Kohli Date: Sun, 23 Feb 2025 16:38:44 -0800 Subject: [PATCH 7/8] add async Signed-off-by: Sidhant Kohli --- tests/servingstore/test_async_serving_err.py | 142 ++++++++++++++++ .../servingstore/test_async_serving_store.py | 152 +++++------------- 2 files changed, 186 insertions(+), 108 deletions(-) create mode 100644 tests/servingstore/test_async_serving_err.py diff --git a/tests/servingstore/test_async_serving_err.py b/tests/servingstore/test_async_serving_err.py new file mode 100644 index 00000000..c3bc6071 --- /dev/null +++ b/tests/servingstore/test_async_serving_err.py @@ -0,0 +1,142 @@ +import asyncio +import logging +import threading +import unittest +from unittest.mock import patch + +import grpc +from grpc.aio._server import Server + +from pynumaflow import setup_logging +from pynumaflow.proto.serving import store_pb2_grpc, store_pb2 +from pynumaflow.servingstore import ( + ServingStorer, + PutDatum, + Payload, + GetDatum, + StoredResult, + ServingStoreAsyncServer, +) +from tests.testing_utils import mock_terminate_on_stop + +LOGGER = setup_logging(__name__) + +_s: Server = None +server_port = "unix:///tmp/async_serving_store_err.sock" +_channel = grpc.insecure_channel(server_port) +_loop = None + + +class AsyncErrInMemoryStore(ServingStorer): + def __init__(self): + self.store = {} + + async def put(self, datum: PutDatum): + req_id = datum.id + print("Received Put request for ", req_id) + if req_id not in self.store: + self.store[req_id] = [] + + cur_payloads = self.store[req_id] + for x in datum.payloads: + cur_payloads.append(Payload(x.origin, x.value)) + raise ValueError("something fishy") + + async def get(self, datum: GetDatum) -> StoredResult: + req_id = datum.id + print("Received Get request for ", req_id) + raise ValueError("get is fishy") + + +def startup_callable(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + +async def start_server(): + server = grpc.aio.server() + class_instance = AsyncErrInMemoryStore() + server_instance = ServingStoreAsyncServer(serving_store_instance=class_instance) + udfs = server_instance.servicer + store_pb2_grpc.add_ServingStoreServicer_to_server(udfs, server) + listen_addr = "unix:///tmp/async_serving_store_err.sock" + server.add_insecure_port(listen_addr) + logging.info("Starting server on %s", listen_addr) + global _s + _s = server + await server.start() + await server.wait_for_termination() + + +# We are mocking the terminate function from the psutil to not exit the program during testing +@patch("psutil.Process.kill", mock_terminate_on_stop) +class TestAsyncServingStoreErrorScenario(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + global _loop + loop = asyncio.new_event_loop() + _loop = loop + _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) + _thread.start() + asyncio.run_coroutine_threadsafe(start_server(), loop=loop) + while True: + try: + with grpc.insecure_channel("unix:///tmp/async_serving_store_err.sock") as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") + LOGGER.error(e) + + @classmethod + def tearDownClass(cls) -> None: + try: + _loop.stop() + LOGGER.info("stopped the event loop") + except Exception as e: + LOGGER.error(e) + + def test_put_error(self) -> None: + grpc_exception = None + with grpc.insecure_channel(server_port) as channel: + stub = store_pb2_grpc.ServingStoreStub(channel) + val = bytes("test_put", encoding="utf-8") + request = store_pb2.PutRequest( + id="abc", + payloads=[store_pb2.Payload(origin="abc1", value=val)], + ) + try: + _ = stub.Put(request=request) + except BaseException as e: + self.assertTrue("something fishy" in e.details()) + self.assertEqual(grpc.StatusCode.UNKNOWN, e.code()) + grpc_exception = e + + self.assertIsNotNone(grpc_exception) + + def test_get_error(self) -> None: + grpc_exception = None + with grpc.insecure_channel(server_port) as channel: + stub = store_pb2_grpc.ServingStoreStub(channel) + request = store_pb2.GetRequest( + id="abc", + ) + try: + _ = stub.Get(request=request) + except BaseException as e: + self.assertTrue("get is fishy" in e.details()) + self.assertEqual(grpc.StatusCode.UNKNOWN, e.code()) + grpc_exception = e + + self.assertIsNotNone(grpc_exception) + + def test_invalid_server_type(self) -> None: + with self.assertRaises(TypeError): + ServingStoreAsyncServer() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/servingstore/test_async_serving_store.py b/tests/servingstore/test_async_serving_store.py index a31b8eae..e185b0d8 100644 --- a/tests/servingstore/test_async_serving_store.py +++ b/tests/servingstore/test_async_serving_store.py @@ -8,7 +8,7 @@ from grpc.aio._server import Server from pynumaflow import setup_logging -from pynumaflow.proto.serving import store_pb2_grpc +from pynumaflow.proto.serving import store_pb2_grpc, store_pb2 from pynumaflow.servingstore import ( ServingStoreAsyncServer, ServingStorer, @@ -44,28 +44,6 @@ async def get(self, datum: GetDatum) -> StoredResult: return StoredResult(id_=req_id, payloads=resp) -class AsyncErrInMemoryStore(ServingStorer): - def __init__(self): - self.store = {} - - async def put(self, datum: PutDatum): - req_id = datum.id - print("Received Put request for ", req_id) - if req_id not in self.store: - self.store[req_id] = [] - - cur_payloads = self.store[req_id] - for x in datum.payloads: - cur_payloads.append(Payload(x.origin, x.value)) - raise ValueError("something fishy") - self.store[req_id] = cur_payloads - - async def get(self, datum: GetDatum) -> StoredResult: - req_id = datum.id - print("Received Get request for ", req_id) - raise ValueError("get is fishy") - - LOGGER = setup_logging(__name__) # if set to true, map handler will raise a `ValueError` exception. @@ -130,40 +108,6 @@ def tearDownClass(cls) -> None: except Exception as e: LOGGER.error(e) - # def test_read_source(self) -> None: - # with grpc.insecure_channel(server_port) as channel: - # stub = store_pb2_grpc.ServingStoreStub(channel) - # - # request = read_req_source_fn() - # generator_response = None - # try: - # generator_response = stub.Put(request=source_pb2.ReadRequest(request=request)) - # except grpc.RpcError as e: - # logging.error(e) - # - # counter = 0 - # # capture the output from the ReadFn generator and assert. - # for r in generator_response: - # counter += 1 - # self.assertEqual( - # bytes("payload:test_mock_message", encoding="utf-8"), - # r.result.payload, - # ) - # self.assertEqual( - # ["test_key"], - # r.result.keys, - # ) - # self.assertEqual( - # mock_offset().offset, - # r.result.offset.offset, - # ) - # self.assertEqual( - # mock_offset().partition_id, - # r.result.offset.partition_id, - # ) - # """Assert that the generator was called 10 times in the stream""" - # self.assertEqual(10, counter) - def test_is_ready(self) -> None: with grpc.insecure_channel(server_port) as channel: stub = store_pb2_grpc.ServingStoreStub(channel) @@ -177,57 +121,49 @@ def test_is_ready(self) -> None: self.assertTrue(response.ready) - # def test_ack(self) -> None: - # with grpc.insecure_channel(server_port) as channel: - # stub = source_pb2_grpc.SourceStub(channel) - # request = ack_req_source_fn() - # try: - # response = stub.AckFn(request=source_pb2.AckRequest(request=request)) - # except grpc.RpcError as e: - # print(e) - # - # self.assertEqual(response, source_pb2.AckResponse()) - # - # def test_pending(self) -> None: - # with grpc.insecure_channel(server_port) as channel: - # stub = source_pb2_grpc.SourceStub(channel) - # request = _empty_pb2.Empty() - # response = None - # try: - # response = stub.PendingFn(request=request) - # except grpc.RpcError as e: - # logging.error(e) - # - # self.assertEqual(response.result.count, 10) - # - # def test_partitions(self) -> None: - # with grpc.insecure_channel(server_port) as channel: - # stub = source_pb2_grpc.SourceStub(channel) - # request = _empty_pb2.Empty() - # response = None - # try: - # response = stub.PartitionsFn(request=request) - # except grpc.RpcError as e: - # logging.error(e) - # - # self.assertEqual(response.result.partitions, mock_partitions()) - # - # def __stub(self): - # return source_pb2_grpc.SourceStub(_channel) - # - # def test_max_threads(self): - # class_instance = AsyncSource() - # # max cap at 16 - # server = SourceAsyncServer(sourcer_instance=class_instance, max_threads=32) - # self.assertEqual(server.max_threads, 16) - # - # # use argument provided - # server = SourceAsyncServer(sourcer_instance=class_instance, max_threads=5) - # self.assertEqual(server.max_threads, 5) - # - # # defaults to 4 - # server = SourceAsyncServer(sourcer_instance=class_instance) - # self.assertEqual(server.max_threads, 4) + def test_put_get(self) -> None: + val = bytes("test_get", encoding="utf-8") + with grpc.insecure_channel(server_port) as channel: + stub = store_pb2_grpc.ServingStoreStub(channel) + response = None + request = store_pb2.PutRequest( + id="abc", + payloads=[store_pb2.Payload(origin="abc1", value=val)], + ) + try: + response = stub.Put(request=request) + except grpc.RpcError as e: + logging.error(e) + + self.assertEqual(True, response.success) + + stub = store_pb2_grpc.ServingStoreStub(channel) + response_get = None + request = store_pb2.GetRequest( + id="abc", + ) + try: + response_get = stub.Get(request=request) + except grpc.RpcError as e: + logging.error(e) + + self.assertEqual(len(response_get.payloads), 1) + self.assertEqual(response_get.payloads[0].value, val) + self.assertEqual(response_get.payloads[0].origin, "abc1") + + def test_max_threads(self): + class_instance = AsyncInMemoryStore() + # max cap at 16 + server = ServingStoreAsyncServer(serving_store_instance=class_instance, max_threads=32) + self.assertEqual(server.max_threads, 16) + + # use argument provided + server = ServingStoreAsyncServer(serving_store_instance=class_instance, max_threads=5) + self.assertEqual(server.max_threads, 5) + + # defaults to 4 + server = ServingStoreAsyncServer(serving_store_instance=class_instance) + self.assertEqual(server.max_threads, 4) if __name__ == "__main__": From d2023b4548c75c569a16d7e67e43a13b13144e9f Mon Sep 17 00:00:00 2001 From: Sidhant Kohli Date: Sun, 23 Feb 2025 17:13:35 -0800 Subject: [PATCH 8/8] add sink response serving Signed-off-by: Sidhant Kohli --- pynumaflow/proto/sinker/sink.proto | 1 + pynumaflow/proto/sinker/sink_pb2.py | 14 +++++++------- pynumaflow/proto/sinker/sink_pb2.pyi | 5 ++++- pynumaflow/sinker/_dtypes.py | 15 +++++++++++---- pynumaflow/sinker/servicer/utils.py | 4 +++- tests/sink/test_responses.py | 16 +++++++++++----- 6 files changed, 37 insertions(+), 18 deletions(-) diff --git a/pynumaflow/proto/sinker/sink.proto b/pynumaflow/proto/sinker/sink.proto index 71dbb418..555db582 100644 --- a/pynumaflow/proto/sinker/sink.proto +++ b/pynumaflow/proto/sinker/sink.proto @@ -77,6 +77,7 @@ message SinkResponse { Status status = 2; // err_msg is the error message, set it if success is set to false. string err_msg = 3; + optional bytes serve_response = 4; } repeated Result results = 1; optional Handshake handshake = 2; diff --git a/pynumaflow/proto/sinker/sink_pb2.py b/pynumaflow/proto/sinker/sink_pb2.py index 27082a0e..3b9a2997 100644 --- a/pynumaflow/proto/sinker/sink_pb2.py +++ b/pynumaflow/proto/sinker/sink_pb2.py @@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\nsink.proto\x12\x07sink.v1\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1fgoogle/protobuf/timestamp.proto"\xa3\x03\n\x0bSinkRequest\x12-\n\x07request\x18\x01 \x01(\x0b\x32\x1c.sink.v1.SinkRequest.Request\x12+\n\x06status\x18\x02 \x01(\x0b\x32\x1b.sink.v1.TransmissionStatus\x12*\n\thandshake\x18\x03 \x01(\x0b\x32\x12.sink.v1.HandshakeH\x00\x88\x01\x01\x1a\xfd\x01\n\x07Request\x12\x0c\n\x04keys\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x12.\n\nevent_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12-\n\twatermark\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\n\n\x02id\x18\x05 \x01(\t\x12:\n\x07headers\x18\x06 \x03(\x0b\x32).sink.v1.SinkRequest.Request.HeadersEntry\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\x0c\n\n_handshake"\x18\n\tHandshake\x12\x0b\n\x03sot\x18\x01 \x01(\x08"\x1e\n\rReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08"!\n\x12TransmissionStatus\x12\x0b\n\x03\x65ot\x18\x01 \x01(\x08"\xfc\x01\n\x0cSinkResponse\x12-\n\x07results\x18\x01 \x03(\x0b\x32\x1c.sink.v1.SinkResponse.Result\x12*\n\thandshake\x18\x02 \x01(\x0b\x32\x12.sink.v1.HandshakeH\x00\x88\x01\x01\x12\x30\n\x06status\x18\x03 \x01(\x0b\x32\x1b.sink.v1.TransmissionStatusH\x01\x88\x01\x01\x1a\x46\n\x06Result\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1f\n\x06status\x18\x02 \x01(\x0e\x32\x0f.sink.v1.Status\x12\x0f\n\x07\x65rr_msg\x18\x03 \x01(\tB\x0c\n\n_handshakeB\t\n\x07_status*0\n\x06Status\x12\x0b\n\x07SUCCESS\x10\x00\x12\x0b\n\x07\x46\x41ILURE\x10\x01\x12\x0c\n\x08\x46\x41LLBACK\x10\x02\x32|\n\x04Sink\x12\x39\n\x06SinkFn\x12\x14.sink.v1.SinkRequest\x1a\x15.sink.v1.SinkResponse(\x01\x30\x01\x12\x39\n\x07IsReady\x12\x16.google.protobuf.Empty\x1a\x16.sink.v1.ReadyResponseb\x06proto3' + b'\n\nsink.proto\x12\x07sink.v1\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1fgoogle/protobuf/timestamp.proto"\xa3\x03\n\x0bSinkRequest\x12-\n\x07request\x18\x01 \x01(\x0b\x32\x1c.sink.v1.SinkRequest.Request\x12+\n\x06status\x18\x02 \x01(\x0b\x32\x1b.sink.v1.TransmissionStatus\x12*\n\thandshake\x18\x03 \x01(\x0b\x32\x12.sink.v1.HandshakeH\x00\x88\x01\x01\x1a\xfd\x01\n\x07Request\x12\x0c\n\x04keys\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x12.\n\nevent_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12-\n\twatermark\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\n\n\x02id\x18\x05 \x01(\t\x12:\n\x07headers\x18\x06 \x03(\x0b\x32).sink.v1.SinkRequest.Request.HeadersEntry\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\x0c\n\n_handshake"\x18\n\tHandshake\x12\x0b\n\x03sot\x18\x01 \x01(\x08"\x1e\n\rReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08"!\n\x12TransmissionStatus\x12\x0b\n\x03\x65ot\x18\x01 \x01(\x08"\xac\x02\n\x0cSinkResponse\x12-\n\x07results\x18\x01 \x03(\x0b\x32\x1c.sink.v1.SinkResponse.Result\x12*\n\thandshake\x18\x02 \x01(\x0b\x32\x12.sink.v1.HandshakeH\x00\x88\x01\x01\x12\x30\n\x06status\x18\x03 \x01(\x0b\x32\x1b.sink.v1.TransmissionStatusH\x01\x88\x01\x01\x1av\n\x06Result\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1f\n\x06status\x18\x02 \x01(\x0e\x32\x0f.sink.v1.Status\x12\x0f\n\x07\x65rr_msg\x18\x03 \x01(\t\x12\x1b\n\x0eserve_response\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x11\n\x0f_serve_responseB\x0c\n\n_handshakeB\t\n\x07_status*0\n\x06Status\x12\x0b\n\x07SUCCESS\x10\x00\x12\x0b\n\x07\x46\x41ILURE\x10\x01\x12\x0c\n\x08\x46\x41LLBACK\x10\x02\x32|\n\x04Sink\x12\x39\n\x06SinkFn\x12\x14.sink.v1.SinkRequest\x1a\x15.sink.v1.SinkResponse(\x01\x30\x01\x12\x39\n\x07IsReady\x12\x16.google.protobuf.Empty\x1a\x16.sink.v1.ReadyResponseb\x06proto3' ) _globals = globals() @@ -28,8 +28,8 @@ DESCRIPTOR._options = None _globals["_SINKREQUEST_REQUEST_HEADERSENTRY"]._options = None _globals["_SINKREQUEST_REQUEST_HEADERSENTRY"]._serialized_options = b"8\001" - _globals["_STATUS"]._serialized_start = 855 - _globals["_STATUS"]._serialized_end = 903 + _globals["_STATUS"]._serialized_start = 903 + _globals["_STATUS"]._serialized_end = 951 _globals["_SINKREQUEST"]._serialized_start = 86 _globals["_SINKREQUEST"]._serialized_end = 505 _globals["_SINKREQUEST_REQUEST"]._serialized_start = 238 @@ -43,9 +43,9 @@ _globals["_TRANSMISSIONSTATUS"]._serialized_start = 565 _globals["_TRANSMISSIONSTATUS"]._serialized_end = 598 _globals["_SINKRESPONSE"]._serialized_start = 601 - _globals["_SINKRESPONSE"]._serialized_end = 853 + _globals["_SINKRESPONSE"]._serialized_end = 901 _globals["_SINKRESPONSE_RESULT"]._serialized_start = 758 - _globals["_SINKRESPONSE_RESULT"]._serialized_end = 828 - _globals["_SINK"]._serialized_start = 905 - _globals["_SINK"]._serialized_end = 1029 + _globals["_SINKRESPONSE_RESULT"]._serialized_end = 876 + _globals["_SINK"]._serialized_start = 953 + _globals["_SINK"]._serialized_end = 1077 # @@protoc_insertion_point(module_scope) diff --git a/pynumaflow/proto/sinker/sink_pb2.pyi b/pynumaflow/proto/sinker/sink_pb2.pyi index 78926321..8d8ca6f6 100644 --- a/pynumaflow/proto/sinker/sink_pb2.pyi +++ b/pynumaflow/proto/sinker/sink_pb2.pyi @@ -93,18 +93,21 @@ class SinkResponse(_message.Message): __slots__ = ("results", "handshake", "status") class Result(_message.Message): - __slots__ = ("id", "status", "err_msg") + __slots__ = ("id", "status", "err_msg", "serve_response") ID_FIELD_NUMBER: _ClassVar[int] STATUS_FIELD_NUMBER: _ClassVar[int] ERR_MSG_FIELD_NUMBER: _ClassVar[int] + SERVE_RESPONSE_FIELD_NUMBER: _ClassVar[int] id: str status: Status err_msg: str + serve_response: bytes def __init__( self, id: _Optional[str] = ..., status: _Optional[_Union[Status, str]] = ..., err_msg: _Optional[str] = ..., + serve_response: _Optional[bytes] = ..., ) -> None: ... RESULTS_FIELD_NUMBER: _ClassVar[int] HANDSHAKE_FIELD_NUMBER: _ClassVar[int] diff --git a/pynumaflow/sinker/_dtypes.py b/pynumaflow/sinker/_dtypes.py index c90f1f2e..971b7006 100644 --- a/pynumaflow/sinker/_dtypes.py +++ b/pynumaflow/sinker/_dtypes.py @@ -26,26 +26,33 @@ class Response: success: bool err: Optional[str] fallback: bool + serve_response: Optional[bytes] - __slots__ = ("id", "success", "err", "fallback") + __slots__ = ("id", "success", "err", "fallback", "serve_response") # as_success creates a successful Response with the given id. # The Success field is set to true. @classmethod def as_success(cls: type[R], id_: str) -> R: - return Response(id=id_, success=True, err=None, fallback=False) + return Response(id=id_, success=True, err=None, fallback=False, serve_response=None) # as_failure creates a failed Response with the given id and error message. # The success field is set to false and the err field is set to the provided error message. @classmethod def as_failure(cls: type[R], id_: str, err_msg: str) -> R: - return Response(id=id_, success=False, err=err_msg, fallback=False) + return Response(id=id_, success=False, err=err_msg, fallback=False, serve_response=None) # as_fallback creates a Response with the fallback field set to true. # This indicates that the message should be sent to the fallback sink. @classmethod def as_fallback(cls: type[R], id_: str) -> R: - return Response(id=id_, fallback=True, err=None, success=False) + return Response(id=id_, fallback=True, err=None, success=False, serve_response=None) + + # as_serving_response creates a Response with the serve_response field set to + # value of the result to be sent back from the serving sink. + @classmethod + def as_serving_response(cls: type[R], id_: str, result: bytes) -> R: + return Response(id=id_, fallback=False, err=None, success=True, serve_response=result) class Responses(Sequence[R]): diff --git a/pynumaflow/sinker/servicer/utils.py b/pynumaflow/sinker/servicer/utils.py index e3d648c2..e34cdecc 100644 --- a/pynumaflow/sinker/servicer/utils.py +++ b/pynumaflow/sinker/servicer/utils.py @@ -28,7 +28,9 @@ def build_sink_response(rspn: Response) -> sink_pb2.SinkResponse.Result: """ rid = rspn.id if rspn.success: - return sink_pb2.SinkResponse.Result(id=rid, status=sink_pb2.Status.SUCCESS) + return sink_pb2.SinkResponse.Result( + id=rid, status=sink_pb2.Status.SUCCESS, serve_response=rspn.serve_response + ) elif rspn.fallback: return sink_pb2.SinkResponse.Result(id=rid, status=sink_pb2.Status.FALLBACK) else: diff --git a/tests/sink/test_responses.py b/tests/sink/test_responses.py index 118570d5..f92b4ee7 100644 --- a/tests/sink/test_responses.py +++ b/tests/sink/test_responses.py @@ -29,7 +29,10 @@ def setUp(self) -> None: def test_responses(self): self.resps.append(Response.as_success("4")) - self.assertEqual(4, len(self.resps)) + self.resps.append( + Response.as_serving_response("6", result=bytes("test_put", encoding="utf-8")) + ) + self.assertEqual(5, len(self.resps)) for resp in self.resps: self.assertIsInstance(resp, Response) @@ -38,12 +41,15 @@ def test_responses(self): self.assertEqual(self.resps[1].id, "3") self.assertEqual(self.resps[2].id, "5") self.assertEqual(self.resps[3].id, "4") + self.assertEqual(self.resps[4].id, "6") self.assertEqual( - "[Response(id='2', success=True, err=None, fallback=False), " - "Response(id='3', success=False, err='RuntimeError encountered!', fallback=False), " - "Response(id='5', success=False, err=None, fallback=True), " - "Response(id='4', success=True, err=None, fallback=False)]", + "[Response(id='2', success=True, err=None, fallback=False, serve_response=None), " + "Response(id='3', success=False, err='RuntimeError encountered!', " + "fallback=False, serve_response=None), " + "Response(id='5', success=False, err=None, fallback=True, serve_response=None), " + "Response(id='4', success=True, err=None, fallback=False, serve_response=None), " + "Response(id='6', success=True, err=None, fallback=False, serve_response=b'test_put')]", repr(self.resps), )