Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add compatibility for protobuf 4.x #62

Merged
merged 6 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion alfred/fm/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

logger = logging.getLogger(__name__)

from transformers import LlamaPreTrainedModel, MistralPreTrainedModel
from transformers.models.llama.modeling_llama import LlamaPreTrainedModel
from transformers.models.mistral.modeling_mistral import MistralPreTrainedModel

dtype_match = {
"auto": "auto",
Expand Down
10 changes: 5 additions & 5 deletions alfred/fm/remote/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from PIL import Image
from tqdm.auto import tqdm

from alfred.fm.query import Query, RankedQuery, CompletionQuery
from alfred.fm.remote.protos import query_pb2
from alfred.fm.remote.protos import query_pb2_grpc
from alfred.fm.remote.utils import get_ip, tensor_to_bytes, bytes_to_tensor, port_finder
from alfred.fm.response import RankedResponse, CompletionResponse
from ..query import Query, RankedQuery, CompletionQuery
from ..remote.protos import query_pb2
from ..remote.protos import query_pb2_grpc
from ..remote.utils import get_ip, tensor_to_bytes, bytes_to_tensor, port_finder
from ..response import RankedResponse, CompletionResponse

logger = logging.getLogger(__name__)

Expand Down
28 changes: 15 additions & 13 deletions alfred/fm/remote/protos/query_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

73 changes: 34 additions & 39 deletions alfred/fm/remote/protos/query_pb2.pyi
Original file line number Diff line number Diff line change
@@ -1,46 +1,17 @@
from typing import ClassVar as _ClassVar, Optional as _Optional

from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Optional as _Optional

DESCRIPTOR: _descriptor.FileDescriptor

class EncodeRequest(_message.Message):
__slots__ = ["kwargs", "message", "reduction"]
KWARGS_FIELD_NUMBER: _ClassVar[int]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
REDUCTION_FIELD_NUMBER: _ClassVar[int]
kwargs: str
message: str
reduction: str

def __init__(
self,
message: _Optional[str] = ...,
reduction: _Optional[str] = ...,
kwargs: _Optional[str] = ...,
) -> None: ...

class EncodeResponse(_message.Message):
__slots__ = ["embedding", "success"]
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
SUCCESS_FIELD_NUMBER: _ClassVar[int]
embedding: bytes
success: bool

def __init__(
self, embedding: _Optional[bytes] = ..., success: bool = ...
) -> None: ...

class RunRequest(_message.Message):
__slots__ = ["candidate", "kwargs", "message"]
__slots__ = ("message", "candidate", "kwargs")
MESSAGE_FIELD_NUMBER: _ClassVar[int]
CANDIDATE_FIELD_NUMBER: _ClassVar[int]
KWARGS_FIELD_NUMBER: _ClassVar[int]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
message: str
candidate: str
kwargs: str
message: str

def __init__(
self,
message: _Optional[str] = ...,
Expand All @@ -49,18 +20,17 @@ class RunRequest(_message.Message):
) -> None: ...

class RunResponse(_message.Message):
__slots__ = ["embedding", "logit", "message", "ranked", "success"]
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
LOGIT_FIELD_NUMBER: _ClassVar[int]
__slots__ = ("message", "ranked", "success", "logit", "embedding")
MESSAGE_FIELD_NUMBER: _ClassVar[int]
RANKED_FIELD_NUMBER: _ClassVar[int]
SUCCESS_FIELD_NUMBER: _ClassVar[int]
embedding: bytes
logit: str
LOGIT_FIELD_NUMBER: _ClassVar[int]
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
message: str
ranked: bool
success: bool

logit: str
embedding: bytes
def __init__(
self,
message: _Optional[str] = ...,
Expand All @@ -69,3 +39,28 @@ class RunResponse(_message.Message):
logit: _Optional[str] = ...,
embedding: _Optional[bytes] = ...,
) -> None: ...

class EncodeRequest(_message.Message):
__slots__ = ("message", "reduction", "kwargs")
MESSAGE_FIELD_NUMBER: _ClassVar[int]
REDUCTION_FIELD_NUMBER: _ClassVar[int]
KWARGS_FIELD_NUMBER: _ClassVar[int]
message: str
reduction: str
kwargs: str
def __init__(
self,
message: _Optional[str] = ...,
reduction: _Optional[str] = ...,
kwargs: _Optional[str] = ...,
) -> None: ...

class EncodeResponse(_message.Message):
__slots__ = ("embedding", "success")
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
SUCCESS_FIELD_NUMBER: _ClassVar[int]
embedding: bytes
success: bool
def __init__(
self, embedding: _Optional[bytes] = ..., success: bool = ...
) -> None: ...
8 changes: 1 addition & 7 deletions alfred/fm/remote/protos/query_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,7 @@
"""Client and server classes corresponding to protobuf-defined services."""
import grpc

try:
import alfred.fm.remote.protos.query_pb2 as query__pb2
except ImportError:
try:
import query_pb2 as query__pb2
except ModuleNotFoundError:
from . import query_pb2 as query__pb2
from . import query_pb2 as query__pb2


class QueryServiceStub(object):
Expand Down
2 changes: 1 addition & 1 deletion docs/alfred/fm/huggingface.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Huggingface

## HuggingFaceModel

[Show source in huggingface.py:43](../../../alfred/fm/huggingface.py#L43)
[Show source in huggingface.py:44](../../../alfred/fm/huggingface.py#L44)

The HuggingFaceModel class is a wrapper for HuggingFace models,
including both Seq2Seq (Encoder-Decoder, e.g. T5, T0) and Causal
Expand Down
16 changes: 8 additions & 8 deletions docs/alfred/fm/remote/protos/query_pb2_grpc.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Query Pb2 Grpc

## QueryService

[Show source in query_pb2_grpc.py:71](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L71)
[Show source in query_pb2_grpc.py:65](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L65)

Missing associated documentation comment in .proto file.

Expand All @@ -34,7 +34,7 @@ class QueryService(object):

### QueryService.Encode

[Show source in query_pb2_grpc.py:74](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L74)
[Show source in query_pb2_grpc.py:68](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L68)

#### Signature

Expand All @@ -57,7 +57,7 @@ def Encode(

### QueryService.Run

[Show source in query_pb2_grpc.py:103](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L103)
[Show source in query_pb2_grpc.py:97](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L97)

#### Signature

Expand All @@ -82,7 +82,7 @@ def Run(

## QueryServiceServicer

[Show source in query_pb2_grpc.py:35](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L35)
[Show source in query_pb2_grpc.py:29](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L29)

Missing associated documentation comment in .proto file.

Expand All @@ -95,7 +95,7 @@ class QueryServiceServicer(object):

### QueryServiceServicer().Encode

[Show source in query_pb2_grpc.py:38](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L38)
[Show source in query_pb2_grpc.py:32](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L32)

Missing associated documentation comment in .proto file.

Expand All @@ -108,7 +108,7 @@ def Encode(self, request_iterator, context):

### QueryServiceServicer().Run

[Show source in query_pb2_grpc.py:44](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L44)
[Show source in query_pb2_grpc.py:38](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L38)

Missing associated documentation comment in .proto file.

Expand All @@ -123,7 +123,7 @@ def Run(self, request_iterator, context):

## QueryServiceStub

[Show source in query_pb2_grpc.py:14](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L14)
[Show source in query_pb2_grpc.py:8](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L8)

Missing associated documentation comment in .proto file.

Expand All @@ -139,7 +139,7 @@ class QueryServiceStub(object):

## add_QueryServiceServicer_to_server

[Show source in query_pb2_grpc.py:51](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L51)
[Show source in query_pb2_grpc.py:45](../../../../../alfred/fm/remote/protos/query_pb2_grpc.py#L45)

#### Signature

Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
datasets~=2.4.0
paramiko>=2.7.2
pyarrow>=3.0.0
torch>=1.8.0

datasets
transformers
accelerate
numpy>=1.21.0
Expand All @@ -12,8 +12,8 @@ arrow>=0.13.1
setuptools>=58.0.4
scipy>=1.7.1

grpcio==1.48.1
protobuf==3.20.0
grpcio==1.62.0
protobuf==4.25.2
sentencepiece
deepspeed>=0.8.2

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
author_email='peilin_yu@brown.edu',
description='Toolkit for Prompted Weak Supervisions',
packages=find_packages(),
install_requires=['numpy', 'scipy', 'torch', 'tqdm', 'torchvision', "paramiko>=2.7.2", "pyarrow>=3.0.0", "grpcio==1.48.1", "protobuf==3.20.0"],
install_requires=['numpy', 'scipy', 'torch', 'tqdm', 'torchvision', "accelerate", "paramiko>=2.7.2", "pyarrow>=3.0.0", "grpcio", "protobuf<5"],
)
2 changes: 2 additions & 0 deletions test/unit_test/test_grpc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import threading
import time
import unittest

import grpc
Expand Down Expand Up @@ -36,6 +37,7 @@ def server_starter(server, client, port):

self.channel = grpc.insecure_channel(f"localhost:{self.port}")
self.stub = query_pb2_grpc.QueryServiceStub(self.channel)
time.sleep(2)

def test_run_single_query(self):
# Test running a single query using the client
Expand Down
Loading