From 4c6cecdaad5015ee59fe2552b5ec9485881054c3 Mon Sep 17 00:00:00 2001 From: Peilin Yu Date: Thu, 22 Feb 2024 13:15:39 -0500 Subject: [PATCH 1/6] Add: Compatibility with Newer Protobuf Version --- alfred/fm/remote/protos/query_pb2.py | 36 +++++++++++++++------------- requirements.txt | 6 ++--- setup.py | 2 +- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/alfred/fm/remote/protos/query_pb2.py b/alfred/fm/remote/protos/query_pb2.py index 278e33e..531774f 100644 --- a/alfred/fm/remote/protos/query_pb2.py +++ b/alfred/fm/remote/protos/query_pb2.py @@ -1,32 +1,34 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: query.proto +# Protobuf Python Version: 4.25.2 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder - # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0bquery.proto\x12\x05unary"c\n\nRunRequest\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x16\n\tcandidate\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06kwargs\x18\x03 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_candidateB\t\n\x07_kwargs"\x83\x01\n\x0bRunResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0e\n\x06ranked\x18\x02 \x01(\x08\x12\x0f\n\x07success\x18\x03 \x01(\x08\x12\x12\n\x05logit\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x16\n\tembedding\x18\x05 \x01(\x0cH\x01\x88\x01\x01\x42\x08\n\x06_logitB\x0c\n\n_embedding"S\n\rEncodeRequest\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x11\n\treduction\x18\x03 \x01(\t\x12\x13\n\x06kwargs\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_kwargs"4\n\x0e\x45ncodeResponse\x12\x11\n\tembedding\x18\x01 \x01(\x0c\x12\x0f\n\x07success\x18\x02 \x01(\x08\x32\x7f\n\x0cQueryService\x12;\n\x06\x45ncode\x12\x14.unary.EncodeRequest\x1a\x15.unary.EncodeResponse"\x00(\x01\x30\x01\x12\x32\n\x03Run\x12\x11.unary.RunRequest\x1a\x12.unary.RunResponse"\x00(\x01\x30\x01\x62\x06proto3' -) -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "query_pb2", globals()) + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0bquery.proto\x12\x05unary\"c\n\nRunRequest\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x16\n\tcandidate\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06kwargs\x18\x03 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_candidateB\t\n\x07_kwargs\"\x83\x01\n\x0bRunResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0e\n\x06ranked\x18\x02 \x01(\x08\x12\x0f\n\x07success\x18\x03 \x01(\x08\x12\x12\n\x05logit\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x16\n\tembedding\x18\x05 \x01(\x0cH\x01\x88\x01\x01\x42\x08\n\x06_logitB\x0c\n\n_embedding\"S\n\rEncodeRequest\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x11\n\treduction\x18\x03 \x01(\t\x12\x13\n\x06kwargs\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_kwargs\"4\n\x0e\x45ncodeResponse\x12\x11\n\tembedding\x18\x01 \x01(\x0c\x12\x0f\n\x07success\x18\x02 \x01(\x08\x32\x7f\n\x0cQueryService\x12;\n\x06\x45ncode\x12\x14.unary.EncodeRequest\x1a\x15.unary.EncodeResponse\"\x00(\x01\x30\x01\x12\x32\n\x03Run\x12\x11.unary.RunRequest\x1a\x12.unary.RunResponse\"\x00(\x01\x30\x01\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'query_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _RUNREQUEST._serialized_start = 22 - _RUNREQUEST._serialized_end = 121 - _RUNRESPONSE._serialized_start = 124 - _RUNRESPONSE._serialized_end = 255 - _ENCODEREQUEST._serialized_start = 257 - _ENCODEREQUEST._serialized_end = 340 - _ENCODERESPONSE._serialized_start = 342 - _ENCODERESPONSE._serialized_end = 394 - _QUERYSERVICE._serialized_start = 396 - _QUERYSERVICE._serialized_end = 523 + DESCRIPTOR._options = None + _globals['_RUNREQUEST']._serialized_start=22 + _globals['_RUNREQUEST']._serialized_end=121 + _globals['_RUNRESPONSE']._serialized_start=124 + _globals['_RUNRESPONSE']._serialized_end=255 + _globals['_ENCODEREQUEST']._serialized_start=257 + _globals['_ENCODEREQUEST']._serialized_end=340 + _globals['_ENCODERESPONSE']._serialized_start=342 + _globals['_ENCODERESPONSE']._serialized_end=394 + _globals['_QUERYSERVICE']._serialized_start=396 + _globals['_QUERYSERVICE']._serialized_end=523 # @@protoc_insertion_point(module_scope) diff --git a/requirements.txt b/requirements.txt index 035d1ce..e5de385 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ -datasets~=2.4.0 paramiko>=2.7.2 pyarrow>=3.0.0 torch>=1.8.0 +datasets transformers accelerate numpy>=1.21.0 @@ -12,8 +12,8 @@ arrow>=0.13.1 setuptools>=58.0.4 scipy>=1.7.1 -grpcio==1.48.1 -protobuf==3.20.0 +grpcio==1.62.0 +protobuf==4.25.2 sentencepiece deepspeed>=0.8.2 diff --git a/setup.py b/setup.py index 58c4e3f..e2902e8 100644 --- a/setup.py +++ b/setup.py @@ -8,5 +8,5 @@ author_email='peilin_yu@brown.edu', description='Toolkit for Prompted Weak Supervisions', packages=find_packages(), - install_requires=['numpy', 'scipy', 'torch', 'tqdm', 'torchvision', "paramiko>=2.7.2", "pyarrow>=3.0.0", "grpcio==1.48.1", "protobuf==3.20.0"], + install_requires=['numpy', 'scipy', 'torch', 'tqdm', 'torchvision', "paramiko>=2.7.2", "pyarrow>=3.0.0", "grpcio", "protobuf<5"], ) From e5447161ca82651382d44c9fa5da58b7cd95db44 Mon Sep 17 00:00:00 2001 From: Peilin Yu Date: Thu, 22 Feb 2024 13:40:29 -0500 Subject: [PATCH 2/6] Add: Compatibility with Newer Protobuf Version --- alfred/fm/remote/protos/query_pb2.py | 29 +++-- alfred/fm/remote/protos/query_pb2.pyi | 34 +----- alfred/fm/remote/protos/query_pb2_grpc.py | 129 ++++++++-------------- 3 files changed, 68 insertions(+), 124 deletions(-) diff --git a/alfred/fm/remote/protos/query_pb2.py b/alfred/fm/remote/protos/query_pb2.py index 531774f..ab4624c 100644 --- a/alfred/fm/remote/protos/query_pb2.py +++ b/alfred/fm/remote/protos/query_pb2.py @@ -1,12 +1,11 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: query.proto -# Protobuf Python Version: 4.25.2 """Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -16,19 +15,19 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0bquery.proto\x12\x05unary\"c\n\nRunRequest\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x16\n\tcandidate\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06kwargs\x18\x03 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_candidateB\t\n\x07_kwargs\"\x83\x01\n\x0bRunResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0e\n\x06ranked\x18\x02 \x01(\x08\x12\x0f\n\x07success\x18\x03 \x01(\x08\x12\x12\n\x05logit\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x16\n\tembedding\x18\x05 \x01(\x0cH\x01\x88\x01\x01\x42\x08\n\x06_logitB\x0c\n\n_embedding\"S\n\rEncodeRequest\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x11\n\treduction\x18\x03 \x01(\t\x12\x13\n\x06kwargs\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_kwargs\"4\n\x0e\x45ncodeResponse\x12\x11\n\tembedding\x18\x01 \x01(\x0c\x12\x0f\n\x07success\x18\x02 \x01(\x08\x32\x7f\n\x0cQueryService\x12;\n\x06\x45ncode\x12\x14.unary.EncodeRequest\x1a\x15.unary.EncodeResponse\"\x00(\x01\x30\x01\x12\x32\n\x03Run\x12\x11.unary.RunRequest\x1a\x12.unary.RunResponse\"\x00(\x01\x30\x01\x62\x06proto3') -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'query_pb2', _globals) +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'query_pb2', globals()) if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None - _globals['_RUNREQUEST']._serialized_start=22 - _globals['_RUNREQUEST']._serialized_end=121 - _globals['_RUNRESPONSE']._serialized_start=124 - _globals['_RUNRESPONSE']._serialized_end=255 - _globals['_ENCODEREQUEST']._serialized_start=257 - _globals['_ENCODEREQUEST']._serialized_end=340 - _globals['_ENCODERESPONSE']._serialized_start=342 - _globals['_ENCODERESPONSE']._serialized_end=394 - _globals['_QUERYSERVICE']._serialized_start=396 - _globals['_QUERYSERVICE']._serialized_end=523 + _RUNREQUEST._serialized_start=22 + _RUNREQUEST._serialized_end=121 + _RUNRESPONSE._serialized_start=124 + _RUNRESPONSE._serialized_end=255 + _ENCODEREQUEST._serialized_start=257 + _ENCODEREQUEST._serialized_end=340 + _ENCODERESPONSE._serialized_start=342 + _ENCODERESPONSE._serialized_end=394 + _QUERYSERVICE._serialized_start=396 + _QUERYSERVICE._serialized_end=523 # @@protoc_insertion_point(module_scope) diff --git a/alfred/fm/remote/protos/query_pb2.pyi b/alfred/fm/remote/protos/query_pb2.pyi index 8991e02..a473209 100644 --- a/alfred/fm/remote/protos/query_pb2.pyi +++ b/alfred/fm/remote/protos/query_pb2.pyi @@ -1,7 +1,6 @@ -from typing import ClassVar as _ClassVar, Optional as _Optional - from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Optional as _Optional DESCRIPTOR: _descriptor.FileDescriptor @@ -13,13 +12,7 @@ class EncodeRequest(_message.Message): kwargs: str message: str reduction: str - - def __init__( - self, - message: _Optional[str] = ..., - reduction: _Optional[str] = ..., - kwargs: _Optional[str] = ..., - ) -> None: ... + def __init__(self, message: _Optional[str] = ..., reduction: _Optional[str] = ..., kwargs: _Optional[str] = ...) -> None: ... class EncodeResponse(_message.Message): __slots__ = ["embedding", "success"] @@ -27,10 +20,7 @@ class EncodeResponse(_message.Message): SUCCESS_FIELD_NUMBER: _ClassVar[int] embedding: bytes success: bool - - def __init__( - self, embedding: _Optional[bytes] = ..., success: bool = ... - ) -> None: ... + def __init__(self, embedding: _Optional[bytes] = ..., success: bool = ...) -> None: ... class RunRequest(_message.Message): __slots__ = ["candidate", "kwargs", "message"] @@ -40,13 +30,7 @@ class RunRequest(_message.Message): candidate: str kwargs: str message: str - - def __init__( - self, - message: _Optional[str] = ..., - candidate: _Optional[str] = ..., - kwargs: _Optional[str] = ..., - ) -> None: ... + def __init__(self, message: _Optional[str] = ..., candidate: _Optional[str] = ..., kwargs: _Optional[str] = ...) -> None: ... class RunResponse(_message.Message): __slots__ = ["embedding", "logit", "message", "ranked", "success"] @@ -60,12 +44,4 @@ class RunResponse(_message.Message): message: str ranked: bool success: bool - - def __init__( - self, - message: _Optional[str] = ..., - ranked: bool = ..., - success: bool = ..., - logit: _Optional[str] = ..., - embedding: _Optional[bytes] = ..., - ) -> None: ... + def __init__(self, message: _Optional[str] = ..., ranked: bool = ..., success: bool = ..., logit: _Optional[str] = ..., embedding: _Optional[bytes] = ...) -> None: ... diff --git a/alfred/fm/remote/protos/query_pb2_grpc.py b/alfred/fm/remote/protos/query_pb2_grpc.py index eda37fd..78d3b36 100644 --- a/alfred/fm/remote/protos/query_pb2_grpc.py +++ b/alfred/fm/remote/protos/query_pb2_grpc.py @@ -2,13 +2,7 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc -try: - import alfred.fm.remote.protos.query_pb2 as query__pb2 -except ImportError: - try: - import query_pb2 as query__pb2 - except ModuleNotFoundError: - from . import query_pb2 as query__pb2 +import query_pb2 as query__pb2 class QueryServiceStub(object): @@ -21,15 +15,15 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Encode = channel.stream_stream( - "/unary.QueryService/Encode", - request_serializer=query__pb2.EncodeRequest.SerializeToString, - response_deserializer=query__pb2.EncodeResponse.FromString, - ) + '/unary.QueryService/Encode', + request_serializer=query__pb2.EncodeRequest.SerializeToString, + response_deserializer=query__pb2.EncodeResponse.FromString, + ) self.Run = channel.stream_stream( - "/unary.QueryService/Run", - request_serializer=query__pb2.RunRequest.SerializeToString, - response_deserializer=query__pb2.RunResponse.FromString, - ) + '/unary.QueryService/Run', + request_serializer=query__pb2.RunRequest.SerializeToString, + response_deserializer=query__pb2.RunResponse.FromString, + ) class QueryServiceServicer(object): @@ -38,93 +32,68 @@ class QueryServiceServicer(object): def Encode(self, request_iterator, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def Run(self, request_iterator, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def add_QueryServiceServicer_to_server(servicer, server): rpc_method_handlers = { - "Encode": grpc.stream_stream_rpc_method_handler( - servicer.Encode, - request_deserializer=query__pb2.EncodeRequest.FromString, - response_serializer=query__pb2.EncodeResponse.SerializeToString, - ), - "Run": grpc.stream_stream_rpc_method_handler( - servicer.Run, - request_deserializer=query__pb2.RunRequest.FromString, - response_serializer=query__pb2.RunResponse.SerializeToString, - ), + 'Encode': grpc.stream_stream_rpc_method_handler( + servicer.Encode, + request_deserializer=query__pb2.EncodeRequest.FromString, + response_serializer=query__pb2.EncodeResponse.SerializeToString, + ), + 'Run': grpc.stream_stream_rpc_method_handler( + servicer.Run, + request_deserializer=query__pb2.RunRequest.FromString, + response_serializer=query__pb2.RunResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - "unary.QueryService", rpc_method_handlers - ) + 'unary.QueryService', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) -# This class is part of an EXPERIMENTAL API. + # This class is part of an EXPERIMENTAL API. class QueryService(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def Encode( - request_iterator, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.stream_stream( - request_iterator, + def Encode(request_iterator, target, - "/unary.QueryService/Encode", + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_stream(request_iterator, target, '/unary.QueryService/Encode', query__pb2.EncodeRequest.SerializeToString, query__pb2.EncodeResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod - def Run( - request_iterator, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.stream_stream( - request_iterator, + def Run(request_iterator, target, - "/unary.QueryService/Run", + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_stream(request_iterator, target, '/unary.QueryService/Run', query__pb2.RunRequest.SerializeToString, query__pb2.RunResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) From 6530720ded00cbaaf88dd6eeb122d2922f51bf45 Mon Sep 17 00:00:00 2001 From: Peilin Yu Date: Thu, 22 Feb 2024 13:41:35 -0500 Subject: [PATCH 3/6] Add: Compatibility with Newer Protobuf Version --- alfred/fm/remote/protos/query_pb2_grpc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alfred/fm/remote/protos/query_pb2_grpc.py b/alfred/fm/remote/protos/query_pb2_grpc.py index 78d3b36..5c7d2ef 100644 --- a/alfred/fm/remote/protos/query_pb2_grpc.py +++ b/alfred/fm/remote/protos/query_pb2_grpc.py @@ -2,7 +2,7 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc -import query_pb2 as query__pb2 +import .query_pb2 as query__pb2 class QueryServiceStub(object): From 9c2c206d68a3296af346823e218e5a7d7fedf38f Mon Sep 17 00:00:00 2001 From: Peilin Yu Date: Thu, 22 Feb 2024 13:42:57 -0500 Subject: [PATCH 4/6] Add: Compatibility with Newer Protobuf Version --- alfred/fm/remote/protos/query_pb2_grpc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alfred/fm/remote/protos/query_pb2_grpc.py b/alfred/fm/remote/protos/query_pb2_grpc.py index 5c7d2ef..3b642a5 100644 --- a/alfred/fm/remote/protos/query_pb2_grpc.py +++ b/alfred/fm/remote/protos/query_pb2_grpc.py @@ -2,7 +2,7 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc -import .query_pb2 as query__pb2 +from . import query_pb2 as query__pb2 class QueryServiceStub(object): From aeb4928fe307be5518263a6edb3e8396c2f39216 Mon Sep 17 00:00:00 2001 From: Peilin Yu Date: Thu, 22 Feb 2024 14:25:34 -0500 Subject: [PATCH 5/6] Add: Compatibility with Newer Protobuf Version --- alfred/fm/remote/protos/query_pb2.py | 29 +++++++-------- alfred/fm/remote/protos/query_pb2.pyi | 52 +++++++++++++-------------- 2 files changed, 41 insertions(+), 40 deletions(-) diff --git a/alfred/fm/remote/protos/query_pb2.py b/alfred/fm/remote/protos/query_pb2.py index ab4624c..12cf1d5 100644 --- a/alfred/fm/remote/protos/query_pb2.py +++ b/alfred/fm/remote/protos/query_pb2.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: query.proto +# Protobuf Python Version: 4.25.1 """Generated protocol buffer code.""" -from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -15,19 +16,19 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0bquery.proto\x12\x05unary\"c\n\nRunRequest\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x16\n\tcandidate\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06kwargs\x18\x03 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_candidateB\t\n\x07_kwargs\"\x83\x01\n\x0bRunResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0e\n\x06ranked\x18\x02 \x01(\x08\x12\x0f\n\x07success\x18\x03 \x01(\x08\x12\x12\n\x05logit\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x16\n\tembedding\x18\x05 \x01(\x0cH\x01\x88\x01\x01\x42\x08\n\x06_logitB\x0c\n\n_embedding\"S\n\rEncodeRequest\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x11\n\treduction\x18\x03 \x01(\t\x12\x13\n\x06kwargs\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_kwargs\"4\n\x0e\x45ncodeResponse\x12\x11\n\tembedding\x18\x01 \x01(\x0c\x12\x0f\n\x07success\x18\x02 \x01(\x08\x32\x7f\n\x0cQueryService\x12;\n\x06\x45ncode\x12\x14.unary.EncodeRequest\x1a\x15.unary.EncodeResponse\"\x00(\x01\x30\x01\x12\x32\n\x03Run\x12\x11.unary.RunRequest\x1a\x12.unary.RunResponse\"\x00(\x01\x30\x01\x62\x06proto3') -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'query_pb2', globals()) +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'query_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _RUNREQUEST._serialized_start=22 - _RUNREQUEST._serialized_end=121 - _RUNRESPONSE._serialized_start=124 - _RUNRESPONSE._serialized_end=255 - _ENCODEREQUEST._serialized_start=257 - _ENCODEREQUEST._serialized_end=340 - _ENCODERESPONSE._serialized_start=342 - _ENCODERESPONSE._serialized_end=394 - _QUERYSERVICE._serialized_start=396 - _QUERYSERVICE._serialized_end=523 + _globals['_RUNREQUEST']._serialized_start=22 + _globals['_RUNREQUEST']._serialized_end=121 + _globals['_RUNRESPONSE']._serialized_start=124 + _globals['_RUNRESPONSE']._serialized_end=255 + _globals['_ENCODEREQUEST']._serialized_start=257 + _globals['_ENCODEREQUEST']._serialized_end=340 + _globals['_ENCODERESPONSE']._serialized_start=342 + _globals['_ENCODERESPONSE']._serialized_end=394 + _globals['_QUERYSERVICE']._serialized_start=396 + _globals['_QUERYSERVICE']._serialized_end=523 # @@protoc_insertion_point(module_scope) diff --git a/alfred/fm/remote/protos/query_pb2.pyi b/alfred/fm/remote/protos/query_pb2.pyi index a473209..8efcfb0 100644 --- a/alfred/fm/remote/protos/query_pb2.pyi +++ b/alfred/fm/remote/protos/query_pb2.pyi @@ -4,44 +4,44 @@ from typing import ClassVar as _ClassVar, Optional as _Optional DESCRIPTOR: _descriptor.FileDescriptor -class EncodeRequest(_message.Message): - __slots__ = ["kwargs", "message", "reduction"] - KWARGS_FIELD_NUMBER: _ClassVar[int] - MESSAGE_FIELD_NUMBER: _ClassVar[int] - REDUCTION_FIELD_NUMBER: _ClassVar[int] - kwargs: str - message: str - reduction: str - def __init__(self, message: _Optional[str] = ..., reduction: _Optional[str] = ..., kwargs: _Optional[str] = ...) -> None: ... - -class EncodeResponse(_message.Message): - __slots__ = ["embedding", "success"] - EMBEDDING_FIELD_NUMBER: _ClassVar[int] - SUCCESS_FIELD_NUMBER: _ClassVar[int] - embedding: bytes - success: bool - def __init__(self, embedding: _Optional[bytes] = ..., success: bool = ...) -> None: ... - class RunRequest(_message.Message): - __slots__ = ["candidate", "kwargs", "message"] + __slots__ = ("message", "candidate", "kwargs") + MESSAGE_FIELD_NUMBER: _ClassVar[int] CANDIDATE_FIELD_NUMBER: _ClassVar[int] KWARGS_FIELD_NUMBER: _ClassVar[int] - MESSAGE_FIELD_NUMBER: _ClassVar[int] + message: str candidate: str kwargs: str - message: str def __init__(self, message: _Optional[str] = ..., candidate: _Optional[str] = ..., kwargs: _Optional[str] = ...) -> None: ... class RunResponse(_message.Message): - __slots__ = ["embedding", "logit", "message", "ranked", "success"] - EMBEDDING_FIELD_NUMBER: _ClassVar[int] - LOGIT_FIELD_NUMBER: _ClassVar[int] + __slots__ = ("message", "ranked", "success", "logit", "embedding") MESSAGE_FIELD_NUMBER: _ClassVar[int] RANKED_FIELD_NUMBER: _ClassVar[int] SUCCESS_FIELD_NUMBER: _ClassVar[int] - embedding: bytes - logit: str + LOGIT_FIELD_NUMBER: _ClassVar[int] + EMBEDDING_FIELD_NUMBER: _ClassVar[int] message: str ranked: bool success: bool + logit: str + embedding: bytes def __init__(self, message: _Optional[str] = ..., ranked: bool = ..., success: bool = ..., logit: _Optional[str] = ..., embedding: _Optional[bytes] = ...) -> None: ... + +class EncodeRequest(_message.Message): + __slots__ = ("message", "reduction", "kwargs") + MESSAGE_FIELD_NUMBER: _ClassVar[int] + REDUCTION_FIELD_NUMBER: _ClassVar[int] + KWARGS_FIELD_NUMBER: _ClassVar[int] + message: str + reduction: str + kwargs: str + def __init__(self, message: _Optional[str] = ..., reduction: _Optional[str] = ..., kwargs: _Optional[str] = ...) -> None: ... + +class EncodeResponse(_message.Message): + __slots__ = ("embedding", "success") + EMBEDDING_FIELD_NUMBER: _ClassVar[int] + SUCCESS_FIELD_NUMBER: _ClassVar[int] + embedding: bytes + success: bool + def __init__(self, embedding: _Optional[bytes] = ..., success: bool = ...) -> None: ... From 93fa694590ed56dac9eb990051aa1fcd1c7ce03c Mon Sep 17 00:00:00 2001 From: Peilin Yu Date: Thu, 22 Feb 2024 15:37:20 -0500 Subject: [PATCH 6/6] Add: Compatibility with Newer Protobuf Version --- alfred/fm/huggingface.py | 3 +- alfred/fm/remote/grpc.py | 10 +- alfred/fm/remote/protos/query_pb2.py | 30 ++--- alfred/fm/remote/protos/query_pb2.pyi | 27 +++- alfred/fm/remote/protos/query_pb2_grpc.py | 121 +++++++++++------- docs/alfred/fm/huggingface.md | 2 +- .../alfred/fm/remote/protos/query_pb2_grpc.md | 16 +-- setup.py | 2 +- test/unit_test/test_grpc.py | 2 + 9 files changed, 130 insertions(+), 83 deletions(-) diff --git a/alfred/fm/huggingface.py b/alfred/fm/huggingface.py index 9b27fa6..4c5c0f2 100644 --- a/alfred/fm/huggingface.py +++ b/alfred/fm/huggingface.py @@ -18,7 +18,8 @@ logger = logging.getLogger(__name__) -from transformers import LlamaPreTrainedModel, MistralPreTrainedModel +from transformers.models.llama.modeling_llama import LlamaPreTrainedModel +from transformers.models.mistral.modeling_mistral import MistralPreTrainedModel dtype_match = { "auto": "auto", diff --git a/alfred/fm/remote/grpc.py b/alfred/fm/remote/grpc.py index a5af04c..24faa48 100644 --- a/alfred/fm/remote/grpc.py +++ b/alfred/fm/remote/grpc.py @@ -11,11 +11,11 @@ from PIL import Image from tqdm.auto import tqdm -from alfred.fm.query import Query, RankedQuery, CompletionQuery -from alfred.fm.remote.protos import query_pb2 -from alfred.fm.remote.protos import query_pb2_grpc -from alfred.fm.remote.utils import get_ip, tensor_to_bytes, bytes_to_tensor, port_finder -from alfred.fm.response import RankedResponse, CompletionResponse +from ..query import Query, RankedQuery, CompletionQuery +from ..remote.protos import query_pb2 +from ..remote.protos import query_pb2_grpc +from ..remote.utils import get_ip, tensor_to_bytes, bytes_to_tensor, port_finder +from ..response import RankedResponse, CompletionResponse logger = logging.getLogger(__name__) diff --git a/alfred/fm/remote/protos/query_pb2.py b/alfred/fm/remote/protos/query_pb2.py index 12cf1d5..55a2b1e 100644 --- a/alfred/fm/remote/protos/query_pb2.py +++ b/alfred/fm/remote/protos/query_pb2.py @@ -12,23 +12,23 @@ _sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0bquery.proto\x12\x05unary\"c\n\nRunRequest\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x16\n\tcandidate\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06kwargs\x18\x03 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_candidateB\t\n\x07_kwargs\"\x83\x01\n\x0bRunResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0e\n\x06ranked\x18\x02 \x01(\x08\x12\x0f\n\x07success\x18\x03 \x01(\x08\x12\x12\n\x05logit\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x16\n\tembedding\x18\x05 \x01(\x0cH\x01\x88\x01\x01\x42\x08\n\x06_logitB\x0c\n\n_embedding\"S\n\rEncodeRequest\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x11\n\treduction\x18\x03 \x01(\t\x12\x13\n\x06kwargs\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_kwargs\"4\n\x0e\x45ncodeResponse\x12\x11\n\tembedding\x18\x01 \x01(\x0c\x12\x0f\n\x07success\x18\x02 \x01(\x08\x32\x7f\n\x0cQueryService\x12;\n\x06\x45ncode\x12\x14.unary.EncodeRequest\x1a\x15.unary.EncodeResponse\"\x00(\x01\x30\x01\x12\x32\n\x03Run\x12\x11.unary.RunRequest\x1a\x12.unary.RunResponse\"\x00(\x01\x30\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x0bquery.proto\x12\x05unary"c\n\nRunRequest\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x16\n\tcandidate\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06kwargs\x18\x03 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_candidateB\t\n\x07_kwargs"\x83\x01\n\x0bRunResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0e\n\x06ranked\x18\x02 \x01(\x08\x12\x0f\n\x07success\x18\x03 \x01(\x08\x12\x12\n\x05logit\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x16\n\tembedding\x18\x05 \x01(\x0cH\x01\x88\x01\x01\x42\x08\n\x06_logitB\x0c\n\n_embedding"S\n\rEncodeRequest\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x11\n\treduction\x18\x03 \x01(\t\x12\x13\n\x06kwargs\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_kwargs"4\n\x0e\x45ncodeResponse\x12\x11\n\tembedding\x18\x01 \x01(\x0c\x12\x0f\n\x07success\x18\x02 \x01(\x08\x32\x7f\n\x0cQueryService\x12;\n\x06\x45ncode\x12\x14.unary.EncodeRequest\x1a\x15.unary.EncodeResponse"\x00(\x01\x30\x01\x12\x32\n\x03Run\x12\x11.unary.RunRequest\x1a\x12.unary.RunResponse"\x00(\x01\x30\x01\x62\x06proto3' +) _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'query_pb2', _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "query_pb2", _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _globals['_RUNREQUEST']._serialized_start=22 - _globals['_RUNREQUEST']._serialized_end=121 - _globals['_RUNRESPONSE']._serialized_start=124 - _globals['_RUNRESPONSE']._serialized_end=255 - _globals['_ENCODEREQUEST']._serialized_start=257 - _globals['_ENCODEREQUEST']._serialized_end=340 - _globals['_ENCODERESPONSE']._serialized_start=342 - _globals['_ENCODERESPONSE']._serialized_end=394 - _globals['_QUERYSERVICE']._serialized_start=396 - _globals['_QUERYSERVICE']._serialized_end=523 + DESCRIPTOR._options = None + _globals["_RUNREQUEST"]._serialized_start = 22 + _globals["_RUNREQUEST"]._serialized_end = 121 + _globals["_RUNRESPONSE"]._serialized_start = 124 + _globals["_RUNRESPONSE"]._serialized_end = 255 + _globals["_ENCODEREQUEST"]._serialized_start = 257 + _globals["_ENCODEREQUEST"]._serialized_end = 340 + _globals["_ENCODERESPONSE"]._serialized_start = 342 + _globals["_ENCODERESPONSE"]._serialized_end = 394 + _globals["_QUERYSERVICE"]._serialized_start = 396 + _globals["_QUERYSERVICE"]._serialized_end = 523 # @@protoc_insertion_point(module_scope) diff --git a/alfred/fm/remote/protos/query_pb2.pyi b/alfred/fm/remote/protos/query_pb2.pyi index 8efcfb0..3f012ba 100644 --- a/alfred/fm/remote/protos/query_pb2.pyi +++ b/alfred/fm/remote/protos/query_pb2.pyi @@ -12,7 +12,12 @@ class RunRequest(_message.Message): message: str candidate: str kwargs: str - def __init__(self, message: _Optional[str] = ..., candidate: _Optional[str] = ..., kwargs: _Optional[str] = ...) -> None: ... + def __init__( + self, + message: _Optional[str] = ..., + candidate: _Optional[str] = ..., + kwargs: _Optional[str] = ..., + ) -> None: ... class RunResponse(_message.Message): __slots__ = ("message", "ranked", "success", "logit", "embedding") @@ -26,7 +31,14 @@ class RunResponse(_message.Message): success: bool logit: str embedding: bytes - def __init__(self, message: _Optional[str] = ..., ranked: bool = ..., success: bool = ..., logit: _Optional[str] = ..., embedding: _Optional[bytes] = ...) -> None: ... + def __init__( + self, + message: _Optional[str] = ..., + ranked: bool = ..., + success: bool = ..., + logit: _Optional[str] = ..., + embedding: _Optional[bytes] = ..., + ) -> None: ... class EncodeRequest(_message.Message): __slots__ = ("message", "reduction", "kwargs") @@ -36,7 +48,12 @@ class EncodeRequest(_message.Message): message: str reduction: str kwargs: str - def __init__(self, message: _Optional[str] = ..., reduction: _Optional[str] = ..., kwargs: _Optional[str] = ...) -> None: ... + def __init__( + self, + message: _Optional[str] = ..., + reduction: _Optional[str] = ..., + kwargs: _Optional[str] = ..., + ) -> None: ... class EncodeResponse(_message.Message): __slots__ = ("embedding", "success") @@ -44,4 +61,6 @@ class EncodeResponse(_message.Message): SUCCESS_FIELD_NUMBER: _ClassVar[int] embedding: bytes success: bool - def __init__(self, embedding: _Optional[bytes] = ..., success: bool = ...) -> None: ... + def __init__( + self, embedding: _Optional[bytes] = ..., success: bool = ... + ) -> None: ... diff --git a/alfred/fm/remote/protos/query_pb2_grpc.py b/alfred/fm/remote/protos/query_pb2_grpc.py index 3b642a5..4f25019 100644 --- a/alfred/fm/remote/protos/query_pb2_grpc.py +++ b/alfred/fm/remote/protos/query_pb2_grpc.py @@ -15,15 +15,15 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Encode = channel.stream_stream( - '/unary.QueryService/Encode', - request_serializer=query__pb2.EncodeRequest.SerializeToString, - response_deserializer=query__pb2.EncodeResponse.FromString, - ) + "/unary.QueryService/Encode", + request_serializer=query__pb2.EncodeRequest.SerializeToString, + response_deserializer=query__pb2.EncodeResponse.FromString, + ) self.Run = channel.stream_stream( - '/unary.QueryService/Run', - request_serializer=query__pb2.RunRequest.SerializeToString, - response_deserializer=query__pb2.RunResponse.FromString, - ) + "/unary.QueryService/Run", + request_serializer=query__pb2.RunRequest.SerializeToString, + response_deserializer=query__pb2.RunResponse.FromString, + ) class QueryServiceServicer(object): @@ -32,68 +32,93 @@ class QueryServiceServicer(object): def Encode(self, request_iterator, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def Run(self, request_iterator, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def add_QueryServiceServicer_to_server(servicer, server): rpc_method_handlers = { - 'Encode': grpc.stream_stream_rpc_method_handler( - servicer.Encode, - request_deserializer=query__pb2.EncodeRequest.FromString, - response_serializer=query__pb2.EncodeResponse.SerializeToString, - ), - 'Run': grpc.stream_stream_rpc_method_handler( - servicer.Run, - request_deserializer=query__pb2.RunRequest.FromString, - response_serializer=query__pb2.RunResponse.SerializeToString, - ), + "Encode": grpc.stream_stream_rpc_method_handler( + servicer.Encode, + request_deserializer=query__pb2.EncodeRequest.FromString, + response_serializer=query__pb2.EncodeResponse.SerializeToString, + ), + "Run": grpc.stream_stream_rpc_method_handler( + servicer.Run, + request_deserializer=query__pb2.RunRequest.FromString, + response_serializer=query__pb2.RunResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'unary.QueryService', rpc_method_handlers) + "unary.QueryService", rpc_method_handlers + ) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. +# This class is part of an EXPERIMENTAL API. class QueryService(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def Encode(request_iterator, + def Encode( + request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.stream_stream( + request_iterator, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.stream_stream(request_iterator, target, '/unary.QueryService/Encode', + "/unary.QueryService/Encode", query__pb2.EncodeRequest.SerializeToString, query__pb2.EncodeResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def Run(request_iterator, + def Run( + request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.stream_stream( + request_iterator, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.stream_stream(request_iterator, target, '/unary.QueryService/Run', + "/unary.QueryService/Run", query__pb2.RunRequest.SerializeToString, query__pb2.RunResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/docs/alfred/fm/huggingface.md b/docs/alfred/fm/huggingface.md index d77d031..b6f977b 100644 --- a/docs/alfred/fm/huggingface.md +++ b/docs/alfred/fm/huggingface.md @@ -12,7 +12,7 @@ Huggingface ## HuggingFaceModel -[Show source in huggingface.py:43](../../../alfred/fm/huggingface.py#L43) +[Show source in huggingface.py:44](../../../alfred/fm/huggingface.py#L44) The HuggingFaceModel class is a wrapper for HuggingFace models, including both Seq2Seq (Encoder-Decoder, e.g. T5, T0) and Causal diff --git a/docs/alfred/fm/remote/protos/query_pb2_grpc.md b/docs/alfred/fm/remote/protos/query_pb2_grpc.md index ae81705..733b894 100644 --- a/docs/alfred/fm/remote/protos/query_pb2_grpc.md +++ b/docs/alfred/fm/remote/protos/query_pb2_grpc.md @@ -21,7 +21,7 @@ Query Pb2 Grpc ## QueryService -[Show source in query_pb2_grpc.py:71](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L71) +[Show source in query_pb2_grpc.py:65](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L65) Missing associated documentation comment in .proto file. @@ -34,7 +34,7 @@ class QueryService(object): ### QueryService.Encode -[Show source in query_pb2_grpc.py:74](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L74) +[Show source in query_pb2_grpc.py:68](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L68) #### Signature @@ -57,7 +57,7 @@ def Encode( ### QueryService.Run -[Show source in query_pb2_grpc.py:103](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L103) +[Show source in query_pb2_grpc.py:97](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L97) #### Signature @@ -82,7 +82,7 @@ def Run( ## QueryServiceServicer -[Show source in query_pb2_grpc.py:35](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L35) +[Show source in query_pb2_grpc.py:29](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L29) Missing associated documentation comment in .proto file. @@ -95,7 +95,7 @@ class QueryServiceServicer(object): ### QueryServiceServicer().Encode -[Show source in query_pb2_grpc.py:38](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L38) +[Show source in query_pb2_grpc.py:32](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L32) Missing associated documentation comment in .proto file. @@ -108,7 +108,7 @@ def Encode(self, request_iterator, context): ### QueryServiceServicer().Run -[Show source in query_pb2_grpc.py:44](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L44) +[Show source in query_pb2_grpc.py:38](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L38) Missing associated documentation comment in .proto file. @@ -123,7 +123,7 @@ def Run(self, request_iterator, context): ## QueryServiceStub -[Show source in query_pb2_grpc.py:14](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L14) +[Show source in query_pb2_grpc.py:8](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L8) Missing associated documentation comment in .proto file. @@ -139,7 +139,7 @@ class QueryServiceStub(object): ## add_QueryServiceServicer_to_server -[Show source in query_pb2_grpc.py:51](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L51) +[Show source in query_pb2_grpc.py:45](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L45) #### Signature diff --git a/setup.py b/setup.py index e2902e8..128014f 100644 --- a/setup.py +++ b/setup.py @@ -8,5 +8,5 @@ author_email='peilin_yu@brown.edu', description='Toolkit for Prompted Weak Supervisions', packages=find_packages(), - install_requires=['numpy', 'scipy', 'torch', 'tqdm', 'torchvision', "paramiko>=2.7.2", "pyarrow>=3.0.0", "grpcio", "protobuf<5"], + install_requires=['numpy', 'scipy', 'torch', 'tqdm', 'torchvision', "accelerate", "paramiko>=2.7.2", "pyarrow>=3.0.0", "grpcio", "protobuf<5"], ) diff --git a/test/unit_test/test_grpc.py b/test/unit_test/test_grpc.py index b5f6a38..3ea15a6 100644 --- a/test/unit_test/test_grpc.py +++ b/test/unit_test/test_grpc.py @@ -1,4 +1,5 @@ import threading +import time import unittest import grpc @@ -36,6 +37,7 @@ def server_starter(server, client, port): self.channel = grpc.insecure_channel(f"localhost:{self.port}") self.stub = query_pb2_grpc.QueryServiceStub(self.channel) + time.sleep(2) def test_run_single_query(self): # Test running a single query using the client