Skip to content

Commit

Permalink
Add multiprocessing support to speedup vis chatbot
Browse files Browse the repository at this point in the history
Now the vision chatbot (gradio) can think and talk simultaneously (with two
processes, one handling thinking and one handling talking, in producer-consumer
fashion)
  • Loading branch information
research4pan committed Jul 19, 2023
1 parent 284080f commit 7a80c6c
Showing 1 changed file with 188 additions and 122 deletions.
310 changes: 188 additions & 122 deletions examples/vis_chatbot_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import numpy as np
import os
import sys
sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0])))
import torch
import warnings
import gradio as gr
Expand All @@ -32,6 +31,7 @@

logging.disable(logging.ERROR)
warnings.filterwarnings("ignore")
torch.multiprocessing.set_start_method('spawn', force=True)

title = """
<h1 align="center">LMFlow-CHAT</h1>
Expand Down Expand Up @@ -115,52 +115,6 @@ class ChatbotArguments:
}
)

pipeline_name = "inferencer"
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)

parser = HfArgumentParser((
VisModelArguments,
PipelineArguments,
ChatbotArguments,
))
model_args, pipeline_args, chatbot_args = (
parser.parse_args_into_dataclasses()
)

with open(pipeline_args.deepspeed, "r") as f:
ds_config = json.load(f)

model = AutoModel.get_model(
model_args,
tune_strategy='none',
ds_config=ds_config,
device=pipeline_args.device,
custom_model=model_args.custom_model,
)

data_args = DatasetArguments(dataset_path=None)
dataset = Dataset(data_args, backend="dict")

inferencer = AutoPipeline.get_pipeline(
pipeline_name=pipeline_name,
model_args=model_args,
data_args=data_args,
pipeline_args=pipeline_args,
)

# Chats
model_name = model_args.model_name_or_path
if model_args.lora_model_path is not None:
model_name += f" + {model_args.lora_model_path}"


end_string = chatbot_args.end_string
prompt_structure = chatbot_args.prompt_structure

title = """<h1 align="center">Demo of Multi-modality chatbot from LMFlow</h1>"""
description = """<h3>This is the demo of Multi-modality chatbot from LMFlow. Upload your images and start chatting!</h3>"""
# article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
# """

def gradio_reset(chat_state, img_list):
if chat_state is not None:
Expand Down Expand Up @@ -223,7 +177,7 @@ def gradio_ask(user_message, chatbot, chat_state):
return '', chatbot, chat_state


