From a6b30f96f1583011bf1f38631a61368d10944731 Mon Sep 17 00:00:00 2001 From: rpan Date: Sat, 22 Jul 2023 17:24:17 +0800 Subject: [PATCH] Support vision chatbot 13b - Remove unnecessary comments --- examples/vis_chatbot_gradio.py | 2 +- output_models/download.sh | 12 ++++++ src/lmflow/models/vision2seq_model.py | 61 +-------------------------- src/lmflow/pipeline/inferencer.py | 44 ------------------- 4 files changed, 14 insertions(+), 105 deletions(-) diff --git a/examples/vis_chatbot_gradio.py b/examples/vis_chatbot_gradio.py index f01083448..9e5e052de 100644 --- a/examples/vis_chatbot_gradio.py +++ b/examples/vis_chatbot_gradio.py @@ -216,7 +216,7 @@ def gradio_answer(chatbot, chat_state, image_list, num_beams=1, temperature=1.0) new_print_index += 1 chatbot[-1][1] += char chat_state += char - time.sleep(0.03) + time.sleep(0.05) yield chatbot, chat_state, image_list print_index = new_print_index diff --git a/output_models/download.sh b/output_models/download.sh index 153d570ea..b6ee77c8d 100755 --- a/output_models/download.sh +++ b/output_models/download.sh @@ -103,6 +103,18 @@ function main() { tar zxvf ${filename} rm ${filename} fi + + if [ "$1" = "minigpt4_7b" -o "$1" = "all" ]; then + echo "downloading minigpt4_7b" + filename='pretrained_minigpt4_7b.pth' + wget ${public_server}/${filename} + fi + + if [ "$1" = "minigpt4_13b" -o "$1" = "all" ]; then + echo "downloading minigpt4_13b" + filename='pretrained_minigpt4_13b.pth' + wget ${public_server}/${filename} + fi } main "$@" diff --git a/src/lmflow/models/vision2seq_model.py b/src/lmflow/models/vision2seq_model.py index 1f5e36c5a..dc67a581a 100644 --- a/src/lmflow/models/vision2seq_model.py +++ b/src/lmflow/models/vision2seq_model.py @@ -58,6 +58,7 @@ def language_model_from_pretrained(self, self.language_projection = nn.Linear(in_channels, self.config.text_config.hidden_size, bias=True) + def register_prompt_cache(self, prompt_ids, prompt_keys_values): """ Udpate the prompt id and embedding for reuse in the future @@ -133,9 +134,6 @@ def generate( Returns: captions (list): A list of strings of length batch_size * num_captions. """ - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: model.generate: start", flush=True) - if hasattr(self, "hf_device_map"): # preprocess for `accelerate` self._preprocess_accelerate() @@ -144,35 +142,9 @@ def generate( else: batch_size = 1 - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: model.generate: _preprocess_accelerate end", flush=True) - - # image_id = pixel_values.cpu().numpy().tobytes() - # if image_id in self.cache_dict: - # language_model_inputs = self.cache_dict[image_id] - # else: - # # print("========", pixel_values) - # image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state - # image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) - - # query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) - # query_outputs = self.qformer( - # query_embeds=query_tokens, - # encoder_hidden_states=image_embeds, - # encoder_attention_mask=image_attention_mask, - # return_dict=True, - # ) - # query_output = query_outputs.last_hidden_state - - # language_model_inputs = self.language_projection(query_output) - # self.cache_dict[image_id] = language_model_inputs - image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: model.generate: image_embeds end", flush=True) - query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_outputs = self.qformer( query_embeds=query_tokens, @@ -181,15 +153,8 @@ def generate( return_dict=True, ) query_output = query_outputs.last_hidden_state - - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: model.generate: query_outputs end", flush=True) - language_model_inputs = self.language_projection(query_output) - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: model.generate: language_model_inputs end", flush=True) - language_attention_mask = torch.ones( language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) @@ -203,9 +168,6 @@ def generate( attention_mask = torch.ones_like(input_ids) attention_mask = attention_mask.to(language_attention_mask.device) - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: model.generate: attention_mask end", flush=True) - # concatenate query embeddings with prompt embeddings inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = inputs_embeds.to(language_model_inputs.device) @@ -217,9 +179,6 @@ def generate( assert len(image_token_indexes) == pixel_values.shape[0] # token format: (# text, # image)xN, # text - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: model.generate: input_embeds end", flush=True) - for idx, image_token_index in enumerate(image_token_indexes): end_index += image_token_index inputs_embeds_with_images.append( @@ -230,9 +189,6 @@ def generate( attention_mask_with_images.append(language_attention_mask[idx][None]) start_index = end_index - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: model.generate: xxx_with_images end", flush=True) - inputs_embeds_with_images.append(inputs_embeds[:, image_token_indexes[-1]:]) inputs_embeds = torch.cat(inputs_embeds_with_images, dim=1) attention_mask_with_images.append(attention_mask[:, image_token_indexes[-1]:]) @@ -241,9 +197,6 @@ def generate( inputs_embeds = inputs_embeds.to(self.language_model.lm_head.weight.dtype) attention_mask = attention_mask.to(self.language_model.lm_head.weight.dtype) - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: model.generate: llm generate start", flush=True) - if not self.use_prompt_cache or batch_size != 1: outputs = self.language_model.generate( inputs_embeds=inputs_embeds, @@ -265,9 +218,6 @@ def generate( past_key_values = outputs["past_key_values"] self.register_prompt_cache(prompt_ids, past_key_values) - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: model.generate: first llm generate end", flush=True) - prompt_length = self.prompt_id.shape[1] if torch.all(input_ids[:, :prompt_length] == self.prompt_id): past_key_values = self.prompt_key_values @@ -275,9 +225,6 @@ def generate( past_key_values = None generate_kwargs["past_key_values"] = past_key_values - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: model.generate: second llm generate start", flush=True) - outputs = self.language_model.generate( inputs_embeds=inputs_embeds[:, prompt_length:], attention_mask=attention_mask[:, prompt_length:], @@ -286,10 +233,4 @@ def generate( ) outputs = outputs.logits - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: model.generate: second llm generate end", flush=True) - - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: model.generate: end", flush=True) - return outputs diff --git a/src/lmflow/pipeline/inferencer.py b/src/lmflow/pipeline/inferencer.py index 02e230b41..4ccc0518f 100644 --- a/src/lmflow/pipeline/inferencer.py +++ b/src/lmflow/pipeline/inferencer.py @@ -136,17 +136,8 @@ def inference( raise NotImplementedError( 'input dataset should have type {}'.format( supported_dataset_type)) - - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: inferencer.inference: start", flush=True) - # print(f'type of context text: {type(dataset.to_dict()["instances"][0]["text"])}', flush=True) - dataloader, data_size = self.create_dataloader(dataset) - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: inferencer.inference: create_dataloader end", flush=True) - # print(f'type of context text: {type(dataset.to_dict()["instances"][0]["text"])}', flush=True) - # The output dataset output_dict = { "type": "text_only", @@ -189,9 +180,6 @@ def inference( inputs["input_ids"] = torch.cat(input_ids, dim=1) inputs["attention_mask"] = torch.cat(attention_mask, dim=1) else: - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: inferencer.inference: model.encode start", flush=True) - if self.inferencer_args.device == "gpu": inputs = model.encode(input, return_tensors="pt").to(device=self.local_rank) elif self.inferencer_args.device == "cpu": @@ -201,17 +189,10 @@ def inference( f"device \"{self.inferencer_args.device}\" is not supported" ) - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: inferencer.inference: model.encode end", flush=True) - # print(f'type of context text: {type(dataset.to_dict()["instances"][0]["text"])}', flush=True) - if remove_image_flag: inputs["image_token_indexes"] = image_token_indexes inputs["one_sample_multiple_images"] = True - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: inferencer.inference: model.inference start", flush=True) - # print(f'type of context text: {type(dataset.to_dict()["instances"][0]["text"])}', flush=True) outputs = model.inference( inputs, max_new_tokens=max_new_tokens, @@ -219,9 +200,6 @@ def inference( repetition_penalty=self.inferencer_args.repetition_penalty, do_sample=self.inferencer_args.do_sample, ) - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: inferencer.inference: model.inference end", flush=True) - # print(f'type of context text: {type(dataset.to_dict()["instances"][0]["text"])}', flush=True) # only return the generation, trucating the input if self.model_args.arch_type != "vision_encoder_decoder": @@ -238,17 +216,9 @@ def inference( output_dict["instances"].append({ "text": text_out }) - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: inferencer.inference: output dataset prepare start", flush=True) - # print(f'type of context text: {type(dataset.to_dict()["instances"][0]["text"])}', flush=True) - output_dataset = Dataset(DatasetArguments(dataset_path = None)) output_dataset = output_dataset.from_dict(output_dict) - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: inferencer.inference: output dataset prepare end", flush=True) - # print(f'type of context text: {type(dataset.to_dict()["instances"][0]["text"])}', flush=True) - return output_dataset def stream_inference( @@ -269,9 +239,6 @@ def stream_inference( response = rstrip_partial_utf8(response) yield response, False else: - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: inferencer.stream_inference: start", flush=True) - for _ in range(0, self.inferencer_args.max_new_tokens // token_per_step): output_dataset = self.inference( model=model, @@ -285,18 +252,10 @@ def stream_inference( new_append_text = rstrip_partial_utf8(new_append_text) response += new_append_text - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: inferencer.stream_inference: partial inference end", flush=True) - # print(f'type of context text: {type(input_dataset.to_dict()["instances"][0]["text"])}', flush=True) - input_dict = input_dataset.to_dict() input_dict["instances"][0]["text"] += new_append_text input_dataset = input_dataset.from_dict(input_dict) - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: inferencer.stream_inference: output encapsulation end", flush=True) - # print(f'type of context text: {type(input_dataset.to_dict()["instances"][0]["text"])}', flush=True) - flag_break = False try: index = response.index(end_string) @@ -307,7 +266,4 @@ def stream_inference( response = response[:index] - # current_time = time.strftime("%H:%M:%S", time.localtime()) - # print(f"{current_time}: inferencer.stream_inference: {_} end", flush=True) - # print(f'type of context text: {type(input_dataset.to_dict()["instances"][0]["text"])}', flush=True) yield response, flag_break