Skip to content

Commit 12e89a8

Browse files
committed
use restful client
1 parent 88372e0 commit 12e89a8

File tree

1 file changed

+6
-11
lines changed

1 file changed

+6
-11
lines changed

xinference/deploy/cmdline.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
from .. import __version__
2626
from ..client import (
27-
Client,
2827
RESTfulChatglmCppChatModelHandle,
2928
RESTfulChatModelHandle,
3029
RESTfulClient,
@@ -354,9 +353,7 @@ def model_generate(
354353
):
355354
endpoint = get_endpoint(endpoint)
356355
if stream:
357-
# TODO: when stream=True, RestfulClient cannot generate words one by one.
358-
# So use Client in temporary. The implementation needs to be changed to
359-
# RestfulClient in the future.
356+
360357
async def generate_internal():
361358
while True:
362359
# the prompt will be written to stdout.
@@ -365,7 +362,7 @@ async def generate_internal():
365362
if prompt == "":
366363
break
367364
print(f"Completion: {prompt}", end="", file=sys.stdout)
368-
async for chunk in model.generate(
365+
for chunk in model.generate(
369366
prompt=prompt,
370367
generate_config={"stream": stream, "max_tokens": max_tokens},
371368
):
@@ -376,7 +373,7 @@ async def generate_internal():
376373
print(choice["text"], end="", flush=True, file=sys.stdout)
377374
print("\n", file=sys.stdout)
378375

379-
client = Client(endpoint=endpoint)
376+
client = RESTfulClient(base_url=endpoint)
380377
model = client.get_model(model_uid=model_uid)
381378

382379
loop = asyncio.get_event_loop()
@@ -436,9 +433,7 @@ def model_chat(
436433
endpoint = get_endpoint(endpoint)
437434
chat_history: "List[ChatCompletionMessage]" = []
438435
if stream:
439-
# TODO: when stream=True, RestfulClient cannot generate words one by one.
440-
# So use Client in temporary. The implementation needs to be changed to
441-
# RestfulClient in the future.
436+
442437
async def chat_internal():
443438
while True:
444439
# the prompt will be written to stdout.
@@ -449,7 +444,7 @@ async def chat_internal():
449444
chat_history.append(ChatCompletionMessage(role="user", content=prompt))
450445
print("Assistant: ", end="", file=sys.stdout)
451446
response_content = ""
452-
async for chunk in model.chat(
447+
for chunk in model.chat(
453448
prompt=prompt,
454449
chat_history=chat_history,
455450
generate_config={"stream": stream, "max_tokens": max_tokens},
@@ -465,7 +460,7 @@ async def chat_internal():
465460
ChatCompletionMessage(role="assistant", content=response_content)
466461
)
467462

468-
client = Client(endpoint=endpoint)
463+
client = RESTfulClient(base_url=endpoint)
469464
model = client.get_model(model_uid=model_uid)
470465

471466
loop = asyncio.get_event_loop()

0 commit comments

Comments
 (0)