Skip to content

Commit

Permalink
Support vision chatbot 13b
Browse files Browse the repository at this point in the history
- Remove unnecessary comments
  • Loading branch information
research4pan committed Jul 22, 2023
1 parent b1475ae commit a6b30f9
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 105 deletions.
2 changes: 1 addition & 1 deletion examples/vis_chatbot_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def gradio_answer(chatbot, chat_state, image_list, num_beams=1, temperature=1.0)
new_print_index += 1
chatbot[-1][1] += char
chat_state += char
time.sleep(0.03)
time.sleep(0.05)
yield chatbot, chat_state, image_list

print_index = new_print_index
Expand Down
12 changes: 12 additions & 0 deletions output_models/download.sh
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ function main() {
tar zxvf ${filename}
rm ${filename}
fi

if [ "$1" = "minigpt4_7b" -o "$1" = "all" ]; then
echo "downloading minigpt4_7b"
filename='pretrained_minigpt4_7b.pth'
wget ${public_server}/${filename}
fi

if [ "$1" = "minigpt4_13b" -o "$1" = "all" ]; then
echo "downloading minigpt4_13b"
filename='pretrained_minigpt4_13b.pth'
wget ${public_server}/${filename}
fi
}

main "$@"
61 changes: 1 addition & 60 deletions src/lmflow/models/vision2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def language_model_from_pretrained(self,
self.language_projection = nn.Linear(in_channels,
self.config.text_config.hidden_size,
bias=True)

def register_prompt_cache(self, prompt_ids, prompt_keys_values):
"""
Udpate the prompt id and embedding for reuse in the future
Expand Down Expand Up @@ -133,9 +134,6 @@ def generate(
Returns:
captions (list): A list of strings of length batch_size * num_captions.
"""
# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: start", flush=True)

if hasattr(self, "hf_device_map"):
# preprocess for `accelerate`
self._preprocess_accelerate()
Expand All @@ -144,35 +142,9 @@ def generate(
else:
batch_size = 1

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: _preprocess_accelerate end", flush=True)

# image_id = pixel_values.cpu().numpy().tobytes()
# if image_id in self.cache_dict:
# language_model_inputs = self.cache_dict[image_id]
# else:
# # print("========", pixel_values)
# image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
# image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)

# query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
# query_outputs = self.qformer(
# query_embeds=query_tokens,
# encoder_hidden_states=image_embeds,
# encoder_attention_mask=image_attention_mask,
# return_dict=True,
# )
# query_output = query_outputs.last_hidden_state

# language_model_inputs = self.language_projection(query_output)
# self.cache_dict[image_id] = language_model_inputs

image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: image_embeds end", flush=True)

query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_outputs = self.qformer(
query_embeds=query_tokens,
Expand All @@ -181,15 +153,8 @@ def generate(
return_dict=True,
)
query_output = query_outputs.last_hidden_state

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: query_outputs end", flush=True)

language_model_inputs = self.language_projection(query_output)

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: language_model_inputs end", flush=True)

language_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)
Expand All @@ -203,9 +168,6 @@ def generate(
attention_mask = torch.ones_like(input_ids)
attention_mask = attention_mask.to(language_attention_mask.device)

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: attention_mask end", flush=True)

# concatenate query embeddings with prompt embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = inputs_embeds.to(language_model_inputs.device)
Expand All @@ -217,9 +179,6 @@ def generate(
assert len(image_token_indexes) == pixel_values.shape[0]
# token format: (# text, # image)xN, # text

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: input_embeds end", flush=True)

for idx, image_token_index in enumerate(image_token_indexes):
end_index += image_token_index
inputs_embeds_with_images.append(
Expand All @@ -230,9 +189,6 @@ def generate(
attention_mask_with_images.append(language_attention_mask[idx][None])
start_index = end_index

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: xxx_with_images end", flush=True)

inputs_embeds_with_images.append(inputs_embeds[:, image_token_indexes[-1]:])
inputs_embeds = torch.cat(inputs_embeds_with_images, dim=1)
attention_mask_with_images.append(attention_mask[:, image_token_indexes[-1]:])
Expand All @@ -241,9 +197,6 @@ def generate(
inputs_embeds = inputs_embeds.to(self.language_model.lm_head.weight.dtype)
attention_mask = attention_mask.to(self.language_model.lm_head.weight.dtype)

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: llm generate start", flush=True)

if not self.use_prompt_cache or batch_size != 1:
outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
Expand All @@ -265,19 +218,13 @@ def generate(
past_key_values = outputs["past_key_values"]
self.register_prompt_cache(prompt_ids, past_key_values)

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: first llm generate end", flush=True)

prompt_length = self.prompt_id.shape[1]
if torch.all(input_ids[:, :prompt_length] == self.prompt_id):
past_key_values = self.prompt_key_values
else:
past_key_values = None
generate_kwargs["past_key_values"] = past_key_values

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: second llm generate start", flush=True)

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds[:, prompt_length:],
attention_mask=attention_mask[:, prompt_length:],
Expand All @@ -286,10 +233,4 @@ def generate(
)
outputs = outputs.logits

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: second llm generate end", flush=True)

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: end", flush=True)

return outputs
44 changes: 0 additions & 44 deletions src/lmflow/pipeline/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,8 @@ def inference(
raise NotImplementedError(
'input dataset should have type {}'.format(
supported_dataset_type))

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: inferencer.inference: start", flush=True)
# print(f'type of context text: {type(dataset.to_dict()["instances"][0]["text"])}', flush=True)

dataloader, data_size = self.create_dataloader(dataset)

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: inferencer.inference: create_dataloader end", flush=True)
# print(f'type of context text: {type(dataset.to_dict()["instances"][0]["text"])}', flush=True)

# The output dataset
output_dict = {
"type": "text_only",
Expand Down Expand Up @@ -189,9 +180,6 @@ def inference(
inputs["input_ids"] = torch.cat(input_ids, dim=1)
inputs["attention_mask"] = torch.cat(attention_mask, dim=1)
else:
# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: inferencer.inference: model.encode start", flush=True)

if self.inferencer_args.device == "gpu":
inputs = model.encode(input, return_tensors="pt").to(device=self.local_rank)
elif self.inferencer_args.device == "cpu":
Expand All @@ -201,27 +189,17 @@ def inference(
f"device \"{self.inferencer_args.device}\" is not supported"
)

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: inferencer.inference: model.encode end", flush=True)
# print(f'type of context text: {type(dataset.to_dict()["instances"][0]["text"])}', flush=True)

if remove_image_flag:
inputs["image_token_indexes"] = image_token_indexes
inputs["one_sample_multiple_images"] = True

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: inferencer.inference: model.inference start", flush=True)
# print(f'type of context text: {type(dataset.to_dict()["instances"][0]["text"])}', flush=True)
outputs = model.inference(
inputs,
max_new_tokens=max_new_tokens,
temperature=self.inferencer_args.temperature,
repetition_penalty=self.inferencer_args.repetition_penalty,
do_sample=self.inferencer_args.do_sample,
)
# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: inferencer.inference: model.inference end", flush=True)
# print(f'type of context text: {type(dataset.to_dict()["instances"][0]["text"])}', flush=True)

# only return the generation, trucating the input
if self.model_args.arch_type != "vision_encoder_decoder":
Expand All @@ -238,17 +216,9 @@ def inference(

output_dict["instances"].append({ "text": text_out })

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: inferencer.inference: output dataset prepare start", flush=True)
# print(f'type of context text: {type(dataset.to_dict()["instances"][0]["text"])}', flush=True)

output_dataset = Dataset(DatasetArguments(dataset_path = None))
output_dataset = output_dataset.from_dict(output_dict)

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: inferencer.inference: output dataset prepare end", flush=True)
# print(f'type of context text: {type(dataset.to_dict()["instances"][0]["text"])}', flush=True)

return output_dataset

def stream_inference(
Expand All @@ -269,9 +239,6 @@ def stream_inference(
response = rstrip_partial_utf8(response)
yield response, False
else:
# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: inferencer.stream_inference: start", flush=True)

for _ in range(0, self.inferencer_args.max_new_tokens // token_per_step):
output_dataset = self.inference(
model=model,
Expand All @@ -285,18 +252,10 @@ def stream_inference(
new_append_text = rstrip_partial_utf8(new_append_text)
response += new_append_text

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: inferencer.stream_inference: partial inference end", flush=True)
# print(f'type of context text: {type(input_dataset.to_dict()["instances"][0]["text"])}', flush=True)

input_dict = input_dataset.to_dict()
input_dict["instances"][0]["text"] += new_append_text
input_dataset = input_dataset.from_dict(input_dict)

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: inferencer.stream_inference: output encapsulation end", flush=True)
# print(f'type of context text: {type(input_dataset.to_dict()["instances"][0]["text"])}', flush=True)

flag_break = False
try:
index = response.index(end_string)
Expand All @@ -307,7 +266,4 @@ def stream_inference(

response = response[:index]

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: inferencer.stream_inference: {_} end", flush=True)
# print(f'type of context text: {type(input_dataset.to_dict()["instances"][0]["text"])}', flush=True)
yield response, flag_break

0 comments on commit a6b30f9

Please sign in to comment.