diff --git a/examples/vis_chatbot_gradio.py b/examples/vis_chatbot_gradio.py
index e10241abb..f01083448 100644
--- a/examples/vis_chatbot_gradio.py
+++ b/examples/vis_chatbot_gradio.py
@@ -13,7 +13,6 @@
import numpy as np
import os
import sys
-sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0])))
import torch
import warnings
import gradio as gr
@@ -31,6 +30,7 @@
logging.disable(logging.ERROR)
warnings.filterwarnings("ignore")
+torch.multiprocessing.set_start_method('spawn', force=True)
title = """
LMFlow-CHAT
@@ -114,52 +114,6 @@ class ChatbotArguments:
}
)
-pipeline_name = "inferencer"
-PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)
-
-parser = HfArgumentParser((
- VisModelArguments,
- PipelineArguments,
- ChatbotArguments,
-))
-model_args, pipeline_args, chatbot_args = (
- parser.parse_args_into_dataclasses()
-)
-
-with open(pipeline_args.deepspeed, "r") as f:
- ds_config = json.load(f)
-
-model = AutoModel.get_model(
- model_args,
- tune_strategy='none',
- ds_config=ds_config,
- device=pipeline_args.device,
- custom_model=model_args.custom_model,
-)
-
-data_args = DatasetArguments(dataset_path=None)
-dataset = Dataset(data_args, backend="dict")
-
-inferencer = AutoPipeline.get_pipeline(
- pipeline_name=pipeline_name,
- model_args=model_args,
- data_args=data_args,
- pipeline_args=pipeline_args,
-)
-
-# Chats
-model_name = model_args.model_name_or_path
-if model_args.lora_model_path is not None:
- model_name += f" + {model_args.lora_model_path}"
-
-
-end_string = chatbot_args.end_string
-prompt_structure = chatbot_args.prompt_structure
-
-title = """Demo of Multi-modality chatbot from LMFlow
"""
-description = """This is the demo of Multi-modality chatbot from LMFlow. Upload your images and start chatting!
"""
-# article = """
-# """
def gradio_reset(chat_state, img_list):
if chat_state is not None:
@@ -233,37 +187,42 @@ def gradio_answer(chatbot, chat_state, image_list, num_beams=1, temperature=1.0)
chatbot[-1][1] = ''
print_index = 0
- token_per_step = 20 # 48
- max_new_tokens = 1024
+ token_per_step = 4 # 48
+ max_new_tokens = -1
temperature = 0.7
-
- for response, flag_break in inferencer.stream_inference(
- context=chatbot,
- model=model,
- max_new_tokens=max_new_tokens,
- token_per_step=token_per_step,
- temperature=temperature,
- end_string=end_string,
- input_dataset=input_dataset,
- remove_image_flag=remove_image_flag,
- ):
- # Prints characters in the buffer
- new_print_index = print_index
- for char in response[print_index:]:
- if end_string is not None and char == end_string[0]:
- if new_print_index + len(end_string) >= len(response):
- break
-
- new_print_index += 1
- chatbot[-1][1] += char
- chat_state += char
- yield chatbot, chat_state, image_list
- # await asyncio.sleep(1)
-
- print_index = new_print_index
-
- if flag_break:
- break
+ context = chatbot
+
+ request_queue.put((
+ context,
+ max_new_tokens,
+ token_per_step,
+ temperature,
+ end_string,
+ input_dataset,
+ remove_image_flag
+ ))
+
+ while True:
+ if not response_queue.empty():
+ response, flag_break = response_queue.get()
+
+ # Prints characters in the buffer
+ new_print_index = print_index
+ for char in response[print_index:]:
+ if end_string is not None and char == end_string[0]:
+ if new_print_index + len(end_string) >= len(response):
+ break
+
+ new_print_index += 1
+ chatbot[-1][1] += char
+ chat_state += char
+ time.sleep(0.03)
+ yield chatbot, chat_state, image_list
+
+ print_index = new_print_index
+
+ if flag_break:
+ break
char = "\n"
chatbot[-1][1] += char
@@ -271,54 +230,162 @@ def gradio_answer(chatbot, chat_state, image_list, num_beams=1, temperature=1.0)
yield chatbot, chat_state, image_list
-with gr.Blocks() as demo:
- gr.Markdown(title)
- gr.Markdown(description)
- chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)
-
- with gr.Row():
- chat_state = gr.State()
- image_list = gr.State()
-
- with gr.Column(scale=0.1, min_width=0):
- clear = gr.Button("Restart")
-
- with gr.Column(scale=0.8):
- text_input = gr.Textbox(
- show_label=False,
- placeholder="Enter text and press enter, or upload an image",
- ).style(container=False)
-
- with gr.Column(scale=0.1, min_width=0):
- upload_button = gr.UploadButton("📁", file_types=["image"])
-
- txt_msg = text_input.submit(
- fn=gradio_ask,
- inputs=[text_input, chatbot, chat_state],
- outputs=[text_input, chatbot, chat_state],
- queue=False,
- ).then(
- fn=gradio_answer,
- inputs=[chatbot, chat_state, image_list],
- outputs=[chatbot, chat_state, image_list],
- )
- txt_msg.then(
- lambda: gr.update(interactive=True), None, [text_input], queue=False
+def start_inferencer(
+ request_queue,
+ response_queue,
+ model_args,
+ pipeline_name,
+ pipeline_args,
+ data_args,
+ dataset,
+ chatbot_args,
+):
+ with open(pipeline_args.deepspeed, "r") as f:
+ ds_config = json.load(f)
+
+ model = AutoModel.get_model(
+ model_args,
+ tune_strategy='none',
+ ds_config=ds_config,
+ device=pipeline_args.device,
+ custom_model=model_args.custom_model,
)
- file_msg = upload_button.upload(
- fn=upload_image,
- inputs=[upload_button, chatbot, text_input, chat_state, image_list],
- outputs=[text_input, chatbot, chat_state, image_list],
- queue=False,
+ inferencer = AutoPipeline.get_pipeline(
+ pipeline_name=pipeline_name,
+ model_args=model_args,
+ data_args=data_args,
+ pipeline_args=pipeline_args,
)
- clear.click(
- fn=gradio_reset,
- inputs=[chat_state, image_list],
- outputs=[chatbot, text_input, upload_button, chat_state, image_list],
- queue=False,
- )
+ while True:
+ if not request_queue.empty():
+ request = request_queue.get()
+
+ context = request[0]
+ max_new_tokens = request[1]
+ token_per_step = request[2]
+ temperature = request[3]
+ end_string = request[4]
+ input_dataset = request[5]
+ remove_image_flag = request[6]
+
+ break_in_the_middle = False
+ for response_text, flag_break in inferencer.stream_inference(
+ context=context,
+ model=model,
+ max_new_tokens=max_new_tokens,
+ token_per_step=token_per_step,
+ temperature=temperature,
+ end_string=end_string,
+ input_dataset=input_dataset,
+ remove_image_flag=remove_image_flag,
+ ):
+ response_queue.put((response_text, flag_break))
+ if flag_break:
+ break_in_the_middle = True
+ break
+ if not break_in_the_middle:
+ response_text = ''
+ flag_break = True
+ response_queue.put((response_text, flag_break))
+
+ time.sleep(0.001)
+
+
+if __name__ == "__main__":
+ pipeline_name = "inferencer"
+ PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)
+
+ parser = HfArgumentParser((
+ VisModelArguments,
+ PipelineArguments,
+ ChatbotArguments,
+ ))
+ model_args, pipeline_args, chatbot_args = (
+ parser.parse_args_into_dataclasses()
+ )
+ data_args = DatasetArguments(dataset_path=None)
+ dataset = Dataset(data_args, backend="dict")
+
+ request_queue = torch.multiprocessing.Queue()
+ response_queue = torch.multiprocessing.Queue()
+ inferencer_process = torch.multiprocessing.Process(
+ target=start_inferencer,
+ args=(
+ request_queue,
+ response_queue,
+ model_args,
+ pipeline_name,
+ pipeline_args,
+ data_args,
+ dataset,
+ chatbot_args,
+ ),
+ )
+ inferencer_process.start()
+
+ # Chats
+ model_name = model_args.model_name_or_path
+ if model_args.lora_model_path is not None:
+ model_name += f" + {model_args.lora_model_path}"
+
+ end_string = chatbot_args.end_string
+ prompt_structure = chatbot_args.prompt_structure
+
+ title = """Demo of Multi-modality chatbot from LMFlow
"""
+ description = """This is the demo of Multi-modality chatbot from LMFlow. Upload your images and start chatting!
"""
+
+ with gr.Blocks() as demo:
+ gr.Markdown(title)
+ gr.Markdown(description)
+ chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)
+
+ with gr.Row():
+ chat_state = gr.State()
+ image_list = gr.State()
+
+ with gr.Column(scale=0.1, min_width=0):
+ clear = gr.Button("Restart")
+
+ with gr.Column(scale=0.8):
+ text_input = gr.Textbox(
+ show_label=False,
+ placeholder="Enter text and press enter, or upload an image",
+ ).style(container=False)
+
+ with gr.Column(scale=0.1, min_width=0):
+ upload_button = gr.UploadButton("📁", file_types=["image"])
+
+ txt_msg = text_input.submit(
+ fn=gradio_ask,
+ inputs=[text_input, chatbot, chat_state],
+ outputs=[text_input, chatbot, chat_state],
+ queue=False,
+ ).then(
+ fn=gradio_answer,
+ inputs=[chatbot, chat_state, image_list],
+ outputs=[chatbot, chat_state, image_list],
+ )
+ txt_msg.then(
+ lambda: gr.update(interactive=True), None, [text_input], queue=False
+ )
+
+ file_msg = upload_button.upload(
+ fn=upload_image,
+ inputs=[upload_button, chatbot, text_input, chat_state, image_list],
+ outputs=[text_input, chatbot, chat_state, image_list],
+ queue=False,
+ )
+
+ clear.click(
+ fn=gradio_reset,
+ inputs=[chat_state, image_list],
+ outputs=[chatbot, text_input, upload_button, chat_state, image_list],
+ queue=False,
+ )
+
+ demo.launch(share=True, enable_queue=True)
+ inferencer_process.join()
-demo.launch(share=True, enable_queue=True)
diff --git a/src/lmflow/models/hf_encoder_decoder_model.py b/src/lmflow/models/hf_encoder_decoder_model.py
index 40ac3e6b0..e63ed2d43 100644
--- a/src/lmflow/models/hf_encoder_decoder_model.py
+++ b/src/lmflow/models/hf_encoder_decoder_model.py
@@ -270,11 +270,7 @@ def encode(self, input: Union[str, List[str]], *args, **kwargs ) -> Union[List[i
if isinstance(input, dict):
# TODO refactor the input type to make it elegant.
kwargs.update(input)
- import time
- start = time.time()
tokens = self.tokenizer(*args, **kwargs)
- end = time.time()
- print('#######', end - start)
return tokens
elif isinstance(input, list):
return self.tokenizer(text=input, *args, **kwargs)#batch encode,will automatically do left padding
diff --git a/src/lmflow/pipeline/inferencer.py b/src/lmflow/pipeline/inferencer.py
index d034fbeb0..02e230b41 100644
--- a/src/lmflow/pipeline/inferencer.py
+++ b/src/lmflow/pipeline/inferencer.py
@@ -223,11 +223,19 @@ def inference(
# print(f"{current_time}: inferencer.inference: model.inference end", flush=True)
# print(f'type of context text: {type(dataset.to_dict()["instances"][0]["text"])}', flush=True)
- text_out = model.decode(outputs[0], skip_special_tokens=True)
# only return the generation, trucating the input
if self.model_args.arch_type != "vision_encoder_decoder":
+ text_out = model.decode(outputs[0], skip_special_tokens=True)
prompt_length = len(model.decode(inputs[0], skip_special_tokens=True,))
text_out = text_out[prompt_length:]
+ else:
+ # to avoid redundant/missing leading space problem, we use a
+ # part of the input text
+ input_text = inputs['input_ids'][0][-1:]
+ text_out = model.decode(torch.cat([input_text, outputs[0]]), skip_special_tokens=True)
+ prompt_length = len(model.decode(input_text, skip_special_tokens=True,))
+ text_out = text_out[prompt_length:]
+
output_dict["instances"].append({ "text": text_out })
# current_time = time.strftime("%H:%M:%S", time.localtime())