Skip to content

Commit

Permalink
Merge branch 'rpan-vision-encoder' into lianqing/vision_encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
research4pan committed Jul 22, 2023
2 parents 925aac5 + 3968053 commit b1475ae
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 127 deletions.
311 changes: 189 additions & 122 deletions examples/vis_chatbot_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import numpy as np
import os
import sys
sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0])))
import torch
import warnings
import gradio as gr
Expand All @@ -31,6 +30,7 @@

logging.disable(logging.ERROR)
warnings.filterwarnings("ignore")
torch.multiprocessing.set_start_method('spawn', force=True)

title = """
<h1 align="center">LMFlow-CHAT</h1>
Expand Down Expand Up @@ -114,52 +114,6 @@ class ChatbotArguments:
}
)

pipeline_name = "inferencer"
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)

parser = HfArgumentParser((
VisModelArguments,
PipelineArguments,
ChatbotArguments,
))
model_args, pipeline_args, chatbot_args = (
parser.parse_args_into_dataclasses()
)

with open(pipeline_args.deepspeed, "r") as f:
ds_config = json.load(f)

model = AutoModel.get_model(
model_args,
tune_strategy='none',
ds_config=ds_config,
device=pipeline_args.device,
custom_model=model_args.custom_model,
)

data_args = DatasetArguments(dataset_path=None)
dataset = Dataset(data_args, backend="dict")

inferencer = AutoPipeline.get_pipeline(
pipeline_name=pipeline_name,
model_args=model_args,
data_args=data_args,
pipeline_args=pipeline_args,
)

# Chats
model_name = model_args.model_name_or_path
if model_args.lora_model_path is not None:
model_name += f" + {model_args.lora_model_path}"


end_string = chatbot_args.end_string
prompt_structure = chatbot_args.prompt_structure

title = """<h1 align="center">Demo of Multi-modality chatbot from LMFlow</h1>"""
description = """<h3>This is the demo of Multi-modality chatbot from LMFlow. Upload your images and start chatting!</h3>"""
# article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
# """

def gradio_reset(chat_state, img_list):
if chat_state is not None:
Expand Down Expand Up @@ -233,92 +187,205 @@ def gradio_answer(chatbot, chat_state, image_list, num_beams=1, temperature=1.0)
chatbot[-1][1] = ''

print_index = 0
token_per_step = 20 # 48
max_new_tokens = 1024
token_per_step = 4 # 48
max_new_tokens = -1
temperature = 0.7

for response, flag_break in inferencer.stream_inference(
context=chatbot,
model=model,
max_new_tokens=max_new_tokens,
token_per_step=token_per_step,
temperature=temperature,
end_string=end_string,
input_dataset=input_dataset,
remove_image_flag=remove_image_flag,
):
# Prints characters in the buffer
new_print_index = print_index
for char in response[print_index:]:
if end_string is not None and char == end_string[0]:
if new_print_index + len(end_string) >= len(response):
break

new_print_index += 1
chatbot[-1][1] += char
chat_state += char
yield chatbot, chat_state, image_list
# await asyncio.sleep(1)

print_index = new_print_index

if flag_break:
break
context = chatbot

request_queue.put((
context,
max_new_tokens,
token_per_step,
temperature,
end_string,
input_dataset,
remove_image_flag
))

while True:
if not response_queue.empty():
response, flag_break = response_queue.get()

# Prints characters in the buffer
new_print_index = print_index
for char in response[print_index:]:
if end_string is not None and char == end_string[0]:
if new_print_index + len(end_string) >= len(response):
break

new_print_index += 1
chatbot[-1][1] += char
chat_state += char
time.sleep(0.03)
yield chatbot, chat_state, image_list

print_index = new_print_index

if flag_break:
break

char = "\n"
chatbot[-1][1] += char
chat_state += char
yield chatbot, chat_state, image_list


with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)

with gr.Row():
chat_state = gr.State()
image_list = gr.State()

with gr.Column(scale=0.1, min_width=0):
clear = gr.Button("Restart")

with gr.Column(scale=0.8):
text_input = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter, or upload an image",
).style(container=False)

with gr.Column(scale=0.1, min_width=0):
upload_button = gr.UploadButton("📁", file_types=["image"])

