24
24
25
25
from .. import __version__
26
26
from ..client import (
27
- Client ,
28
27
RESTfulChatglmCppChatModelHandle ,
29
28
RESTfulChatModelHandle ,
30
29
RESTfulClient ,
@@ -354,9 +353,7 @@ def model_generate(
354
353
):
355
354
endpoint = get_endpoint (endpoint )
356
355
if stream :
357
- # TODO: when stream=True, RestfulClient cannot generate words one by one.
358
- # So use Client in temporary. The implementation needs to be changed to
359
- # RestfulClient in the future.
356
+
360
357
async def generate_internal ():
361
358
while True :
362
359
# the prompt will be written to stdout.
@@ -365,7 +362,7 @@ async def generate_internal():
365
362
if prompt == "" :
366
363
break
367
364
print (f"Completion: { prompt } " , end = "" , file = sys .stdout )
368
- async for chunk in model .generate (
365
+ for chunk in model .generate (
369
366
prompt = prompt ,
370
367
generate_config = {"stream" : stream , "max_tokens" : max_tokens },
371
368
):
@@ -376,7 +373,7 @@ async def generate_internal():
376
373
print (choice ["text" ], end = "" , flush = True , file = sys .stdout )
377
374
print ("\n " , file = sys .stdout )
378
375
379
- client = Client ( endpoint = endpoint )
376
+ client = RESTfulClient ( base_url = endpoint )
380
377
model = client .get_model (model_uid = model_uid )
381
378
382
379
loop = asyncio .get_event_loop ()
@@ -436,9 +433,7 @@ def model_chat(
436
433
endpoint = get_endpoint (endpoint )
437
434
chat_history : "List[ChatCompletionMessage]" = []
438
435
if stream :
439
- # TODO: when stream=True, RestfulClient cannot generate words one by one.
440
- # So use Client in temporary. The implementation needs to be changed to
441
- # RestfulClient in the future.
436
+
442
437
async def chat_internal ():
443
438
while True :
444
439
# the prompt will be written to stdout.
@@ -449,7 +444,7 @@ async def chat_internal():
449
444
chat_history .append (ChatCompletionMessage (role = "user" , content = prompt ))
450
445
print ("Assistant: " , end = "" , file = sys .stdout )
451
446
response_content = ""
452
- async for chunk in model .chat (
447
+ for chunk in model .chat (
453
448
prompt = prompt ,
454
449
chat_history = chat_history ,
455
450
generate_config = {"stream" : stream , "max_tokens" : max_tokens },
@@ -465,7 +460,7 @@ async def chat_internal():
465
460
ChatCompletionMessage (role = "assistant" , content = response_content )
466
461
)
467
462
468
- client = Client ( endpoint = endpoint )
463
+ client = RESTfulClient ( base_url = endpoint )
469
464
model = client .get_model (model_uid = model_uid )
470
465
471
466
loop = asyncio .get_event_loop ()
0 commit comments