diff --git a/examples/vis_chatbot.py b/examples/vis_chatbot.py index e556fcf42..9eaf16fbc 100644 --- a/examples/vis_chatbot.py +++ b/examples/vis_chatbot.py @@ -3,6 +3,7 @@ # Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. """A simple shell to inference the input data. """ +from cmath import e from dataclasses import dataclass, field import logging import json @@ -55,14 +56,44 @@ class ChatbotArguments: "help": "task for reasoning", } ) + prompt_format: Optional[str] = field( + default="None", + metadata={ + "help": "prompt format" + } + ) +@dataclass +class VisModelArguments(ModelArguments): + low_resource: Optional[bool] = field( + default=False, + metadata={ + "help": "Use 8 bit and float16 when loading llm" + } + ) + custom_model: bool = field( + default=False, + metadata={"help": "flag for the model from huggingface or not"} + ) + checkpoint_path: str = field( + default=None, + metadata={"help": "path for model checkpoint"} + ) + llm_model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "llm model in multi-modality model" + ) + }, + ) def main(): pipeline_name = "inferencer" PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) parser = HfArgumentParser(( - ModelArguments, + VisModelArguments, PipelineArguments, ChatbotArguments, )) @@ -73,12 +104,12 @@ def main(): inferencer_args = pipeline_args with open (pipeline_args.deepspeed, "r") as f: ds_config = json.load(f) - model = AutoModel.get_model( model_args, tune_strategy='none', ds_config=ds_config, device=pipeline_args.device, + custom_model=model_args.custom_model, ) data_args = DatasetArguments(dataset_path=None) @@ -100,6 +131,7 @@ def main(): "\n" f"#############################################################################\n" f"## A {model_name} chatbot is now chatting with you!\n" + f"## The command for loading a new image: ###Load image:" f"#############################################################################\n" "\n" ) @@ -109,54 +141,90 @@ def main(): # "You are a helpful assistant who follows the given instructions" # " unconditionally." # ) - context = "" + + sep = "###" end_string = chatbot_args.end_string + if chatbot_args.prompt_format == "mini_gpt": + context = "Give the following image: ImageContent. " + "You will be able to see the image once I provide it to you. Please answer my questions." + else: + context = "" prompt_structure = chatbot_args.prompt_structure - # Load image and input text for reasoning + image_list = [] if chatbot_args.image_path is not None: raw_image = Image.open(chatbot_args.image_path) else: - img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg' + img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg' raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB') + image_list.append(raw_image) input_text = chatbot_args.input_text if chatbot_args.task == "image_caption" and len(input_text) == 0: input_text = "a photography of" + if chatbot_args.prompt_format == "mini_gpt": + context += sep + "Human: " + " " - + # this flag is for determining if we need to add the ###Human: prompt + # if text after loading image, we add it when loading image + # else, we add it when read the text. + text_after_loading_image = True if chatbot_args.task == "image_caption": # single round reasoning input_dataset = dataset.from_dict({ "type": "image_text", - "instances": [{"images": raw_image, + "instances": [{"images": image_list, "text": input_text,}] }) output = inferencer.inference(model, input_dataset) print(output.backend_dataset['text']) else: - # multi rounds reasoning - # TODO support streaming reasoning. while True: input_text = input("User >>> ") if input_text == "exit": print("exit...") break + elif input_text.startswith("###Load image:"): + image_path = input_text[14:] + try: + raw_image = Image.open(image_path) + image_list.append(raw_image) + context += sep + "Human: " + " " + text_after_loading_image = True + continue + except FileNotFoundError: + print("Loading image failed") elif input_text == "reset": context = "" print("Chat history cleared") continue + + if text_after_loading_image is False: + if chatbot_args.prompt_format == "mini_gpt": + context += sep + "Human: " + else: + text_after_loading_image = False + if not input_text: input_text = " " - context += prompt_structure.format(input_text=input_text) + context += prompt_structure.format(input_text=input_text) + # TODO handle when model doesn't have the get_max_length context = context[-model.get_max_length():] # Memory of the bot input_dataset = dataset.from_dict({ "type": "image_text", - "instances": [{"images": raw_image, + "instances": [{"images": image_list, "text": context,}] }) + remove_image_flag = chatbot_args.prompt_format=="mini_gpt" + # output_dataset = inferencer.inference( + # model, + # input_dataset, + # remove_image_flag=remove_image_flag) + # response = output_dataset.backend_dataset['text'] + # print(response[0]) + # print("\n", end="") + # context += response[0] print("Bot: ", end="") print_index = 0 @@ -190,6 +258,5 @@ def main(): context += response + "\n" - if __name__ == "__main__": main() diff --git a/examples/vis_chatbot_gradio.py b/examples/vis_chatbot_gradio.py index cc2e56f8e..d454afd72 100644 --- a/examples/vis_chatbot_gradio.py +++ b/examples/vis_chatbot_gradio.py @@ -174,12 +174,15 @@ def upload_image(gr_image, text_input, chat_state): if gr_image is None: return None, None, gr.update(interactive=True), chat_state, None image_list = [] - if chatbot_args.prompt_format == "mini_gpt": - chat_state = "Give the following image: ImageContent. " + "You will be able to see the image once I provide it to you. Please answer my questions." - else: - chat_state = '' + if chat_state is None: + if chatbot_args.prompt_format == "mini_gpt": + chat_state = "Give the following image: ImageContent. " + "You will be able to see the image once I provide it to you. Please answer my questions." + else: + chat_state = '' image = read_img(gr_image) image_list.append(image) + if chatbot_args.prompt_format == "mini_gpt": + chat_state = "Human: " + "" return gr.update(interactive=False), \ gr.update(interactive=True, placeholder='Type and press Enter'), \ gr.update(value="Start Chatting", interactive=False), \ @@ -198,9 +201,7 @@ def read_img(image): def gradio_ask(user_message, chatbot, chat_state): if len(user_message) == 0: return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state - user_message = prompt_structure.format(input_text=user_message) - chat_state = chat_state + user_message - + chat_state = chat_state + prompt_structure.format(input_text=user_message) chatbot = chatbot + [[user_message, None]] return '', chatbot, chat_state diff --git a/examples/debug.py b/examples/vis_debug.py similarity index 51% rename from examples/debug.py rename to examples/vis_debug.py index 141a1aec7..8d931c55f 100644 --- a/examples/debug.py +++ b/examples/vis_debug.py @@ -63,13 +63,37 @@ class ChatbotArguments: } ) +@dataclass +class VisModelArguments(ModelArguments): + low_resource: Optional[bool] = field( + default=False, + metadata={ + "help": "Use 8 bit and float16 when loading llm" + } + ) + custom_model: bool = field( + default=False, + metadata={"help": "flag for the model from huggingface or not"} + ) + checkpoint_path: str = field( + default=None, + metadata={"help": "path for model checkpoint"} + ) + llm_model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "llm model in multi-modality model" + ) + }, + ) def main(): pipeline_name = "inferencer" PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) parser = HfArgumentParser(( - ModelArguments, + VisModelArguments, PipelineArguments, ChatbotArguments, )) @@ -107,6 +131,7 @@ def main(): "\n" f"#############################################################################\n" f"## A {model_name} chatbot is now chatting with you!\n" + f"## The command for loading a new image: ###Load image:" f"#############################################################################\n" "\n" ) @@ -127,23 +152,28 @@ def main(): prompt_structure = chatbot_args.prompt_structure # Load image and input text for reasoning + image_list = [] if chatbot_args.image_path is not None: raw_image = Image.open(chatbot_args.image_path) else: img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg' raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB') + image_list.append(raw_image) input_text = chatbot_args.input_text if chatbot_args.task == "image_caption" and len(input_text) == 0: input_text = "a photography of" if chatbot_args.prompt_format == "mini_gpt": - context += sep + "Human: " + "" - + context += sep + "Human: " + " " + # this flag is for determining if we need to add the ###Human: prompt + # if text after loading image, we add it when loading image + # else, we add it when read the text. + text_after_loading_image = True if chatbot_args.task == "image_caption": # single round reasoning input_dataset = dataset.from_dict({ "type": "image_text", - "instances": [{"images": raw_image, + "instances": [{"images": image_list, "text": input_text,}] }) output = inferencer.inference(model, input_dataset) @@ -151,25 +181,75 @@ def main(): else: # multi rounds reasoning # TODO support streaming reasoning. - while True: - input_text = input("User >>> ") - if input_text == "exit": - print("exit...") - break - elif input_text == "reset": - context = "" - print("Chat history cleared") - continue + # while True: + # input_text = input("User >>> ") + # if input_text == "exit": + # print("exit...") + # break + # elif input_text == "reset": + # context = "" + # print("Chat history cleared") + # continue + # if not input_text: + # input_text = " " + # context += prompt_structure.format(input_text=input_text) + # # TODO handle when model doesn't have the get_max_length + # context = context[-model.get_max_length():] # Memory of the bot + # input_dataset = dataset.from_dict({ + # "type": "image_text", + # "instances": [{"images": raw_image, + # "text": context,}] + # }) + # remove_image_flag = chatbot_args.prompt_format=="mini_gpt" + # output_dataset = inferencer.inference( + # model, + # input_dataset, + # remove_image_flag=remove_image_flag) + # response = output_dataset.backend_dataset['text'] + # print(response[0]) + # print("\n", end="") + # context += response[0] + # while True: + # input_text = input("User >>> ") + # if input_text == "exit": + # print("exit...") + # break + # elif input_text.startswith("###Load image:"): + # image_path = input_text[14:] + # try: + # raw_image = Image.open(image_path) + # image_list.append(raw_image) + # context += sep + "Human: " + " " + # text_after_loading_image = True + # continue + # except FileNotFoundError: + # print("Loading image failed") + # elif input_text == "reset": + # context = "" + # print("Chat history cleared") + # continue + + input_text = "describe the image" + + # if text_after_loading_image is False: + # if chatbot_args.prompt_format == "mini_gpt": + # context += sep + "Human: " + # else: + # text_after_loading_image = False + if not input_text: input_text = " " - context += prompt_structure.format(input_text=input_text) + context += prompt_structure.format(input_text=input_text) + # TODO handle when model doesn't have the get_max_length context = context[-model.get_max_length():] # Memory of the bot + import pdb; pdb.set_trace() input_dataset = dataset.from_dict({ "type": "image_text", - "instances": [{"images": raw_image, + "instances": [{"images": image_list, "text": context,}] }) + import pdb; pdb.set_trace() remove_image_flag = chatbot_args.prompt_format=="mini_gpt" output_dataset = inferencer.inference( model, @@ -180,6 +260,36 @@ def main(): print("\n", end="") context += response[0] + image_list.append(raw_image) + context += sep + "Human: " + " " + input_text = "describe the image again" + + # if text_after_loading_image is False: + # if chatbot_args.prompt_format == "mini_gpt": + # context += sep + "Human: " + # else: + # text_after_loading_image = False + + if not input_text: + input_text = " " + context += prompt_structure.format(input_text=input_text) + + # TODO handle when model doesn't have the get_max_length + context = context[-model.get_max_length():] # Memory of the bot + input_dataset = dataset.from_dict({ + "type": "image_text", + "instances": [{"images": image_list, + "text": context,}] + }) + remove_image_flag = chatbot_args.prompt_format=="mini_gpt" + output_dataset = inferencer.inference( + model, + input_dataset, + remove_image_flag=remove_image_flag) + response = output_dataset.backend_dataset['text'] + print(response[0]) + print("\n", end="") + context += response[0] if __name__ == "__main__": main() diff --git a/requirements.txt b/requirements.txt index 1fa7c7e31..a16a3c78f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ numpy==1.24.2 datasets==2.10.1 peft @ git+https://github.com/huggingface/peft.git@deff03f2c251534fffd2511fc2d440e84cc54b1b -torch==2.0.0 wandb==0.14.0 deepspeed==0.8.3 trl @ git+https://github.com/lvwerra/trl.git#egg=trl-0.4.1 diff --git a/scripts/run_vis_chatbot_debug.sh b/scripts/run_vis_chatbot_debug.sh new file mode 100644 index 000000000..61fc124f8 --- /dev/null +++ b/scripts/run_vis_chatbot_debug.sh @@ -0,0 +1,11 @@ +model=Salesforce/blip2-flan-t5-xxl +checkpoint_path=/scratch/PI/tongzhang/qinglian/checkpoints/pretrained_weights/minigpt4/prerained_minigpt4_7b_converted.pth +llm_model_name_or_path=/scratch/PI/tongzhang/qinglian/checkpoints/pretrained_weights/vicuna-7b/ +deepspeed examples/vis_debug.py --model_name_or_path ${model} --deepspeed configs/ds_config_multimodal.json --arch_type vision_encoder_decoder --task vqa --custom_model \ + --prompt_format mini_gpt \ + --prompt_structure "{input_text}###Assistant:" \ + --checkpoint_path ${checkpoint_path} \ + --llm_model_name_or_path ${llm_model_name_or_path} \ + --image_path "/home/qlianab/base.jpg" \ + --low_resource True + diff --git a/scripts/run_vis_chatbot_gradio_minigpt4.sh b/scripts/run_vis_chatbot_gradio_minigpt4.sh index 46ec4d56c..b03fff19e 100644 --- a/scripts/run_vis_chatbot_gradio_minigpt4.sh +++ b/scripts/run_vis_chatbot_gradio_minigpt4.sh @@ -1,2 +1,10 @@ model=Salesforce/blip2-flan-t5-xxl -deepspeed examples/vis_chatbot_gradio.py --model_name_or_path ${model} --deepspeed configs/ds_config_multimodal.json --arch_type vision_encoder_decoder --task vqa --custom_model --prompt_format mini_gpt --prompt_structure "{input_text}###Assistant:" +checkpoint_path=/scratch/PI/tongzhang/qinglian/checkpoints/pretrained_weights/minigpt4/prerained_minigpt4_7b_converted.pth +llm_model_name_or_path=/scratch/PI/tongzhang/qinglian/checkpoints/pretrained_weights/vicuna-7b/ +deepspeed examples/vis_chatbot_gradio.py --model_name_or_path ${model} \ + --deepspeed configs/ds_config_multimodal.json \ + --arch_type vision_encoder_decoder \ + --task vqa --custom_model --prompt_format mini_gpt --prompt_structure "{input_text}###Assistant:" \ + --checkpoint_path ${checkpoint_path} \ + --llm_model_name_or_path ${llm_model_name_or_path} + --low_resource True diff --git a/scripts/run_vis_chatbot_minigpt4.sh b/scripts/run_vis_chatbot_minigpt4.sh index a79b9cc67..4d55b3457 100644 --- a/scripts/run_vis_chatbot_minigpt4.sh +++ b/scripts/run_vis_chatbot_minigpt4.sh @@ -1,9 +1,11 @@ model=Salesforce/blip2-flan-t5-xxl checkpoint_path=/scratch/PI/tongzhang/qinglian/checkpoints/pretrained_weights/minigpt4/prerained_minigpt4_7b_converted.pth -llm_model_name_or_path="/scratch/PI/tongzhang/qinglian/checkpoints/pretrained_weights/vicuna-7b/" -deepspeed examples/debug.py --model_name_or_path ${model} --deepspeed configs/ds_config_multimodal.json --arch_type vision_encoder_decoder --task vqa --custom_model \ +llm_model_name_or_path=/scratch/PI/tongzhang/qinglian/checkpoints/pretrained_weights/vicuna-7b/ +deepspeed examples/vis_chatbot.py --model_name_or_path ${model} --deepspeed configs/ds_config_multimodal.json --arch_type vision_encoder_decoder --task vqa --custom_model \ --prompt_format mini_gpt \ --prompt_structure "{input_text}###Assistant:" \ - --checkpoint_path {checkpoint_path} \ - --llm_model_name_or_path {llm_model_name_or_path} + --checkpoint_path ${checkpoint_path} \ + --llm_model_name_or_path ${llm_model_name_or_path} \ + --image_path "/home/qlianab/base.jpg" + --low_resource True diff --git a/src/lmflow/args.py b/src/lmflow/args.py index cfd998619..29b48a9d3 100644 --- a/src/lmflow/args.py +++ b/src/lmflow/args.py @@ -198,26 +198,7 @@ class ModelArguments: ) } ) - use_int8: bool = field( - default=False, - metadata={"help": "whether to load int8 quantization for inference"} - ) - custom_model: bool = field( - default=False, - metadata={"help": "flag for the model from huggingface or not"} - ) - checkpoint_path: str = field( - default=None, - metadata={"help": "path for model checkpoint"} - ) - llm_model_name_or_path: Optional[str] = field( - default=None, - metadata={ - "help": ( - "llm model in multi-modality model" - ) - }, - ) + def __post_init__(self): if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): diff --git a/src/lmflow/models/hf_encoder_decoder_model.py b/src/lmflow/models/hf_encoder_decoder_model.py index 18fa10d64..b28c337a3 100644 --- a/src/lmflow/models/hf_encoder_decoder_model.py +++ b/src/lmflow/models/hf_encoder_decoder_model.py @@ -46,6 +46,7 @@ AutoModelForVision2Seq, AutoModel, AutoProcessor, + LlamaTokenizer ) from transformers import (Blip2VisionConfig, @@ -181,19 +182,16 @@ def __init__( # self.backend_model = model_register.from_pretrained( # model_args.model_name_or_path) else: - # model = CustomAutoVision2SeqModel.from_pretrained( - # model_args.model_name_or_path, - # ) - vision_config = Blip2VisionConfig.from_pretrained(model_args.model_name_or_path) - qformer_config = Blip2QFormerConfig.from_pretrained(model_args.model_name_or_path) - text_config = LlamaConfig.from_pretrained(model_args.llm_model_name_or_path) - config = Blip2Config.from_vision_qformer_text_configs(vision_config, qformer_config, text_config) - model = CustomAutoVision2SeqModel(config) - model.vision_model_from_pretrained(model_args.model_name_or_path) - model.qformer_from_pretrained(model_args.model_name_or_path) - model.language_model_from_pretrained(model_args.llm_model_name_or_path) + model = CustomAutoVision2SeqModel.from_pretrained(model_args.model_name_or_path) + if model_args.llm_model_name_or_path is not None: + text_config = LlamaConfig.from_pretrained(model_args.llm_model_name_or_path) + model.config.text_config = text_config + model.language_model_from_pretrained(model_args.llm_model_name_or_path, + low_resource=model_args.low_resource) state_dict = torch.load(model_args.checkpoint_path, map_location="cpu") model.load_state_dict(state_dict, strict=False) + # model = CustomAutoVision2SeqModel.from_pretrained( + # "/home/qlianab/checkpoints/pretrained_weights/minigpt4-lmflow-vicuna-7b-low_resource/") self.backend_model = model if self.arch_type == "encoder_decoder": @@ -202,8 +200,9 @@ def __init__( tokenizer_register = AutoProcessor else: raise NotImplementedError - self.tokenizer = tokenizer_register.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) + if model_args.llm_model_name_or_path is not None: + self.tokenizer.tokenizer = LlamaTokenizer.from_pretrained(model_args.llm_model_name_or_path) self.backend_model_full = self.backend_model if peft_model_id is not None: self.backend_model = PeftModel.from_pretrained( @@ -267,6 +266,7 @@ def encode(self, input: Union[str, List[str]], *args, **kwargs ) -> Union[List[i outputs : The tokenized inputs. """ + import pdb; pdb.set_trace() if isinstance(input, dict): # TODO refactor the input type to make it elegant. kwargs.update(input) @@ -329,6 +329,7 @@ def inference(self, inputs, *args, **kwargs): The generated sequence output """ # TODO need to discuss how to handle pad_token_id + import pdb; pdb.set_trace() if self.arch_type == "encoder_decoder": kwargs.update(pad_token_id=self.tokenizer.pad_token_id) elif self.arch_type == "vision_encoder_decoder": diff --git a/src/lmflow/models/vision2seq_model.py b/src/lmflow/models/vision2seq_model.py index 54ec8db5e..244ac875f 100644 --- a/src/lmflow/models/vision2seq_model.py +++ b/src/lmflow/models/vision2seq_model.py @@ -9,7 +9,8 @@ from transformers import ( Blip2ForConditionalGeneration, - Blip2Config + Blip2Config, + AutoModelForCausalLM ) from .base_model import BaseModel @@ -23,22 +24,26 @@ def vision_model_from_pretrained(self, pretrained_path): self.vision_model = self.vision_model.from_pretrained( pretrained_path, config=self.config.vision_config) - def qformer_from_pretrained(self, pretrained_path): - print(self.qformer.encoder.layer[11].output_query.dense.weight.mean()) self.qformer = self.qformer.from_pretrained( pretrained_path, config=self.config.qformer_config) print(self.qformer.encoder.layer[11].output_query.dense.weight.mean()) - def language_model_from_pretrained(self, pretrained_path): + def language_model_from_pretrained(self, pretrained_path, low_resource=False): # TODO remove the low resource related loading in the future - self.language_model = self.language_model.from_pretrained( + if low_resource: + kwargs = dict( + torch_dtype=torch.float16, + load_in_8bit=True, + device_map="auto" + ) + else: + kwargs = {} + self.language_model = AutoModelForCausalLM.from_pretrained( pretrained_path, config=self.config.text_config, - torch_dtype=torch.float16, - device_map="auto") - + **kwargs) @torch.no_grad() def generate( @@ -47,6 +52,7 @@ def generate( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, image_token_indexes: Optional[List] = [0], + one_sample_multiple_images: Optional[bool] = False, **generate_kwargs, ) -> torch.LongTensor: """ @@ -59,6 +65,10 @@ def generate( The sequence used as a prompt for the generation. attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): Mask to avoid performing attention on padding token indices + image_token_indexes (bool, *optional*): + The index for inserting the image tokens. + one_sample_multiple_images: (bool, *optional*): + The flag for inference that the input batch size is 1 and contain multiple images. Returns: captions (list): A list of strings of length batch_size * num_captions. @@ -66,8 +76,10 @@ def generate( if hasattr(self, "hf_device_map"): # preprocess for `accelerate` self._preprocess_accelerate() - - batch_size = pixel_values.shape[0] + if not one_sample_multiple_images: + batch_size = pixel_values.shape[0] + else: + batch_size = 1 image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) @@ -98,27 +110,27 @@ def generate( # concatenate query embeddings with prompt embeddings inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = inputs_embeds.to(language_model_inputs.device) - # concatenate the text embeddings with image embeddings inputs_embeds_with_images = [] attention_mask_with_images = [] # currently we only support with one image assert len(image_token_indexes) == 1 - for image_token_index in image_token_indexes: + for idx, image_token_index in enumerate(image_token_indexes): inputs_embeds_with_images.append(inputs_embeds[:, :image_token_index]) - inputs_embeds_with_images.append(language_model_inputs) + inputs_embeds_with_images.append(language_model_inputs[idx][None]) attention_mask_with_images.append( attention_mask[:, :image_token_index]) - attention_mask_with_images.append(language_attention_mask) + attention_mask_with_images.append(language_attention_mask[idx][None]) inputs_embeds_with_images.append(inputs_embeds[:, image_token_indexes[-1]:]) inputs_embeds = torch.cat(inputs_embeds_with_images, dim=1) attention_mask_with_images.append(attention_mask[:, image_token_indexes[-1]:]) attention_mask = torch.cat(attention_mask_with_images, dim=1) + inputs_embeds = inputs_embeds.to(self.language_model.lm_head.weight.dtype) + attention_mask = attention_mask.to(self.language_model.lm_head.weight.dtype) outputs = self.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generate_kwargs, ) - return outputs diff --git a/src/lmflow/pipeline/inferencer.py b/src/lmflow/pipeline/inferencer.py index f0e6a1a64..df93c4b23 100644 --- a/src/lmflow/pipeline/inferencer.py +++ b/src/lmflow/pipeline/inferencer.py @@ -154,14 +154,18 @@ def inference( else: input = current_batch['input'] input['text'] = prompt_structure.format(input=input['text']) - if remove_image_flag: + # remove the image flag in tokenization; input['text'] = input['text'].split("") new_input = copy.deepcopy(input) new_input['text'] = new_input['text'][-1] input['text'] = input['text'][0] - inputs = model.encode(input, return_tensors="pt").to(device=self.local_rank) - new_inputs = model.encode(new_input, return_tensors="pt").to(device=self.local_rank) + inputs = model.encode(input, + return_tensors="pt", + add_special_tokens=True).to(device=self.local_rank) + new_inputs = model.encode(new_input, + return_tensors="pt", + add_special_tokens=False).to(device=self.local_rank) image_token_indexes = [inputs["input_ids"].shape[1]] inputs["input_ids"] = torch.cat([inputs["input_ids"], new_inputs["input_ids"]], dim=1)