txt_msg = text_input.submit(
fn=gradio_ask,
inputs=[text_input, chatbot, chat_state],
outputs=[text_input, chatbot, chat_state],
queue=False,
).then(
fn=gradio_answer,
inputs=[chatbot, chat_state, image_list],
outputs=[chatbot, chat_state, image_list],
)
txt_msg.then(
lambda: gr.update(interactive=True), None, [text_input], queue=False
def start_inferencer(
request_queue,
response_queue,
model_args,
pipeline_name,
pipeline_args,
data_args,
dataset,
chatbot_args,
):
with open(pipeline_args.deepspeed, "r") as f:
ds_config = json.load(f)

model = AutoModel.get_model(
model_args,
tune_strategy='none',
ds_config=ds_config,
device=pipeline_args.device,
custom_model=model_args.custom_model,
)

file_msg = upload_button.upload(
fn=upload_image,
inputs=[upload_button, chatbot, text_input, chat_state, image_list],
outputs=[text_input, chatbot, chat_state, image_list],
queue=False,
inferencer = AutoPipeline.get_pipeline(
pipeline_name=pipeline_name,
model_args=model_args,
data_args=data_args,
pipeline_args=pipeline_args,
)

clear.click(
fn=gradio_reset,
inputs=[chat_state, image_list],
outputs=[chatbot, text_input, upload_button, chat_state, image_list],
queue=False,
)
while True:
if not request_queue.empty():
request = request_queue.get()

context = request[0]
max_new_tokens = request[1]
token_per_step = request[2]
temperature = request[3]
end_string = request[4]
input_dataset = request[5]
remove_image_flag = request[6]

break_in_the_middle = False
for response_text, flag_break in inferencer.stream_inference(
context=context,
model=model,
max_new_tokens=max_new_tokens,
token_per_step=token_per_step,
temperature=temperature,
end_string=end_string,
input_dataset=input_dataset,
remove_image_flag=remove_image_flag,
):
response_queue.put((response_text, flag_break))
if flag_break:
break_in_the_middle = True
break

if not break_in_the_middle:
response_text = ''
flag_break = True
response_queue.put((response_text, flag_break))

time.sleep(0.001)


if __name__ == "__main__":
pipeline_name = "inferencer"
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)

parser = HfArgumentParser((
VisModelArguments,
PipelineArguments,
ChatbotArguments,
))
model_args, pipeline_args, chatbot_args = (
parser.parse_args_into_dataclasses()
)
data_args = DatasetArguments(dataset_path=None)
dataset = Dataset(data_args, backend="dict")

request_queue = torch.multiprocessing.Queue()
response_queue = torch.multiprocessing.Queue()
inferencer_process = torch.multiprocessing.Process(
target=start_inferencer,
args=(
request_queue,
response_queue,
model_args,
pipeline_name,
pipeline_args,
data_args,
dataset,
chatbot_args,
),
)
inferencer_process.start()

# Chats
model_name = model_args.model_name_or_path
if model_args.lora_model_path is not None:
model_name += f" + {model_args.lora_model_path}"

end_string = chatbot_args.end_string
prompt_structure = chatbot_args.prompt_structure

title = """<h1 align="center">Demo of Multi-modality chatbot from LMFlow</h1>"""
description = """<h3>This is the demo of Multi-modality chatbot from LMFlow. Upload your images and start chatting!</h3>"""

with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)

with gr.Row():
chat_state = gr.State()
image_list = gr.State()

with gr.Column(scale=0.1, min_width=0):
clear = gr.Button("Restart")

with gr.Column(scale=0.8):
text_input = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter, or upload an image",
).style(container=False)

with gr.Column(scale=0.1, min_width=0):
upload_button = gr.UploadButton("📁", file_types=["image"])

txt_msg = text_input.submit(
fn=gradio_ask,
inputs=[text_input, chatbot, chat_state],
outputs=[text_input, chatbot, chat_state],
queue=False,
).then(
fn=gradio_answer,
inputs=[chatbot, chat_state, image_list],
outputs=[chatbot, chat_state, image_list],
)
txt_msg.then(
lambda: gr.update(interactive=True), None, [text_input], queue=False
)

file_msg = upload_button.upload(
fn=upload_image,
inputs=[upload_button, chatbot, text_input, chat_state, image_list],
outputs=[text_input, chatbot, chat_state, image_list],
queue=False,
)

clear.click(
fn=gradio_reset,
inputs=[chat_state, image_list],
outputs=[chatbot, text_input, upload_button, chat_state, image_list],
queue=False,
)

demo.launch(share=True, enable_queue=True)
inferencer_process.join()

demo.launch(share=True, enable_queue=True)
4 changes: 0 additions & 4 deletions src/lmflow/models/hf_encoder_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,7 @@ def encode(self, input: Union[str, List[str]], *args, **kwargs ) -> Union[List[i
if isinstance(input, dict):
# TODO refactor the input type to make it elegant.
kwargs.update(input)
import time
start = time.time()
tokens = self.tokenizer(*args, **kwargs)
end = time.time()
print('#######', end - start)
return tokens
elif isinstance(input, list):
return self.tokenizer(text=input, *args, **kwargs)#batch encode,will automatically do left padding
Expand Down
10 changes: 9 additions & 1 deletion src/lmflow/pipeline/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,19 @@ def inference(
# print(f"{current_time}: inferencer.inference: model.inference end", flush=True)
# print(f'type of context text: {type(dataset.to_dict()["instances"][0]["text"])}', flush=True)

text_out = model.decode(outputs[0], skip_special_tokens=True)
# only return the generation, trucating the input
if self.model_args.arch_type != "vision_encoder_decoder":
text_out = model.decode(outputs[0], skip_special_tokens=True)
prompt_length = len(model.decode(inputs[0], skip_special_tokens=True,))
text_out = text_out[prompt_length:]
else:
# to avoid redundant/missing leading space problem, we use a
# part of the input text
input_text = inputs['input_ids'][0][-1:]
text_out = model.decode(torch.cat([input_text, outputs[0]]), skip_special_tokens=True)
prompt_length = len(model.decode(input_text, skip_special_tokens=True,))
text_out = text_out[prompt_length:]

output_dict["instances"].append({ "text": text_out })

# current_time = time.strftime("%H:%M:%S", time.localtime())
Expand Down

0 comments on commit b1475ae

Please sign in to comment.