diff --git a/examples/vis_chatbot_gradio.py b/examples/vis_chatbot_gradio.py index e10241abb..f01083448 100644 --- a/examples/vis_chatbot_gradio.py +++ b/examples/vis_chatbot_gradio.py @@ -13,7 +13,6 @@ import numpy as np import os import sys -sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0]))) import torch import warnings import gradio as gr @@ -31,6 +30,7 @@ logging.disable(logging.ERROR) warnings.filterwarnings("ignore") +torch.multiprocessing.set_start_method('spawn', force=True) title = """

LMFlow-CHAT

@@ -114,52 +114,6 @@ class ChatbotArguments: } ) -pipeline_name = "inferencer" -PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) - -parser = HfArgumentParser(( - VisModelArguments, - PipelineArguments, - ChatbotArguments, -)) -model_args, pipeline_args, chatbot_args = ( - parser.parse_args_into_dataclasses() -) - -with open(pipeline_args.deepspeed, "r") as f: - ds_config = json.load(f) - -model = AutoModel.get_model( - model_args, - tune_strategy='none', - ds_config=ds_config, - device=pipeline_args.device, - custom_model=model_args.custom_model, -) - -data_args = DatasetArguments(dataset_path=None) -dataset = Dataset(data_args, backend="dict") - -inferencer = AutoPipeline.get_pipeline( - pipeline_name=pipeline_name, - model_args=model_args, - data_args=data_args, - pipeline_args=pipeline_args, -) - -# Chats -model_name = model_args.model_name_or_path -if model_args.lora_model_path is not None: - model_name += f" + {model_args.lora_model_path}" - - -end_string = chatbot_args.end_string -prompt_structure = chatbot_args.prompt_structure - -title = """

Demo of Multi-modality chatbot from LMFlow

""" -description = """

This is the demo of Multi-modality chatbot from LMFlow. Upload your images and start chatting!

""" -# article = """

-# """ def gradio_reset(chat_state, img_list): if chat_state is not None: @@ -233,37 +187,42 @@ def gradio_answer(chatbot, chat_state, image_list, num_beams=1, temperature=1.0) chatbot[-1][1] = '' print_index = 0 - token_per_step = 20 # 48 - max_new_tokens = 1024 + token_per_step = 4 # 48 + max_new_tokens = -1 temperature = 0.7 - - for response, flag_break in inferencer.stream_inference( - context=chatbot, - model=model, - max_new_tokens=max_new_tokens, - token_per_step=token_per_step, - temperature=temperature, - end_string=end_string, - input_dataset=input_dataset, - remove_image_flag=remove_image_flag, - ): - # Prints characters in the buffer - new_print_index = print_index - for char in response[print_index:]: - if end_string is not None and char == end_string[0]: - if new_print_index + len(end_string) >= len(response): - break - - new_print_index += 1 - chatbot[-1][1] += char - chat_state += char - yield chatbot, chat_state, image_list - # await asyncio.sleep(1) - - print_index = new_print_index - - if flag_break: - break + context = chatbot + + request_queue.put(( + context, + max_new_tokens, + token_per_step, + temperature, + end_string, + input_dataset, + remove_image_flag + )) + + while True: + if not response_queue.empty(): + response, flag_break = response_queue.get() + + # Prints characters in the buffer + new_print_index = print_index + for char in response[print_index:]: + if end_string is not None and char == end_string[0]: + if new_print_index + len(end_string) >= len(response): + break + + new_print_index += 1 + chatbot[-1][1] += char + chat_state += char + time.sleep(0.03) + yield chatbot, chat_state, image_list + + print_index = new_print_index + + if flag_break: + break char = "\n" chatbot[-1][1] += char @@ -271,54 +230,162 @@ def gradio_answer(chatbot, chat_state, image_list, num_beams=1, temperature=1.0) yield chatbot, chat_state, image_list -with gr.Blocks() as demo: - gr.Markdown(title) - gr.Markdown(description) - chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500) - - with gr.Row(): - chat_state = gr.State() - image_list = gr.State() - - with gr.Column(scale=0.1, min_width=0): - clear = gr.Button("Restart") - - with gr.Column(scale=0.8): - text_input = gr.Textbox( - show_label=False, - placeholder="Enter text and press enter, or upload an image", - ).style(container=False) - - with gr.Column(scale=0.1, min_width=0): - upload_button = gr.UploadButton("📁", file_types=["image"]) - - txt_msg = text_input.submit( - fn=gradio_ask, - inputs=[text_input, chatbot, chat_state], - outputs=[text_input, chatbot, chat_state], - queue=False, - ).then( - fn=gradio_answer, - inputs=[chatbot, chat_state, image_list], - outputs=[chatbot, chat_state, image_list], - ) - txt_msg.then( - lambda: gr.update(interactive=True), None, [text_input], queue=False +def start_inferencer( + request_queue, + response_queue, + model_args, + pipeline_name, + pipeline_args, + data_args, + dataset, + chatbot_args, +): + with open(pipeline_args.deepspeed, "r") as f: + ds_config = json.load(f) + + model = AutoModel.get_model( + model_args, + tune_strategy='none', + ds_config=ds_config, + device=pipeline_args.device, + custom_model=model_args.custom_model, ) - file_msg = upload_button.upload( - fn=upload_image, - inputs=[upload_button, chatbot, text_input, chat_state, image_list], - outputs=[text_input, chatbot, chat_state, image_list], - queue=False, + inferencer = AutoPipeline.get_pipeline( + pipeline_name=pipeline_name, + model_args=model_args, + data_args=data_args, + pipeline_args=pipeline_args, ) - clear.click( - fn=gradio_reset, - inputs=[chat_state, image_list], - outputs=[chatbot, text_input, upload_button, chat_state, image_list], - queue=False, - ) + while True: + if not request_queue.empty(): + request = request_queue.get() + + context = request[0] + max_new_tokens = request[1] + token_per_step = request[2] + temperature = request[3] + end_string = request[4] + input_dataset = request[5] + remove_image_flag = request[6] + + break_in_the_middle = False + for response_text, flag_break in inferencer.stream_inference( + context=context, + model=model, + max_new_tokens=max_new_tokens, + token_per_step=token_per_step, + temperature=temperature, + end_string=end_string, + input_dataset=input_dataset, + remove_image_flag=remove_image_flag, + ): + response_queue.put((response_text, flag_break)) + if flag_break: + break_in_the_middle = True + break + if not break_in_the_middle: + response_text = '' + flag_break = True + response_queue.put((response_text, flag_break)) + + time.sleep(0.001) + + +if __name__ == "__main__": + pipeline_name = "inferencer" + PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) + + parser = HfArgumentParser(( + VisModelArguments, + PipelineArguments, + ChatbotArguments, + )) + model_args, pipeline_args, chatbot_args = ( + parser.parse_args_into_dataclasses() + ) + data_args = DatasetArguments(dataset_path=None) + dataset = Dataset(data_args, backend="dict") + + request_queue = torch.multiprocessing.Queue() + response_queue = torch.multiprocessing.Queue() + inferencer_process = torch.multiprocessing.Process( + target=start_inferencer, + args=( + request_queue, + response_queue, + model_args, + pipeline_name, + pipeline_args, + data_args, + dataset, + chatbot_args, + ), + ) + inferencer_process.start() + + # Chats + model_name = model_args.model_name_or_path + if model_args.lora_model_path is not None: + model_name += f" + {model_args.lora_model_path}" + + end_string = chatbot_args.end_string + prompt_structure = chatbot_args.prompt_structure + + title = """