async def gradio_answer(chatbot, chat_state, image_list, num_beams=1, temperature=1.0):
def gradio_answer(chatbot, chat_state, image_list, num_beams=1, temperature=1.0):
input_dataset = dataset.from_dict({
"type": "image_text",
"instances": [{"images": np.stack([np.array(i) for i in image_list]),
Expand All @@ -237,90 +191,202 @@ async def gradio_answer(chatbot, chat_state, image_list, num_beams=1, temperatur
token_per_step = 20 # 48
max_new_tokens = 1024
temperature = 0.7

for response, flag_break in inferencer.stream_inference(
context=chatbot,
model=model,
max_new_tokens=max_new_tokens,
token_per_step=token_per_step,
temperature=temperature,
end_string=end_string,
input_dataset=input_dataset,
remove_image_flag=remove_image_flag,
):
# Prints characters in the buffer
new_print_index = print_index
for char in response[print_index:]:
if end_string is not None and char == end_string[0]:
if new_print_index + len(end_string) >= len(response):
break

new_print_index += 1
chatbot[-1][1] += char
chat_state += char
await asyncio.sleep(0.1)
yield chatbot, chat_state, image_list
# await asyncio.sleep(1)

print_index = new_print_index

if flag_break:
break
context = chatbot

request_queue.put((
context,
max_new_tokens,
token_per_step,
temperature,
end_string,
input_dataset,
remove_image_flag
))

while True:
if not response_queue.empty():
response, flag_break = response_queue.get()

# Prints characters in the buffer
new_print_index = print_index
for char in response[print_index:]:
if end_string is not None and char == end_string[0]:
if new_print_index + len(end_string) >= len(response):
break

new_print_index += 1
chatbot[-1][1] += char
chat_state += char
time.sleep(0.03)
yield chatbot, chat_state, image_list

print_index = new_print_index

if flag_break:
break

char = "\n"
chatbot[-1][1] += char
chat_state += char
yield chatbot, chat_state, image_list


with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)

with gr.Row():
chat_state = gr.State()
image_list = gr.State()

with gr.Column(scale=0.1, min_width=0):
clear = gr.Button("Restart")

with gr.Column(scale=0.8):
text_input = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter, or upload an image",
).style(container=False)

with gr.Column(scale=0.1, min_width=0):
upload_button = gr.UploadButton("📁", file_types=["image"])

txt_msg = text_input.submit(
fn=gradio_ask,
inputs=[text_input, chatbot, chat_state],
outputs=[text_input, chatbot, chat_state],
queue=False,
).then(
fn=gradio_answer,
inputs=[chatbot, chat_state, image_list],
outputs=[chatbot, chat_state, image_list],
)
txt_msg.then(
lambda: gr.update(interactive=True), None, [text_input], queue=False
def start_inferencer(
request_queue,
response_queue,
model_args,
pipeline_name,
pipeline_args,
data_args,
dataset,
chatbot_args,
):
with open(pipeline_args.deepspeed, "r") as f:
ds_config = json.load(f)

model = AutoModel.get_model(
model_args,
tune_strategy='none',
ds_config=ds_config,
device=pipeline_args.device,
custom_model=model_args.custom_model,
)

file_msg = upload_button.upload(
fn=upload_image,
inputs=[upload_button, chatbot, text_input, chat_state, image_list],
outputs=[text_input, chatbot, chat_state, image_list],
queue=False,
inferencer = AutoPipeline.get_pipeline(
pipeline_name=pipeline_name,
model_args=model_args,
data_args=data_args,
pipeline_args=pipeline_args,
)

clear.click(
fn=gradio_reset,
inputs=[chat_state, image_list],
outputs=[chatbot, text_input, upload_button, chat_state, image_list],
queue=False,
)
while True:
if not request_queue.empty():
request = request_queue.get()

context = request[0]
max_new_tokens = request[1]
token_per_step = request[2]
temperature = request[3]
end_string = request[4]
input_dataset = request[5]
remove_image_flag = request[6]

break_in_the_middle = False
for response_text, flag_break in inferencer.stream_inference(
context=context,
model=model,
max_new_tokens=max_new_tokens,
token_per_step=token_per_step,
temperature=temperature,
end_string=end_string,
input_dataset=input_dataset,
remove_image_flag=remove_image_flag,
):
response_queue.put((response_text, flag_break))
if flag_break:
break_in_the_middle = True
break

if not break_in_the_middle:
response_text = ''
flag_break = True
response_queue.put((response_text, flag_break))

time.sleep(0.001)


if __name__ == "__main__":
pipeline_name = "inferencer"
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)

parser = HfArgumentParser((
VisModelArguments,
PipelineArguments,
ChatbotArguments,
))
model_args, pipeline_args, chatbot_args = (
parser.parse_args_into_dataclasses()
)
data_args = DatasetArguments(dataset_path=None)
dataset = Dataset(data_args, backend="dict")

request_queue = torch.multiprocessing.Queue()
response_queue = torch.multiprocessing.Queue()
inferencer_process = torch.multiprocessing.Process(
target=start_inferencer,
args=(
request_queue,
response_queue,
model_args,
pipeline_name,
pipeline_args,
data_args,
dataset,
chatbot_args,
),
)
inferencer_process.start()

# Chats
model_name = model_args.model_name_or_path
if model_args.lora_model_path is not None:
model_name += f" + {model_args.lora_model_path}"

end_string = chatbot_args.end_string
prompt_structure = chatbot_args.prompt_structure

title = """<h1 align="center">Demo of Multi-modality chatbot from LMFlow</h1>"""
description = """<h3>This is the demo of Multi-modality chatbot from LMFlow. Upload your images and start chatting!</h3>"""

with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)

with gr.Row():
chat_state = gr.State()
image_list = gr.State()

with gr.Column(scale=0.1, min_width=0):
clear = gr.Button("Restart")

with gr.Column(scale=0.8):
text_input = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter, or upload an image",
).style(container=False)

with gr.Column(scale=0.1, min_width=0):
upload_button = gr.UploadButton("📁", file_types=["image"])

txt_msg = text_input.submit(
fn=gradio_ask,
inputs=[text_input, chatbot, chat_state],
outputs=[text_input, chatbot, chat_state],
queue=False,
).then(
fn=gradio_answer,
inputs=[chatbot, chat_state, image_list],
outputs=[chatbot, chat_state, image_list],
)
txt_msg.then(
lambda: gr.update(interactive=True), None, [text_input], queue=False
)

file_msg = upload_button.upload(
fn=upload_image,
inputs=[upload_button, chatbot, text_input, chat_state, image_list],
outputs=[text_input, chatbot, chat_state, image_list],
queue=False,
)

clear.click(
fn=gradio_reset,
inputs=[chat_state, image_list],
outputs=[chatbot, text_input, upload_button, chat_state, image_list],
queue=False,
)

demo.launch(share=True, enable_queue=True)
inferencer_process.join()

demo.launch(share=True, enable_queue=True)

0 comments on commit 7a80c6c

Please sign in to comment.