Skip to content

Commit

Permalink
Merge pull request #62 from BatsResearch/upgrade-protobuf
Browse files Browse the repository at this point in the history
add compatibility for protobuf 4.x
  • Loading branch information
dotpyu authored Feb 22, 2024
2 parents 46851ba + 93fa694 commit 3c2f3f0
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 78 deletions.
3 changes: 2 additions & 1 deletion alfred/fm/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

logger = logging.getLogger(__name__)

from transformers import LlamaPreTrainedModel, MistralPreTrainedModel
from transformers.models.llama.modeling_llama import LlamaPreTrainedModel
from transformers.models.mistral.modeling_mistral import MistralPreTrainedModel

dtype_match = {
"auto": "auto",
Expand Down
10 changes: 5 additions & 5 deletions alfred/fm/remote/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from PIL import Image
from tqdm.auto import tqdm

from alfred.fm.query import Query, RankedQuery, CompletionQuery
from alfred.fm.remote.protos import query_pb2
from alfred.fm.remote.protos import query_pb2_grpc
from alfred.fm.remote.utils import get_ip, tensor_to_bytes, bytes_to_tensor, port_finder
from alfred.fm.response import RankedResponse, CompletionResponse
from ..query import Query, RankedQuery, CompletionQuery
from ..remote.protos import query_pb2
from ..remote.protos import query_pb2_grpc
from ..remote.utils import get_ip, tensor_to_bytes, bytes_to_tensor, port_finder
from ..response import RankedResponse, CompletionResponse

logger = logging.getLogger(__name__)

Expand Down
28 changes: 15 additions & 13 deletions alfred/fm/remote/protos/query_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

73 changes: 34 additions & 39 deletions alfred/fm/remote/protos/query_pb2.pyi
Original file line number Diff line number Diff line change
@@ -1,46 +1,17 @@
from typing import ClassVar as _ClassVar, Optional as _Optional

from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Optional as _Optional

DESCRIPTOR: _descriptor.FileDescriptor

class EncodeRequest(_message.Message):
__slots__ = ["kwargs", "message", "reduction"]
KWARGS_FIELD_NUMBER: _ClassVar[int]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
REDUCTION_FIELD_NUMBER: _ClassVar[int]
kwargs: str
message: str
reduction: str

def __init__(
self,
message: _Optional[str] = ...,
reduction: _Optional[str] = ...,
kwargs: _Optional[str] = ...,
) -> None: ...

class EncodeResponse(_message.Message):
__slots__ = ["embedding", "success"]
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
SUCCESS_FIELD_NUMBER: _ClassVar[int]
embedding: bytes
success: bool

def __init__(
self, embedding: _Optional[bytes] = ..., success: bool = ...
) -> None: ...

class RunRequest(_message.Message):
__slots__ = ["candidate", "kwargs", "message"]
__slots__ = ("message", "candidate", "kwargs")
MESSAGE_FIELD_NUMBER: _ClassVar[int]
CANDIDATE_FIELD_NUMBER: _ClassVar[int]
KWARGS_FIELD_NUMBER: _ClassVar[int]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
message: str
candidate: str
kwargs: str
message: str

def __init__(
self,
message: _Optional[str] = ...,
Expand All @@ -49,18 +20,17 @@ class RunRequest(_message.Message):
) -> None: ...

class RunResponse(_message.Message):
__slots__ = ["embedding", "logit", "message", "ranked", "success"]
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
LOGIT_FIELD_NUMBER: _ClassVar[int]
__slots__ = ("message", "ranked", "success", "logit", "embedding")
MESSAGE_FIELD_NUMBER: _ClassVar[int]
RANKED_FIELD_NUMBER: _ClassVar[int]
SUCCESS_FIELD_NUMBER: _ClassVar[int]
embedding: bytes
logit: str
LOGIT_FIELD_NUMBER: _ClassVar[int]
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
message: str
ranked: bool
success: bool

logit: str
embedding: bytes
def __init__(
self,
message: _Optional[str] = ...,
Expand All @@ -69,3 +39,28 @@ class RunResponse(_message.Message):
logit: _Optional[str] = ...,
embedding: _Optional[bytes] = ...,
) -> None: ...

class EncodeRequest(_message.Message):
__slots__ = ("message", "reduction", "kwargs")
MESSAGE_FIELD_NUMBER: _ClassVar[int]
REDUCTION_FIELD_NUMBER: _ClassVar[int]
KWARGS_FIELD_NUMBER: _ClassVar[int]
message: str
reduction: str
kwargs: str
def __init__(
self,
message: _Optional[str] = ...,
reduction: _Optional[str] = ...,
kwargs: _Optional[str] = ...,
) -> None: ...

class EncodeResponse(_message.Message):
__slots__ = ("embedding", "success")
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
SUCCESS_FIELD_NUMBER: _ClassVar[int]
embedding: bytes
success: bool
def __init__(
self, embedding: _Optional[bytes] = ..., success: bool = ...
) -> None: ...
8 changes: 1 addition & 7 deletions alfred/fm/remote/protos/query_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,7 @@
"""Client and server classes corresponding to protobuf-defined services."""
import grpc

try:
import alfred.fm.remote.protos.query_pb2 as query__pb2
except ImportError:
try:
import query_pb2 as query__pb2
except ModuleNotFoundError:
from . import query_pb2 as query__pb2
from . import query_pb2 as query__pb2


class QueryServiceStub(object):
Expand Down
2 changes: 1 addition & 1 deletion docs/alfred/fm/huggingface.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Huggingface

## HuggingFaceModel

[Show source in huggingface.py:43](../../../alfred/fm/huggingface.py#L43)
[Show source in huggingface.py:44](../../../alfred/fm/huggingface.py#L44)

The HuggingFaceModel class is a wrapper for HuggingFace models,
including both Seq2Seq (Encoder-Decoder, e.g. T5, T0) and Causal
Expand Down
16 changes: 8 additions & 8 deletions docs/alfred/fm/remote/protos/query_pb2_grpc.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Query Pb2 Grpc

## QueryService

[Show source in query_pb2_grpc.py:71](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L71)
[Show source in query_pb2_grpc.py:65](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L65)

Missing associated documentation comment in .proto file.

Expand All @@ -34,7 +34,7 @@ class QueryService(object):

### QueryService.Encode

[Show source in query_pb2_grpc.py:74](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L74)
[Show source in query_pb2_grpc.py:68](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L68)

#### Signature

Expand All @@ -57,7 +57,7 @@ def Encode(

### QueryService.Run

[Show source in query_pb2_grpc.py:103](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L103)
[Show source in query_pb2_grpc.py:97](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L97)

#### Signature

Expand All @@ -82,7 +82,7 @@ def Run(

## QueryServiceServicer

[Show source in query_pb2_grpc.py:35](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L35)
[Show source in query_pb2_grpc.py:29](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L29)

Missing associated documentation comment in .proto file.

Expand All @@ -95,7 +95,7 @@ class QueryServiceServicer(object):

### QueryServiceServicer().Encode

[Show source in query_pb2_grpc.py:38](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L38)
[Show source in query_pb2_grpc.py:32](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L32)

Missing associated documentation comment in .proto file.

Expand All @@ -108,7 +108,7 @@ def Encode(self, request_iterator, context):

### QueryServiceServicer().Run

[Show source in query_pb2_grpc.py:44](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L44)
[Show source in query_pb2_grpc.py:38](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L38)

Missing associated documentation comment in .proto file.

Expand All @@ -123,7 +123,7 @@ def Run(self, request_iterator, context):

## QueryServiceStub

[Show source in query_pb2_grpc.py:14](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L14)
[Show source in query_pb2_grpc.py:8](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L8)

Missing associated documentation comment in .proto file.

Expand All @@ -139,7 +139,7 @@ class QueryServiceStub(object):

## add_QueryServiceServicer_to_server

[Show source in query_pb2_grpc.py:51](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L51)
[Show source in query_pb2_grpc.py:45](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L45)

#### Signature

Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
datasets~=2.4.0
paramiko>=2.7.2
pyarrow>=3.0.0
torch>=1.8.0

datasets
transformers
accelerate
numpy>=1.21.0
Expand All @@ -12,8 +12,8 @@ arrow>=0.13.1
setuptools>=58.0.4
scipy>=1.7.1

grpcio==1.48.1
protobuf==3.20.0
grpcio==1.62.0
protobuf==4.25.2
sentencepiece
deepspeed>=0.8.2

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
author_email='peilin_yu@brown.edu',
description='Toolkit for Prompted Weak Supervisions',
packages=find_packages(),
install_requires=['numpy', 'scipy', 'torch', 'tqdm', 'torchvision', "paramiko>=2.7.2", "pyarrow>=3.0.0", "grpcio==1.48.1", "protobuf==3.20.0"],
install_requires=['numpy', 'scipy', 'torch', 'tqdm', 'torchvision', "accelerate", "paramiko>=2.7.2", "pyarrow>=3.0.0", "grpcio", "protobuf<5"],
)
2 changes: 2 additions & 0 deletions test/unit_test/test_grpc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import threading
import time
import unittest

import grpc
Expand Down Expand Up @@ -36,6 +37,7 @@ def server_starter(server, client, port):

self.channel = grpc.insecure_channel(f"localhost:{self.port}")
self.stub = query_pb2_grpc.QueryServiceStub(self.channel)
time.sleep(2)

def test_run_single_query(self):
# Test running a single query using the client
Expand Down

0 comments on commit 3c2f3f0

Please sign in to comment.