Demo of Multi-modality chatbot from LMFlow

""" + description = """

This is the demo of Multi-modality chatbot from LMFlow. Upload your images and start chatting!

""" + + with gr.Blocks() as demo: + gr.Markdown(title) + gr.Markdown(description) + chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500) + + with gr.Row(): + chat_state = gr.State() + image_list = gr.State() + + with gr.Column(scale=0.1, min_width=0): + clear = gr.Button("Restart") + + with gr.Column(scale=0.8): + text_input = gr.Textbox( + show_label=False, + placeholder="Enter text and press enter, or upload an image", + ).style(container=False) + + with gr.Column(scale=0.1, min_width=0): + upload_button = gr.UploadButton("📁", file_types=["image"]) + + txt_msg = text_input.submit( + fn=gradio_ask, + inputs=[text_input, chatbot, chat_state], + outputs=[text_input, chatbot, chat_state], + queue=False, + ).then( + fn=gradio_answer, + inputs=[chatbot, chat_state, image_list], + outputs=[chatbot, chat_state, image_list], + ) + txt_msg.then( + lambda: gr.update(interactive=True), None, [text_input], queue=False + ) + + file_msg = upload_button.upload( + fn=upload_image, + inputs=[upload_button, chatbot, text_input, chat_state, image_list], + outputs=[text_input, chatbot, chat_state, image_list], + queue=False, + ) + + clear.click( + fn=gradio_reset, + inputs=[chat_state, image_list], + outputs=[chatbot, text_input, upload_button, chat_state, image_list], + queue=False, + ) + + demo.launch(share=True, enable_queue=True) + inferencer_process.join() -demo.launch(share=True, enable_queue=True) diff --git a/src/lmflow/models/hf_encoder_decoder_model.py b/src/lmflow/models/hf_encoder_decoder_model.py index 40ac3e6b0..e63ed2d43 100644 --- a/src/lmflow/models/hf_encoder_decoder_model.py +++ b/src/lmflow/models/hf_encoder_decoder_model.py @@ -270,11 +270,7 @@ def encode(self, input: Union[str, List[str]], *args, **kwargs ) -> Union[List[i if isinstance(input, dict): # TODO refactor the input type to make it elegant. kwargs.update(input) - import time - start = time.time() tokens = self.tokenizer(*args, **kwargs) - end = time.time() - print('#######', end - start) return tokens elif isinstance(input, list): return self.tokenizer(text=input, *args, **kwargs)#batch encode,will automatically do left padding diff --git a/src/lmflow/pipeline/inferencer.py b/src/lmflow/pipeline/inferencer.py index d034fbeb0..02e230b41 100644 --- a/src/lmflow/pipeline/inferencer.py +++ b/src/lmflow/pipeline/inferencer.py @@ -223,11 +223,19 @@ def inference( # print(f"{current_time}: inferencer.inference: model.inference end", flush=True) # print(f'type of context text: {type(dataset.to_dict()["instances"][0]["text"])}', flush=True) - text_out = model.decode(outputs[0], skip_special_tokens=True) # only return the generation, trucating the input if self.model_args.arch_type != "vision_encoder_decoder": + text_out = model.decode(outputs[0], skip_special_tokens=True) prompt_length = len(model.decode(inputs[0], skip_special_tokens=True,)) text_out = text_out[prompt_length:] + else: + # to avoid redundant/missing leading space problem, we use a + # part of the input text + input_text = inputs['input_ids'][0][-1:] + text_out = model.decode(torch.cat([input_text, outputs[0]]), skip_special_tokens=True) + prompt_length = len(model.decode(input_text, skip_special_tokens=True,)) + text_out = text_out[prompt_length:] + output_dict["instances"].append({ "text": text_out }) # current_time = time.strftime("%H:%M:%S", time.localtime())