diff --git a/xinference/deploy/cmdline.py b/xinference/deploy/cmdline.py index 6b67c67c61..dcb9dbedaa 100644 --- a/xinference/deploy/cmdline.py +++ b/xinference/deploy/cmdline.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import configparser import logging import os @@ -24,7 +23,6 @@ from .. import __version__ from ..client import ( - Client, RESTfulChatglmCppChatModelHandle, RESTfulChatModelHandle, RESTfulClient, @@ -36,7 +34,6 @@ XINFERENCE_DEFAULT_LOCAL_HOST, XINFERENCE_ENV_ENDPOINT, ) -from ..isolation import Isolation from ..types import ChatCompletionMessage try: @@ -353,68 +350,39 @@ def model_generate( stream: bool, ): endpoint = get_endpoint(endpoint) - if stream: - # TODO: when stream=True, RestfulClient cannot generate words one by one. - # So use Client in temporary. The implementation needs to be changed to - # RestfulClient in the future. - async def generate_internal(): - while True: - # the prompt will be written to stdout. - # https://docs.python.org/3.10/library/functions.html#input - prompt = input("Prompt: ") - if prompt == "": - break - print(f"Completion: {prompt}", end="", file=sys.stdout) - async for chunk in model.generate( - prompt=prompt, - generate_config={"stream": stream, "max_tokens": max_tokens}, - ): - choice = chunk["choices"][0] - if "text" not in choice: - continue - else: - print(choice["text"], end="", flush=True, file=sys.stdout) - print("\n", file=sys.stdout) - - client = Client(endpoint=endpoint) - model = client.get_model(model_uid=model_uid) - - loop = asyncio.get_event_loop() - coro = generate_internal() - - if loop.is_running(): - isolation = Isolation(asyncio.new_event_loop(), threaded=True) - isolation.start() - isolation.call(coro) + client = RESTfulClient(base_url=endpoint) + model = client.get_model(model_uid=model_uid) + if not isinstance(model, (RESTfulChatModelHandle, RESTfulGenerateModelHandle)): + raise ValueError(f"model {model_uid} has no generate method") + + while True: + # the prompt will be written to stdout. + # https://docs.python.org/3.10/library/functions.html#input + prompt = input("Prompt: ") + if prompt.lower() == "exit" or prompt.lower() == "quit": + break + print(f"Completion: {prompt}", end="", file=sys.stdout) + + if stream: + iter = model.generate( + prompt=prompt, + generate_config={"stream": stream, "max_tokens": max_tokens}, + ) + assert not isinstance(iter, dict) + for chunk in iter: + choice = chunk["choices"][0] + if "text" not in choice: + continue + else: + print(choice["text"], end="", flush=True, file=sys.stdout) else: - task = loop.create_task(coro) - try: - loop.run_until_complete(task) - except KeyboardInterrupt: - task.cancel() - loop.run_until_complete(task) - # avoid displaying exception-unhandled warnings - task.exception() - else: - restful_client = RESTfulClient(base_url=endpoint) - restful_model = restful_client.get_model(model_uid=model_uid) - if not isinstance( - restful_model, (RESTfulChatModelHandle, RESTfulGenerateModelHandle) - ): - raise ValueError(f"model {model_uid} has no generate method") - - while True: - prompt = input("User: ") - if prompt == "": - break - print(f"Assistant: {prompt}", end="", file=sys.stdout) - response = restful_model.generate( + response = model.generate( prompt=prompt, generate_config={"stream": stream, "max_tokens": max_tokens}, ) - if not isinstance(response, dict): - raise ValueError("generate result is not valid") - print(f"{response['choices'][0]['text']}\n", file=sys.stdout) + assert isinstance(response, dict) + print(f"{response['choices'][0]['text']}", file=sys.stdout) + print("\n", file=sys.stdout) @cli.command("chat") @@ -434,82 +402,52 @@ def model_chat( ): # TODO: chat model roles may not be user and assistant. endpoint = get_endpoint(endpoint) + client = RESTfulClient(base_url=endpoint) + model = client.get_model(model_uid=model_uid) + if not isinstance( + model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle) + ): + raise ValueError(f"model {model_uid} has no chat method") + chat_history: "List[ChatCompletionMessage]" = [] - if stream: - # TODO: when stream=True, RestfulClient cannot generate words one by one. - # So use Client in temporary. The implementation needs to be changed to - # RestfulClient in the future. - async def chat_internal(): - while True: - # the prompt will be written to stdout. - # https://docs.python.org/3.10/library/functions.html#input - prompt = input("User: ") - if prompt == "": - break - chat_history.append(ChatCompletionMessage(role="user", content=prompt)) - print("Assistant: ", end="", file=sys.stdout) - response_content = "" - async for chunk in model.chat( - prompt=prompt, - chat_history=chat_history, - generate_config={"stream": stream, "max_tokens": max_tokens}, - ): - delta = chunk["choices"][0]["delta"] - if "content" not in delta: - continue - else: - response_content += delta["content"] - print(delta["content"], end="", flush=True, file=sys.stdout) - print("\n", file=sys.stdout) - chat_history.append( - ChatCompletionMessage(role="assistant", content=response_content) - ) - - client = Client(endpoint=endpoint) - model = client.get_model(model_uid=model_uid) - - loop = asyncio.get_event_loop() - coro = chat_internal() - - if loop.is_running(): - isolation = Isolation(asyncio.new_event_loop(), threaded=True) - isolation.start() - isolation.call(coro) + while True: + # the prompt will be written to stdout. + # https://docs.python.org/3.10/library/functions.html#input + prompt = input("User: ") + if prompt == "": + break + chat_history.append(ChatCompletionMessage(role="user", content=prompt)) + print("Assistant: ", end="", file=sys.stdout) + + response_content = "" + if stream: + iter = model.chat( + prompt=prompt, + chat_history=chat_history, + generate_config={"stream": stream, "max_tokens": max_tokens}, + ) + assert not isinstance(iter, dict) + for chunk in iter: + delta = chunk["choices"][0]["delta"] + if "content" not in delta: + continue + else: + response_content += delta["content"] + print(delta["content"], end="", flush=True, file=sys.stdout) else: - task = loop.create_task(coro) - try: - loop.run_until_complete(task) - except KeyboardInterrupt: - task.cancel() - loop.run_until_complete(task) - # avoid displaying exception-unhandled warnings - task.exception() - else: - restful_client = RESTfulClient(base_url=endpoint) - restful_model = restful_client.get_model(model_uid=model_uid) - if not isinstance( - restful_model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle) - ): - raise ValueError(f"model {model_uid} has no chat method") - - while True: - prompt = input("User: ") - if prompt == "": - break - chat_history.append(ChatCompletionMessage(role="user", content=prompt)) - print("Assistant: ", end="", file=sys.stdout) - response = restful_model.chat( + response = model.chat( prompt=prompt, chat_history=chat_history, generate_config={"stream": stream, "max_tokens": max_tokens}, ) - if not isinstance(response, dict): - raise ValueError("chat result is not valid") + assert isinstance(response, dict) response_content = response["choices"][0]["message"]["content"] - print(f"{response_content}\n", file=sys.stdout) - chat_history.append( - ChatCompletionMessage(role="assistant", content=response_content) - ) + print(f"{response_content}", file=sys.stdout) + + chat_history.append( + ChatCompletionMessage(role="assistant", content=response_content) + ) + print("\n", file=sys.stdout) if __name__ == "__main__": diff --git a/xinference/deploy/test/test_cmdline.py b/xinference/deploy/test/test_cmdline.py index 2d9be3be32..dbf28bed1c 100644 --- a/xinference/deploy/test/test_cmdline.py +++ b/xinference/deploy/test/test_cmdline.py @@ -18,7 +18,7 @@ import pytest from click.testing import CliRunner -from ...client import Client +from ...client import RESTfulClient from ..cmdline import ( list_model_registrations, model_chat, @@ -59,7 +59,7 @@ def test_cmdline(setup, stream): """ # if use `model_launch` command to launch model, CI will fail. # So use client to launch model in temporary - client = Client(endpoint) + client = RESTfulClient(endpoint) model_uid = client.launch_model( model_name="orca", model_size_in_billions=3, quantization="q4_0" )