diff --git a/alfred/fm/huggingface.py b/alfred/fm/huggingface.py index 9b27fa6..4c5c0f2 100644 --- a/alfred/fm/huggingface.py +++ b/alfred/fm/huggingface.py @@ -18,7 +18,8 @@ logger = logging.getLogger(__name__) -from transformers import LlamaPreTrainedModel, MistralPreTrainedModel +from transformers.models.llama.modeling_llama import LlamaPreTrainedModel +from transformers.models.mistral.modeling_mistral import MistralPreTrainedModel dtype_match = { "auto": "auto", diff --git a/alfred/fm/remote/grpc.py b/alfred/fm/remote/grpc.py index a5af04c..24faa48 100644 --- a/alfred/fm/remote/grpc.py +++ b/alfred/fm/remote/grpc.py @@ -11,11 +11,11 @@ from PIL import Image from tqdm.auto import tqdm -from alfred.fm.query import Query, RankedQuery, CompletionQuery -from alfred.fm.remote.protos import query_pb2 -from alfred.fm.remote.protos import query_pb2_grpc -from alfred.fm.remote.utils import get_ip, tensor_to_bytes, bytes_to_tensor, port_finder -from alfred.fm.response import RankedResponse, CompletionResponse +from ..query import Query, RankedQuery, CompletionQuery +from ..remote.protos import query_pb2 +from ..remote.protos import query_pb2_grpc +from ..remote.utils import get_ip, tensor_to_bytes, bytes_to_tensor, port_finder +from ..response import RankedResponse, CompletionResponse logger = logging.getLogger(__name__) diff --git a/alfred/fm/remote/protos/query_pb2.py b/alfred/fm/remote/protos/query_pb2.py index 278e33e..55a2b1e 100644 --- a/alfred/fm/remote/protos/query_pb2.py +++ b/alfred/fm/remote/protos/query_pb2.py @@ -1,32 +1,34 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: query.proto +# Protobuf Python Version: 4.25.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder - # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() + DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( b'\n\x0bquery.proto\x12\x05unary"c\n\nRunRequest\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x16\n\tcandidate\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06kwargs\x18\x03 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_candidateB\t\n\x07_kwargs"\x83\x01\n\x0bRunResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0e\n\x06ranked\x18\x02 \x01(\x08\x12\x0f\n\x07success\x18\x03 \x01(\x08\x12\x12\n\x05logit\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x16\n\tembedding\x18\x05 \x01(\x0cH\x01\x88\x01\x01\x42\x08\n\x06_logitB\x0c\n\n_embedding"S\n\rEncodeRequest\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x11\n\treduction\x18\x03 \x01(\t\x12\x13\n\x06kwargs\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_kwargs"4\n\x0e\x45ncodeResponse\x12\x11\n\tembedding\x18\x01 \x01(\x0c\x12\x0f\n\x07success\x18\x02 \x01(\x08\x32\x7f\n\x0cQueryService\x12;\n\x06\x45ncode\x12\x14.unary.EncodeRequest\x1a\x15.unary.EncodeResponse"\x00(\x01\x30\x01\x12\x32\n\x03Run\x12\x11.unary.RunRequest\x1a\x12.unary.RunResponse"\x00(\x01\x30\x01\x62\x06proto3' ) -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "query_pb2", globals()) +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "query_pb2", _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _RUNREQUEST._serialized_start = 22 - _RUNREQUEST._serialized_end = 121 - _RUNRESPONSE._serialized_start = 124 - _RUNRESPONSE._serialized_end = 255 - _ENCODEREQUEST._serialized_start = 257 - _ENCODEREQUEST._serialized_end = 340 - _ENCODERESPONSE._serialized_start = 342 - _ENCODERESPONSE._serialized_end = 394 - _QUERYSERVICE._serialized_start = 396 - _QUERYSERVICE._serialized_end = 523 + _globals["_RUNREQUEST"]._serialized_start = 22 + _globals["_RUNREQUEST"]._serialized_end = 121 + _globals["_RUNRESPONSE"]._serialized_start = 124 + _globals["_RUNRESPONSE"]._serialized_end = 255 + _globals["_ENCODEREQUEST"]._serialized_start = 257 + _globals["_ENCODEREQUEST"]._serialized_end = 340 + _globals["_ENCODERESPONSE"]._serialized_start = 342 + _globals["_ENCODERESPONSE"]._serialized_end = 394 + _globals["_QUERYSERVICE"]._serialized_start = 396 + _globals["_QUERYSERVICE"]._serialized_end = 523 # @@protoc_insertion_point(module_scope) diff --git a/alfred/fm/remote/protos/query_pb2.pyi b/alfred/fm/remote/protos/query_pb2.pyi index 8991e02..3f012ba 100644 --- a/alfred/fm/remote/protos/query_pb2.pyi +++ b/alfred/fm/remote/protos/query_pb2.pyi @@ -1,46 +1,17 @@ -from typing import ClassVar as _ClassVar, Optional as _Optional - from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Optional as _Optional DESCRIPTOR: _descriptor.FileDescriptor -class EncodeRequest(_message.Message): - __slots__ = ["kwargs", "message", "reduction"] - KWARGS_FIELD_NUMBER: _ClassVar[int] - MESSAGE_FIELD_NUMBER: _ClassVar[int] - REDUCTION_FIELD_NUMBER: _ClassVar[int] - kwargs: str - message: str - reduction: str - - def __init__( - self, - message: _Optional[str] = ..., - reduction: _Optional[str] = ..., - kwargs: _Optional[str] = ..., - ) -> None: ... - -class EncodeResponse(_message.Message): - __slots__ = ["embedding", "success"] - EMBEDDING_FIELD_NUMBER: _ClassVar[int] - SUCCESS_FIELD_NUMBER: _ClassVar[int] - embedding: bytes - success: bool - - def __init__( - self, embedding: _Optional[bytes] = ..., success: bool = ... - ) -> None: ... - class RunRequest(_message.Message): - __slots__ = ["candidate", "kwargs", "message"] + __slots__ = ("message", "candidate", "kwargs") + MESSAGE_FIELD_NUMBER: _ClassVar[int] CANDIDATE_FIELD_NUMBER: _ClassVar[int] KWARGS_FIELD_NUMBER: _ClassVar[int] - MESSAGE_FIELD_NUMBER: _ClassVar[int] + message: str candidate: str kwargs: str - message: str - def __init__( self, message: _Optional[str] = ..., @@ -49,18 +20,17 @@ class RunRequest(_message.Message): ) -> None: ... class RunResponse(_message.Message): - __slots__ = ["embedding", "logit", "message", "ranked", "success"] - EMBEDDING_FIELD_NUMBER: _ClassVar[int] - LOGIT_FIELD_NUMBER: _ClassVar[int] + __slots__ = ("message", "ranked", "success", "logit", "embedding") MESSAGE_FIELD_NUMBER: _ClassVar[int] RANKED_FIELD_NUMBER: _ClassVar[int] SUCCESS_FIELD_NUMBER: _ClassVar[int] - embedding: bytes - logit: str + LOGIT_FIELD_NUMBER: _ClassVar[int] + EMBEDDING_FIELD_NUMBER: _ClassVar[int] message: str ranked: bool success: bool - + logit: str + embedding: bytes def __init__( self, message: _Optional[str] = ..., @@ -69,3 +39,28 @@ class RunResponse(_message.Message): logit: _Optional[str] = ..., embedding: _Optional[bytes] = ..., ) -> None: ... + +class EncodeRequest(_message.Message): + __slots__ = ("message", "reduction", "kwargs") + MESSAGE_FIELD_NUMBER: _ClassVar[int] + REDUCTION_FIELD_NUMBER: _ClassVar[int] + KWARGS_FIELD_NUMBER: _ClassVar[int] + message: str + reduction: str + kwargs: str + def __init__( + self, + message: _Optional[str] = ..., + reduction: _Optional[str] = ..., + kwargs: _Optional[str] = ..., + ) -> None: ... + +class EncodeResponse(_message.Message): + __slots__ = ("embedding", "success") + EMBEDDING_FIELD_NUMBER: _ClassVar[int] + SUCCESS_FIELD_NUMBER: _ClassVar[int] + embedding: bytes + success: bool + def __init__( + self, embedding: _Optional[bytes] = ..., success: bool = ... + ) -> None: ... diff --git a/alfred/fm/remote/protos/query_pb2_grpc.py b/alfred/fm/remote/protos/query_pb2_grpc.py index eda37fd..4f25019 100644 --- a/alfred/fm/remote/protos/query_pb2_grpc.py +++ b/alfred/fm/remote/protos/query_pb2_grpc.py @@ -2,13 +2,7 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc -try: - import alfred.fm.remote.protos.query_pb2 as query__pb2 -except ImportError: - try: - import query_pb2 as query__pb2 - except ModuleNotFoundError: - from . import query_pb2 as query__pb2 +from . import query_pb2 as query__pb2 class QueryServiceStub(object): diff --git a/docs/alfred/fm/huggingface.md b/docs/alfred/fm/huggingface.md index d77d031..b6f977b 100644 --- a/docs/alfred/fm/huggingface.md +++ b/docs/alfred/fm/huggingface.md @@ -12,7 +12,7 @@ Huggingface ## HuggingFaceModel -[Show source in huggingface.py:43](../../../alfred/fm/huggingface.py#L43) +[Show source in huggingface.py:44](../../../alfred/fm/huggingface.py#L44) The HuggingFaceModel class is a wrapper for HuggingFace models, including both Seq2Seq (Encoder-Decoder, e.g. T5, T0) and Causal diff --git a/docs/alfred/fm/remote/protos/query_pb2_grpc.md b/docs/alfred/fm/remote/protos/query_pb2_grpc.md index ae81705..733b894 100644 --- a/docs/alfred/fm/remote/protos/query_pb2_grpc.md +++ b/docs/alfred/fm/remote/protos/query_pb2_grpc.md @@ -21,7 +21,7 @@ Query Pb2 Grpc ## QueryService -[Show source in query_pb2_grpc.py:71](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L71) +[Show source in query_pb2_grpc.py:65](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L65) Missing associated documentation comment in .proto file. @@ -34,7 +34,7 @@ class QueryService(object): ### QueryService.Encode -[Show source in query_pb2_grpc.py:74](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L74) +[Show source in query_pb2_grpc.py:68](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L68) #### Signature @@ -57,7 +57,7 @@ def Encode( ### QueryService.Run -[Show source in query_pb2_grpc.py:103](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L103) +[Show source in query_pb2_grpc.py:97](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L97) #### Signature @@ -82,7 +82,7 @@ def Run( ## QueryServiceServicer -[Show source in query_pb2_grpc.py:35](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L35) +[Show source in query_pb2_grpc.py:29](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L29) Missing associated documentation comment in .proto file. @@ -95,7 +95,7 @@ class QueryServiceServicer(object): ### QueryServiceServicer().Encode -[Show source in query_pb2_grpc.py:38](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L38) +[Show source in query_pb2_grpc.py:32](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L32) Missing associated documentation comment in .proto file. @@ -108,7 +108,7 @@ def Encode(self, request_iterator, context): ### QueryServiceServicer().Run -[Show source in query_pb2_grpc.py:44](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L44) +[Show source in query_pb2_grpc.py:38](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L38) Missing associated documentation comment in .proto file. @@ -123,7 +123,7 @@ def Run(self, request_iterator, context): ## QueryServiceStub -[Show source in query_pb2_grpc.py:14](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L14) +[Show source in query_pb2_grpc.py:8](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L8) Missing associated documentation comment in .proto file. @@ -139,7 +139,7 @@ class QueryServiceStub(object): ## add_QueryServiceServicer_to_server -[Show source in query_pb2_grpc.py:51](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L51) +[Show source in query_pb2_grpc.py:45](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L45) #### Signature diff --git a/requirements.txt b/requirements.txt index 035d1ce..e5de385 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ -datasets~=2.4.0 paramiko>=2.7.2 pyarrow>=3.0.0 torch>=1.8.0 +datasets transformers accelerate numpy>=1.21.0 @@ -12,8 +12,8 @@ arrow>=0.13.1 setuptools>=58.0.4 scipy>=1.7.1 -grpcio==1.48.1 -protobuf==3.20.0 +grpcio==1.62.0 +protobuf==4.25.2 sentencepiece deepspeed>=0.8.2 diff --git a/setup.py b/setup.py index 58c4e3f..128014f 100644 --- a/setup.py +++ b/setup.py @@ -8,5 +8,5 @@ author_email='peilin_yu@brown.edu', description='Toolkit for Prompted Weak Supervisions', packages=find_packages(), - install_requires=['numpy', 'scipy', 'torch', 'tqdm', 'torchvision', "paramiko>=2.7.2", "pyarrow>=3.0.0", "grpcio==1.48.1", "protobuf==3.20.0"], + install_requires=['numpy', 'scipy', 'torch', 'tqdm', 'torchvision', "accelerate", "paramiko>=2.7.2", "pyarrow>=3.0.0", "grpcio", "protobuf<5"], ) diff --git a/test/unit_test/test_grpc.py b/test/unit_test/test_grpc.py index b5f6a38..3ea15a6 100644 --- a/test/unit_test/test_grpc.py +++ b/test/unit_test/test_grpc.py @@ -1,4 +1,5 @@ import threading +import time import unittest import grpc @@ -36,6 +37,7 @@ def server_starter(server, client, port): self.channel = grpc.insecure_channel(f"localhost:{self.port}") self.stub = query_pb2_grpc.QueryServiceStub(self.channel) + time.sleep(2) def test_run_single_query(self): # Test running a single query using the client