Skip to content

Commit

Permalink
fix minigpt4 inference issues
Browse files Browse the repository at this point in the history
  • Loading branch information
lianqing01 committed Jul 11, 2023
1 parent 3927111 commit 10fa218
Show file tree
Hide file tree
Showing 11 changed files with 286 additions and 90 deletions.
91 changes: 79 additions & 12 deletions examples/vis_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved.
"""A simple shell to inference the input data.
"""
from cmath import e
from dataclasses import dataclass, field
import logging
import json
Expand Down Expand Up @@ -55,14 +56,44 @@ class ChatbotArguments:
"help": "task for reasoning",
}
)
prompt_format: Optional[str] = field(
default="None",
metadata={
"help": "prompt format"
}
)

@dataclass
class VisModelArguments(ModelArguments):
low_resource: Optional[bool] = field(
default=False,
metadata={
"help": "Use 8 bit and float16 when loading llm"
}
)
custom_model: bool = field(
default=False,
metadata={"help": "flag for the model from huggingface or not"}
)
checkpoint_path: str = field(
default=None,
metadata={"help": "path for model checkpoint"}
)
llm_model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"llm model in multi-modality model"
)
},
)

def main():
pipeline_name = "inferencer"
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)

parser = HfArgumentParser((
ModelArguments,
VisModelArguments,
PipelineArguments,
ChatbotArguments,
))
Expand All @@ -73,12 +104,12 @@ def main():
inferencer_args = pipeline_args
with open (pipeline_args.deepspeed, "r") as f:
ds_config = json.load(f)

model = AutoModel.get_model(
model_args,
tune_strategy='none',
ds_config=ds_config,
device=pipeline_args.device,
custom_model=model_args.custom_model,
)

data_args = DatasetArguments(dataset_path=None)
Expand All @@ -100,6 +131,7 @@ def main():
"\n"
f"#############################################################################\n"
f"## A {model_name} chatbot is now chatting with you!\n"
f"## The command for loading a new image: ###Load image:"
f"#############################################################################\n"
"\n"
)
Expand All @@ -109,54 +141,90 @@ def main():
# "You are a helpful assistant who follows the given instructions"
# " unconditionally."
# )
context = ""

sep = "###"

end_string = chatbot_args.end_string
if chatbot_args.prompt_format == "mini_gpt":
context = "Give the following image: <Img>ImageContent</Img>. " + "You will be able to see the image once I provide it to you. Please answer my questions."
else:
context = ""
prompt_structure = chatbot_args.prompt_structure


# Load image and input text for reasoning
image_list = []
if chatbot_args.image_path is not None:
raw_image = Image.open(chatbot_args.image_path)
else:
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
image_list.append(raw_image)
input_text = chatbot_args.input_text
if chatbot_args.task == "image_caption" and len(input_text) == 0:
input_text = "a photography of"
if chatbot_args.prompt_format == "mini_gpt":
context += sep + "Human: " + "<Img><ImageHere></Img> "


# this flag is for determining if we need to add the ###Human: prompt
# if text after loading image, we add it when loading image
# else, we add it when read the text.
text_after_loading_image = True
if chatbot_args.task == "image_caption":
# single round reasoning
input_dataset = dataset.from_dict({
"type": "image_text",
"instances": [{"images": raw_image,
"instances": [{"images": image_list,
"text": input_text,}]
})
output = inferencer.inference(model, input_dataset)
print(output.backend_dataset['text'])
else:
# multi rounds reasoning
# TODO support streaming reasoning.
while True:
input_text = input("User >>> ")
if input_text == "exit":
print("exit...")
break
elif input_text.startswith("###Load image:"):
image_path = input_text[14:]
try:
raw_image = Image.open(image_path)
image_list.append(raw_image)
context += sep + "Human: " + "<Img><ImageHere></Img> "
text_after_loading_image = True
continue
except FileNotFoundError:
print("Loading image failed")
elif input_text == "reset":
context = ""
print("Chat history cleared")
continue

if text_after_loading_image is False:
if chatbot_args.prompt_format == "mini_gpt":
context += sep + "Human: "
else:
text_after_loading_image = False

if not input_text:
input_text = " "
context += prompt_structure.format(input_text=input_text)
context += prompt_structure.format(input_text=input_text)

# TODO handle when model doesn't have the get_max_length
context = context[-model.get_max_length():] # Memory of the bot
input_dataset = dataset.from_dict({
"type": "image_text",
"instances": [{"images": raw_image,
"instances": [{"images": image_list,
"text": context,}]
})
remove_image_flag = chatbot_args.prompt_format=="mini_gpt"
# output_dataset = inferencer.inference(
# model,
# input_dataset,
# remove_image_flag=remove_image_flag)
# response = output_dataset.backend_dataset['text']
# print(response[0])
# print("\n", end="")
# context += response[0]

print("Bot: ", end="")
print_index = 0
Expand Down Expand Up @@ -190,6 +258,5 @@ def main():

context += response + "\n"


if __name__ == "__main__":
main()
15 changes: 8 additions & 7 deletions examples/vis_chatbot_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,15 @@ def upload_image(gr_image, text_input, chat_state):
if gr_image is None:
return None, None, gr.update(interactive=True), chat_state, None
image_list = []
if chatbot_args.prompt_format == "mini_gpt":
chat_state = "Give the following image: <Img>ImageContent</Img>. " + "You will be able to see the image once I provide it to you. Please answer my questions."
else:
chat_state = ''
if chat_state is None:
if chatbot_args.prompt_format == "mini_gpt":
chat_state = "Give the following image: <Img>ImageContent</Img>. " + "You will be able to see the image once I provide it to you. Please answer my questions."
else:
chat_state = ''
image = read_img(gr_image)
image_list.append(image)
if chatbot_args.prompt_format == "mini_gpt":
chat_state = "Human: " + "<Img><ImageHere></Img>"
return gr.update(interactive=False), \
gr.update(interactive=True, placeholder='Type and press Enter'), \
gr.update(value="Start Chatting", interactive=False), \
Expand All @@ -198,9 +201,7 @@ def read_img(image):
def gradio_ask(user_message, chatbot, chat_state):
if len(user_message) == 0:
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
user_message = prompt_structure.format(input_text=user_message)
chat_state = chat_state + user_message

chat_state = chat_state + prompt_structure.format(input_text=user_message)
chatbot = chatbot + [[user_message, None]]
return '', chatbot, chat_state

Expand Down
Loading

0 comments on commit 10fa218

Please sign in to comment.