diff --git a/examples/vis_chatbot.py b/examples/vis_chatbot.py
index e556fcf42..9eaf16fbc 100644
--- a/examples/vis_chatbot.py
+++ b/examples/vis_chatbot.py
@@ -3,6 +3,7 @@
# Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved.
"""A simple shell to inference the input data.
"""
+from cmath import e
from dataclasses import dataclass, field
import logging
import json
@@ -55,14 +56,44 @@ class ChatbotArguments:
"help": "task for reasoning",
}
)
+ prompt_format: Optional[str] = field(
+ default="None",
+ metadata={
+ "help": "prompt format"
+ }
+ )
+@dataclass
+class VisModelArguments(ModelArguments):
+ low_resource: Optional[bool] = field(
+ default=False,
+ metadata={
+ "help": "Use 8 bit and float16 when loading llm"
+ }
+ )
+ custom_model: bool = field(
+ default=False,
+ metadata={"help": "flag for the model from huggingface or not"}
+ )
+ checkpoint_path: str = field(
+ default=None,
+ metadata={"help": "path for model checkpoint"}
+ )
+ llm_model_name_or_path: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": (
+ "llm model in multi-modality model"
+ )
+ },
+ )
def main():
pipeline_name = "inferencer"
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)
parser = HfArgumentParser((
- ModelArguments,
+ VisModelArguments,
PipelineArguments,
ChatbotArguments,
))
@@ -73,12 +104,12 @@ def main():
inferencer_args = pipeline_args
with open (pipeline_args.deepspeed, "r") as f:
ds_config = json.load(f)
-
model = AutoModel.get_model(
model_args,
tune_strategy='none',
ds_config=ds_config,
device=pipeline_args.device,
+ custom_model=model_args.custom_model,
)
data_args = DatasetArguments(dataset_path=None)
@@ -100,6 +131,7 @@ def main():
"\n"
f"#############################################################################\n"
f"## A {model_name} chatbot is now chatting with you!\n"
+ f"## The command for loading a new image: ###Load image:"
f"#############################################################################\n"
"\n"
)
@@ -109,54 +141,90 @@ def main():
# "You are a helpful assistant who follows the given instructions"
# " unconditionally."
# )
- context = ""
+
+ sep = "###"
end_string = chatbot_args.end_string
+ if chatbot_args.prompt_format == "mini_gpt":
+ context = "Give the following image: ImageContent. " + "You will be able to see the image once I provide it to you. Please answer my questions."
+ else:
+ context = ""
prompt_structure = chatbot_args.prompt_structure
-
# Load image and input text for reasoning
+ image_list = []
if chatbot_args.image_path is not None:
raw_image = Image.open(chatbot_args.image_path)
else:
- img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
+ img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
+ image_list.append(raw_image)
input_text = chatbot_args.input_text
if chatbot_args.task == "image_caption" and len(input_text) == 0:
input_text = "a photography of"
+ if chatbot_args.prompt_format == "mini_gpt":
+ context += sep + "Human: " + " "
-
+ # this flag is for determining if we need to add the ###Human: prompt
+ # if text after loading image, we add it when loading image
+ # else, we add it when read the text.
+ text_after_loading_image = True
if chatbot_args.task == "image_caption":
# single round reasoning
input_dataset = dataset.from_dict({
"type": "image_text",
- "instances": [{"images": raw_image,
+ "instances": [{"images": image_list,
"text": input_text,}]
})
output = inferencer.inference(model, input_dataset)
print(output.backend_dataset['text'])
else:
- # multi rounds reasoning
- # TODO support streaming reasoning.
while True:
input_text = input("User >>> ")
if input_text == "exit":
print("exit...")
break
+ elif input_text.startswith("###Load image:"):
+ image_path = input_text[14:]
+ try:
+ raw_image = Image.open(image_path)
+ image_list.append(raw_image)
+ context += sep + "Human: " + " "
+ text_after_loading_image = True
+ continue
+ except FileNotFoundError:
+ print("Loading image failed")
elif input_text == "reset":
context = ""
print("Chat history cleared")
continue
+
+ if text_after_loading_image is False:
+ if chatbot_args.prompt_format == "mini_gpt":
+ context += sep + "Human: "
+ else:
+ text_after_loading_image = False
+
if not input_text:
input_text = " "
- context += prompt_structure.format(input_text=input_text)
+ context += prompt_structure.format(input_text=input_text)
+
# TODO handle when model doesn't have the get_max_length
context = context[-model.get_max_length():] # Memory of the bot
input_dataset = dataset.from_dict({
"type": "image_text",
- "instances": [{"images": raw_image,
+ "instances": [{"images": image_list,
"text": context,}]
})
+ remove_image_flag = chatbot_args.prompt_format=="mini_gpt"
+ # output_dataset = inferencer.inference(
+ # model,
+ # input_dataset,
+ # remove_image_flag=remove_image_flag)
+ # response = output_dataset.backend_dataset['text']
+ # print(response[0])
+ # print("\n", end="")
+ # context += response[0]
print("Bot: ", end="")
print_index = 0
@@ -190,6 +258,5 @@ def main():
context += response + "\n"
-
if __name__ == "__main__":
main()
diff --git a/examples/vis_chatbot_gradio.py b/examples/vis_chatbot_gradio.py
index cc2e56f8e..d454afd72 100644
--- a/examples/vis_chatbot_gradio.py
+++ b/examples/vis_chatbot_gradio.py
@@ -174,12 +174,15 @@ def upload_image(gr_image, text_input, chat_state):
if gr_image is None:
return None, None, gr.update(interactive=True), chat_state, None
image_list = []
- if chatbot_args.prompt_format == "mini_gpt":
- chat_state = "Give the following image: ImageContent. " + "You will be able to see the image once I provide it to you. Please answer my questions."
- else:
- chat_state = ''
+ if chat_state is None:
+ if chatbot_args.prompt_format == "mini_gpt":
+ chat_state = "Give the following image: ImageContent. " + "You will be able to see the image once I provide it to you. Please answer my questions."
+ else:
+ chat_state = ''
image = read_img(gr_image)
image_list.append(image)
+ if chatbot_args.prompt_format == "mini_gpt":
+ chat_state = "Human: " + ""
return gr.update(interactive=False), \
gr.update(interactive=True, placeholder='Type and press Enter'), \
gr.update(value="Start Chatting", interactive=False), \
@@ -198,9 +201,7 @@ def read_img(image):
def gradio_ask(user_message, chatbot, chat_state):
if len(user_message) == 0:
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
- user_message = prompt_structure.format(input_text=user_message)
- chat_state = chat_state + user_message
-
+ chat_state = chat_state + prompt_structure.format(input_text=user_message)
chatbot = chatbot + [[user_message, None]]
return '', chatbot, chat_state
diff --git a/examples/debug.py b/examples/vis_debug.py
similarity index 51%
rename from examples/debug.py
rename to examples/vis_debug.py
index 141a1aec7..8d931c55f 100644
--- a/examples/debug.py
+++ b/examples/vis_debug.py
@@ -63,13 +63,37 @@ class ChatbotArguments:
}
)
+@dataclass
+class VisModelArguments(ModelArguments):
+ low_resource: Optional[bool] = field(
+ default=False,
+ metadata={
+ "help": "Use 8 bit and float16 when loading llm"
+ }
+ )
+ custom_model: bool = field(
+ default=False,
+ metadata={"help": "flag for the model from huggingface or not"}
+ )
+ checkpoint_path: str = field(
+ default=None,
+ metadata={"help": "path for model checkpoint"}
+ )
+ llm_model_name_or_path: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": (
+ "llm model in multi-modality model"
+ )
+ },
+ )
def main():
pipeline_name = "inferencer"
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)
parser = HfArgumentParser((
- ModelArguments,
+ VisModelArguments,
PipelineArguments,
ChatbotArguments,
))
@@ -107,6 +131,7 @@ def main():
"\n"
f"#############################################################################\n"
f"## A {model_name} chatbot is now chatting with you!\n"
+ f"## The command for loading a new image: ###Load image:"
f"#############################################################################\n"
"\n"
)
@@ -127,23 +152,28 @@ def main():
prompt_structure = chatbot_args.prompt_structure
# Load image and input text for reasoning
+ image_list = []
if chatbot_args.image_path is not None:
raw_image = Image.open(chatbot_args.image_path)
else:
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
+ image_list.append(raw_image)
input_text = chatbot_args.input_text
if chatbot_args.task == "image_caption" and len(input_text) == 0:
input_text = "a photography of"
if chatbot_args.prompt_format == "mini_gpt":
- context += sep + "Human: " + ""
-
+ context += sep + "Human: " + " "
+ # this flag is for determining if we need to add the ###Human: prompt
+ # if text after loading image, we add it when loading image
+ # else, we add it when read the text.
+ text_after_loading_image = True
if chatbot_args.task == "image_caption":
# single round reasoning
input_dataset = dataset.from_dict({
"type": "image_text",
- "instances": [{"images": raw_image,
+ "instances": [{"images": image_list,
"text": input_text,}]
})
output = inferencer.inference(model, input_dataset)
@@ -151,25 +181,75 @@ def main():
else:
# multi rounds reasoning
# TODO support streaming reasoning.
- while True:
- input_text = input("User >>> ")
- if input_text == "exit":
- print("exit...")
- break
- elif input_text == "reset":
- context = ""
- print("Chat history cleared")
- continue
+ # while True:
+ # input_text = input("User >>> ")
+ # if input_text == "exit":
+ # print("exit...")
+ # break
+ # elif input_text == "reset":
+ # context = ""
+ # print("Chat history cleared")
+ # continue
+ # if not input_text:
+ # input_text = " "
+ # context += prompt_structure.format(input_text=input_text)
+ # # TODO handle when model doesn't have the get_max_length
+ # context = context[-model.get_max_length():] # Memory of the bot
+ # input_dataset = dataset.from_dict({
+ # "type": "image_text",
+ # "instances": [{"images": raw_image,
+ # "text": context,}]
+ # })
+ # remove_image_flag = chatbot_args.prompt_format=="mini_gpt"
+ # output_dataset = inferencer.inference(
+ # model,
+ # input_dataset,
+ # remove_image_flag=remove_image_flag)
+ # response = output_dataset.backend_dataset['text']
+ # print(response[0])
+ # print("\n", end="")
+ # context += response[0]
+ # while True:
+ # input_text = input("User >>> ")
+ # if input_text == "exit":
+ # print("exit...")
+ # break
+ # elif input_text.startswith("###Load image:"):
+ # image_path = input_text[14:]
+ # try:
+ # raw_image = Image.open(image_path)
+ # image_list.append(raw_image)
+ # context += sep + "Human: " + " "
+ # text_after_loading_image = True
+ # continue
+ # except FileNotFoundError:
+ # print("Loading image failed")
+ # elif input_text == "reset":
+ # context = ""
+ # print("Chat history cleared")
+ # continue
+
+ input_text = "describe the image"
+
+ # if text_after_loading_image is False:
+ # if chatbot_args.prompt_format == "mini_gpt":
+ # context += sep + "Human: "
+ # else:
+ # text_after_loading_image = False
+
if not input_text:
input_text = " "
- context += prompt_structure.format(input_text=input_text)
+ context += prompt_structure.format(input_text=input_text)
+
# TODO handle when model doesn't have the get_max_length
context = context[-model.get_max_length():] # Memory of the bot
+ import pdb; pdb.set_trace()
input_dataset = dataset.from_dict({
"type": "image_text",
- "instances": [{"images": raw_image,
+ "instances": [{"images": image_list,
"text": context,}]
})
+ import pdb; pdb.set_trace()
remove_image_flag = chatbot_args.prompt_format=="mini_gpt"
output_dataset = inferencer.inference(
model,
@@ -180,6 +260,36 @@ def main():
print("\n", end="")
context += response[0]
+ image_list.append(raw_image)
+ context += sep + "Human: " + " "
+ input_text = "describe the image again"
+
+ # if text_after_loading_image is False:
+ # if chatbot_args.prompt_format == "mini_gpt":
+ # context += sep + "Human: "
+ # else:
+ # text_after_loading_image = False
+
+ if not input_text:
+ input_text = " "
+ context += prompt_structure.format(input_text=input_text)
+
+ # TODO handle when model doesn't have the get_max_length
+ context = context[-model.get_max_length():] # Memory of the bot
+ input_dataset = dataset.from_dict({
+ "type": "image_text",
+ "instances": [{"images": image_list,
+ "text": context,}]
+ })
+ remove_image_flag = chatbot_args.prompt_format=="mini_gpt"
+ output_dataset = inferencer.inference(
+ model,
+ input_dataset,
+ remove_image_flag=remove_image_flag)
+ response = output_dataset.backend_dataset['text']
+ print(response[0])
+ print("\n", end="")
+ context += response[0]
if __name__ == "__main__":
main()
diff --git a/requirements.txt b/requirements.txt
index 1fa7c7e31..a16a3c78f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,6 @@
numpy==1.24.2
datasets==2.10.1
peft @ git+https://github.com/huggingface/peft.git@deff03f2c251534fffd2511fc2d440e84cc54b1b
-torch==2.0.0
wandb==0.14.0
deepspeed==0.8.3
trl @ git+https://github.com/lvwerra/trl.git#egg=trl-0.4.1
diff --git a/scripts/run_vis_chatbot_debug.sh b/scripts/run_vis_chatbot_debug.sh
new file mode 100644
index 000000000..61fc124f8
--- /dev/null
+++ b/scripts/run_vis_chatbot_debug.sh
@@ -0,0 +1,11 @@
+model=Salesforce/blip2-flan-t5-xxl
+checkpoint_path=/scratch/PI/tongzhang/qinglian/checkpoints/pretrained_weights/minigpt4/prerained_minigpt4_7b_converted.pth
+llm_model_name_or_path=/scratch/PI/tongzhang/qinglian/checkpoints/pretrained_weights/vicuna-7b/
+deepspeed examples/vis_debug.py --model_name_or_path ${model} --deepspeed configs/ds_config_multimodal.json --arch_type vision_encoder_decoder --task vqa --custom_model \
+ --prompt_format mini_gpt \
+ --prompt_structure "{input_text}###Assistant:" \
+ --checkpoint_path ${checkpoint_path} \
+ --llm_model_name_or_path ${llm_model_name_or_path} \
+ --image_path "/home/qlianab/base.jpg" \
+ --low_resource True
+
diff --git a/scripts/run_vis_chatbot_gradio_minigpt4.sh b/scripts/run_vis_chatbot_gradio_minigpt4.sh
index 46ec4d56c..b03fff19e 100644
--- a/scripts/run_vis_chatbot_gradio_minigpt4.sh
+++ b/scripts/run_vis_chatbot_gradio_minigpt4.sh
@@ -1,2 +1,10 @@
model=Salesforce/blip2-flan-t5-xxl
-deepspeed examples/vis_chatbot_gradio.py --model_name_or_path ${model} --deepspeed configs/ds_config_multimodal.json --arch_type vision_encoder_decoder --task vqa --custom_model --prompt_format mini_gpt --prompt_structure "{input_text}###Assistant:"
+checkpoint_path=/scratch/PI/tongzhang/qinglian/checkpoints/pretrained_weights/minigpt4/prerained_minigpt4_7b_converted.pth
+llm_model_name_or_path=/scratch/PI/tongzhang/qinglian/checkpoints/pretrained_weights/vicuna-7b/
+deepspeed examples/vis_chatbot_gradio.py --model_name_or_path ${model} \
+ --deepspeed configs/ds_config_multimodal.json \
+ --arch_type vision_encoder_decoder \
+ --task vqa --custom_model --prompt_format mini_gpt --prompt_structure "{input_text}###Assistant:" \
+ --checkpoint_path ${checkpoint_path} \
+ --llm_model_name_or_path ${llm_model_name_or_path}
+ --low_resource True
diff --git a/scripts/run_vis_chatbot_minigpt4.sh b/scripts/run_vis_chatbot_minigpt4.sh
index a79b9cc67..4d55b3457 100644
--- a/scripts/run_vis_chatbot_minigpt4.sh
+++ b/scripts/run_vis_chatbot_minigpt4.sh
@@ -1,9 +1,11 @@
model=Salesforce/blip2-flan-t5-xxl
checkpoint_path=/scratch/PI/tongzhang/qinglian/checkpoints/pretrained_weights/minigpt4/prerained_minigpt4_7b_converted.pth
-llm_model_name_or_path="/scratch/PI/tongzhang/qinglian/checkpoints/pretrained_weights/vicuna-7b/"
-deepspeed examples/debug.py --model_name_or_path ${model} --deepspeed configs/ds_config_multimodal.json --arch_type vision_encoder_decoder --task vqa --custom_model \
+llm_model_name_or_path=/scratch/PI/tongzhang/qinglian/checkpoints/pretrained_weights/vicuna-7b/
+deepspeed examples/vis_chatbot.py --model_name_or_path ${model} --deepspeed configs/ds_config_multimodal.json --arch_type vision_encoder_decoder --task vqa --custom_model \
--prompt_format mini_gpt \
--prompt_structure "{input_text}###Assistant:" \
- --checkpoint_path {checkpoint_path} \
- --llm_model_name_or_path {llm_model_name_or_path}
+ --checkpoint_path ${checkpoint_path} \
+ --llm_model_name_or_path ${llm_model_name_or_path} \
+ --image_path "/home/qlianab/base.jpg"
+ --low_resource True
diff --git a/src/lmflow/args.py b/src/lmflow/args.py
index cfd998619..29b48a9d3 100644
--- a/src/lmflow/args.py
+++ b/src/lmflow/args.py
@@ -198,26 +198,7 @@ class ModelArguments:
)
}
)
- use_int8: bool = field(
- default=False,
- metadata={"help": "whether to load int8 quantization for inference"}
- )
- custom_model: bool = field(
- default=False,
- metadata={"help": "flag for the model from huggingface or not"}
- )
- checkpoint_path: str = field(
- default=None,
- metadata={"help": "path for model checkpoint"}
- )
- llm_model_name_or_path: Optional[str] = field(
- default=None,
- metadata={
- "help": (
- "llm model in multi-modality model"
- )
- },
- )
+
def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
diff --git a/src/lmflow/models/hf_encoder_decoder_model.py b/src/lmflow/models/hf_encoder_decoder_model.py
index 18fa10d64..b28c337a3 100644
--- a/src/lmflow/models/hf_encoder_decoder_model.py
+++ b/src/lmflow/models/hf_encoder_decoder_model.py
@@ -46,6 +46,7 @@
AutoModelForVision2Seq,
AutoModel,
AutoProcessor,
+ LlamaTokenizer
)
from transformers import (Blip2VisionConfig,
@@ -181,19 +182,16 @@ def __init__(
# self.backend_model = model_register.from_pretrained(
# model_args.model_name_or_path)
else:
- # model = CustomAutoVision2SeqModel.from_pretrained(
- # model_args.model_name_or_path,
- # )
- vision_config = Blip2VisionConfig.from_pretrained(model_args.model_name_or_path)
- qformer_config = Blip2QFormerConfig.from_pretrained(model_args.model_name_or_path)
- text_config = LlamaConfig.from_pretrained(model_args.llm_model_name_or_path)
- config = Blip2Config.from_vision_qformer_text_configs(vision_config, qformer_config, text_config)
- model = CustomAutoVision2SeqModel(config)
- model.vision_model_from_pretrained(model_args.model_name_or_path)
- model.qformer_from_pretrained(model_args.model_name_or_path)
- model.language_model_from_pretrained(model_args.llm_model_name_or_path)
+ model = CustomAutoVision2SeqModel.from_pretrained(model_args.model_name_or_path)
+ if model_args.llm_model_name_or_path is not None:
+ text_config = LlamaConfig.from_pretrained(model_args.llm_model_name_or_path)
+ model.config.text_config = text_config
+ model.language_model_from_pretrained(model_args.llm_model_name_or_path,
+ low_resource=model_args.low_resource)
state_dict = torch.load(model_args.checkpoint_path, map_location="cpu")
model.load_state_dict(state_dict, strict=False)
+ # model = CustomAutoVision2SeqModel.from_pretrained(
+ # "/home/qlianab/checkpoints/pretrained_weights/minigpt4-lmflow-vicuna-7b-low_resource/")
self.backend_model = model
if self.arch_type == "encoder_decoder":
@@ -202,8 +200,9 @@ def __init__(
tokenizer_register = AutoProcessor
else:
raise NotImplementedError
-
self.tokenizer = tokenizer_register.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
+ if model_args.llm_model_name_or_path is not None:
+ self.tokenizer.tokenizer = LlamaTokenizer.from_pretrained(model_args.llm_model_name_or_path)
self.backend_model_full = self.backend_model
if peft_model_id is not None:
self.backend_model = PeftModel.from_pretrained(
@@ -267,6 +266,7 @@ def encode(self, input: Union[str, List[str]], *args, **kwargs ) -> Union[List[i
outputs :
The tokenized inputs.
"""
+ import pdb; pdb.set_trace()
if isinstance(input, dict):
# TODO refactor the input type to make it elegant.
kwargs.update(input)
@@ -329,6 +329,7 @@ def inference(self, inputs, *args, **kwargs):
The generated sequence output
"""
# TODO need to discuss how to handle pad_token_id
+ import pdb; pdb.set_trace()
if self.arch_type == "encoder_decoder":
kwargs.update(pad_token_id=self.tokenizer.pad_token_id)
elif self.arch_type == "vision_encoder_decoder":
diff --git a/src/lmflow/models/vision2seq_model.py b/src/lmflow/models/vision2seq_model.py
index 54ec8db5e..244ac875f 100644
--- a/src/lmflow/models/vision2seq_model.py
+++ b/src/lmflow/models/vision2seq_model.py
@@ -9,7 +9,8 @@
from transformers import (
Blip2ForConditionalGeneration,
- Blip2Config
+ Blip2Config,
+ AutoModelForCausalLM
)
from .base_model import BaseModel
@@ -23,22 +24,26 @@ def vision_model_from_pretrained(self, pretrained_path):
self.vision_model = self.vision_model.from_pretrained(
pretrained_path,
config=self.config.vision_config)
-
def qformer_from_pretrained(self, pretrained_path):
- print(self.qformer.encoder.layer[11].output_query.dense.weight.mean())
self.qformer = self.qformer.from_pretrained(
pretrained_path,
config=self.config.qformer_config)
print(self.qformer.encoder.layer[11].output_query.dense.weight.mean())
- def language_model_from_pretrained(self, pretrained_path):
+ def language_model_from_pretrained(self, pretrained_path, low_resource=False):
# TODO remove the low resource related loading in the future
- self.language_model = self.language_model.from_pretrained(
+ if low_resource:
+ kwargs = dict(
+ torch_dtype=torch.float16,
+ load_in_8bit=True,
+ device_map="auto"
+ )
+ else:
+ kwargs = {}
+ self.language_model = AutoModelForCausalLM.from_pretrained(
pretrained_path,
config=self.config.text_config,
- torch_dtype=torch.float16,
- device_map="auto")
-
+ **kwargs)
@torch.no_grad()
def generate(
@@ -47,6 +52,7 @@ def generate(
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
image_token_indexes: Optional[List] = [0],
+ one_sample_multiple_images: Optional[bool] = False,
**generate_kwargs,
) -> torch.LongTensor:
"""
@@ -59,6 +65,10 @@ def generate(
The sequence used as a prompt for the generation.
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
Mask to avoid performing attention on padding token indices
+ image_token_indexes (bool, *optional*):
+ The index for inserting the image tokens.
+ one_sample_multiple_images: (bool, *optional*):
+ The flag for inference that the input batch size is 1 and contain multiple images.
Returns:
captions (list): A list of strings of length batch_size * num_captions.
@@ -66,8 +76,10 @@ def generate(
if hasattr(self, "hf_device_map"):
# preprocess for `accelerate`
self._preprocess_accelerate()
-
- batch_size = pixel_values.shape[0]
+ if not one_sample_multiple_images:
+ batch_size = pixel_values.shape[0]
+ else:
+ batch_size = 1
image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
@@ -98,27 +110,27 @@ def generate(
# concatenate query embeddings with prompt embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = inputs_embeds.to(language_model_inputs.device)
-
# concatenate the text embeddings with image embeddings
inputs_embeds_with_images = []
attention_mask_with_images = []
# currently we only support with one image
assert len(image_token_indexes) == 1
- for image_token_index in image_token_indexes:
+ for idx, image_token_index in enumerate(image_token_indexes):
inputs_embeds_with_images.append(inputs_embeds[:, :image_token_index])
- inputs_embeds_with_images.append(language_model_inputs)
+ inputs_embeds_with_images.append(language_model_inputs[idx][None])
attention_mask_with_images.append(
attention_mask[:, :image_token_index])
- attention_mask_with_images.append(language_attention_mask)
+ attention_mask_with_images.append(language_attention_mask[idx][None])
inputs_embeds_with_images.append(inputs_embeds[:, image_token_indexes[-1]:])
inputs_embeds = torch.cat(inputs_embeds_with_images, dim=1)
attention_mask_with_images.append(attention_mask[:, image_token_indexes[-1]:])
attention_mask = torch.cat(attention_mask_with_images, dim=1)
+ inputs_embeds = inputs_embeds.to(self.language_model.lm_head.weight.dtype)
+ attention_mask = attention_mask.to(self.language_model.lm_head.weight.dtype)
outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generate_kwargs,
)
-
return outputs
diff --git a/src/lmflow/pipeline/inferencer.py b/src/lmflow/pipeline/inferencer.py
index f0e6a1a64..df93c4b23 100644
--- a/src/lmflow/pipeline/inferencer.py
+++ b/src/lmflow/pipeline/inferencer.py
@@ -154,14 +154,18 @@ def inference(
else:
input = current_batch['input']
input['text'] = prompt_structure.format(input=input['text'])
-
if remove_image_flag:
+ # remove the image flag in tokenization;
input['text'] = input['text'].split("")
new_input = copy.deepcopy(input)
new_input['text'] = new_input['text'][-1]
input['text'] = input['text'][0]
- inputs = model.encode(input, return_tensors="pt").to(device=self.local_rank)
- new_inputs = model.encode(new_input, return_tensors="pt").to(device=self.local_rank)
+ inputs = model.encode(input,
+ return_tensors="pt",
+ add_special_tokens=True).to(device=self.local_rank)
+ new_inputs = model.encode(new_input,
+ return_tensors="pt",
+ add_special_tokens=False).to(device=self.local_rank)
image_token_indexes = [inputs["input_ids"].shape[1]]
inputs["input_ids"] = torch.cat([inputs["input_ids"],
new_inputs["input_ids"]], dim=1)