Skip to content

Commit

Permalink
Improve inference speed of vis_chatbot_gradio
Browse files Browse the repository at this point in the history
- Support backend "dict" in datasets
  • Loading branch information
research4pan committed Jul 17, 2023
1 parent 6877f18 commit 284080f
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 29 deletions.
19 changes: 8 additions & 11 deletions examples/vis_chatbot_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
# Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved.
"""A simple Multimodal chatbot implemented with lmflow APIs.
"""
from dataclasses import dataclass, field
import asyncio
import logging
import json
import time

from PIL import Image
from lmflow.pipeline.inferencer import Inferencer

Expand Down Expand Up @@ -138,7 +139,7 @@ class ChatbotArguments:
)

data_args = DatasetArguments(dataset_path=None)
dataset = Dataset(data_args)
dataset = Dataset(data_args, backend="dict")

inferencer = AutoPipeline.get_pipeline(
pipeline_name=pipeline_name,
Expand All @@ -156,11 +157,6 @@ class ChatbotArguments:
end_string = chatbot_args.end_string
prompt_structure = chatbot_args.prompt_structure


token_per_step = 4



title = """<h1 align="center">Demo of Multi-modality chatbot from LMFlow</h1>"""
description = """<h3>This is the demo of Multi-modality chatbot from LMFlow. Upload your images and start chatting!</h3>"""
# article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
Expand Down Expand Up @@ -227,7 +223,7 @@ def gradio_ask(user_message, chatbot, chat_state):
return '', chatbot, chat_state


def gradio_answer(chatbot, chat_state, image_list, num_beams=1, temperature=1.0):
async def gradio_answer(chatbot, chat_state, image_list, num_beams=1, temperature=1.0):
input_dataset = dataset.from_dict({
"type": "image_text",
"instances": [{"images": np.stack([np.array(i) for i in image_list]),
Expand All @@ -238,8 +234,8 @@ def gradio_answer(chatbot, chat_state, image_list, num_beams=1, temperature=1.0)
chatbot[-1][1] = ''

print_index = 0
token_per_step = 48
max_new_tokens = 512
token_per_step = 20 # 48
max_new_tokens = 1024
temperature = 0.7

for response, flag_break in inferencer.stream_inference(
Expand All @@ -262,8 +258,9 @@ def gradio_answer(chatbot, chat_state, image_list, num_beams=1, temperature=1.0)
new_print_index += 1
chatbot[-1][1] += char
chat_state += char
time.sleep(0.1)
await asyncio.sleep(0.1)
yield chatbot, chat_state, image_list
# await asyncio.sleep(1)

print_index = new_print_index

Expand Down
27 changes: 26 additions & 1 deletion src/lmflow/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@


# Importing necessary libraries and modules
from cmath import e
import copy
import json

from cmath import e
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -229,6 +231,10 @@ def from_dict(self, dict_obj: dict, *args, **kwargs):
)
self._check_data_format()

return self
elif self.backend == "dict":
self.backend_dataset = dict_obj
self.type = dict_obj[KEY_TYPE]
return self
else:
raise NotImplementedError(
Expand Down Expand Up @@ -294,13 +300,32 @@ def to_dict(self):
for i in range(num_instances)
]

return dict_obj
elif self.backend == "dict":
dict_obj = self.backend_dataset
return dict_obj
else:
raise NotImplementedError(
f'Current .to_dict is not supported for backend "{backend}"'
)


def to_list(self):
"""Returns a list of instances."""
if self.backend == "huggingface":
instance_list = [self.backend_dataset.__getitem__(idx)
for idx in range(len(self.backend_dataset))]
return instance_list
elif self.backend == "dict":
instance_list = copy.deepcopy(self.backend_dataset[KEY_INSTANCES])
# TODO: should be a list of instances, instance should be huggingface datasets row format
return instance_list
else:
raise NotImplementedError(
f'Current .to_list is not supported for backend "{backend}"'
)


def map(self, *args, **kwargs):
r"""
Parameters
Expand Down
11 changes: 11 additions & 0 deletions src/lmflow/models/hf_encoder_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import copy
import logging
import time
from typing import List, Union

import deepspeed
Expand Down Expand Up @@ -332,6 +333,9 @@ def inference(self, inputs, *args, **kwargs):
outputs :
The generated sequence output
"""
# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.inference: start", flush=True)

# TODO need to discuss how to handle pad_token_id
if self.arch_type == "encoder_decoder":
kwargs.update(pad_token_id=self.tokenizer.pad_token_id)
Expand All @@ -342,6 +346,9 @@ def inference(self, inputs, *args, **kwargs):
kwargs.update(**inputs)
inputs = input_ids

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.inference: kwargs update end", flush=True)

with torch.no_grad():
if self.device == "gpu":
outputs = self.ds_engine.module.generate(
Expand All @@ -361,6 +368,10 @@ def inference(self, inputs, *args, **kwargs):
raise NotImplementedError(
f"device \"{self.device}\" is not supported"
)

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.inference: end", flush=True)

return outputs


Expand Down
40 changes: 40 additions & 0 deletions src/lmflow/models/vision2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import copy
import logging
import time
import torch
from typing import List, Optional, Union

Expand Down Expand Up @@ -125,6 +126,9 @@ def generate(
Returns:
captions (list): A list of strings of length batch_size * num_captions.
"""
# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: start", flush=True)

if hasattr(self, "hf_device_map"):
# preprocess for `accelerate`
self._preprocess_accelerate()
Expand All @@ -133,6 +137,9 @@ def generate(
else:
batch_size = 1

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: _preprocess_accelerate end", flush=True)

# image_id = pixel_values.cpu().numpy().tobytes()
# if image_id in self.cache_dict:
# language_model_inputs = self.cache_dict[image_id]
Expand All @@ -156,6 +163,9 @@ def generate(
image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: image_embeds end", flush=True)

query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_outputs = self.qformer(
query_embeds=query_tokens,
Expand All @@ -165,8 +175,14 @@ def generate(
)
query_output = query_outputs.last_hidden_state

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: query_outputs end", flush=True)

language_model_inputs = self.language_projection(query_output)

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: language_model_inputs end", flush=True)

language_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)
Expand All @@ -180,6 +196,9 @@ def generate(
attention_mask = torch.ones_like(input_ids)
attention_mask = attention_mask.to(language_attention_mask.device)

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: attention_mask end", flush=True)

# concatenate query embeddings with prompt embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = inputs_embeds.to(language_model_inputs.device)
Expand All @@ -191,6 +210,9 @@ def generate(
assert len(image_token_indexes) == pixel_values.shape[0]
# token format: (# text, # image)xN, # text

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: input_embeds end", flush=True)

for idx, image_token_index in enumerate(image_token_indexes):
end_index += image_token_index
inputs_embeds_with_images.append(
Expand All @@ -201,6 +223,9 @@ def generate(
attention_mask_with_images.append(language_attention_mask[idx][None])
start_index = end_index

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: xxx_with_images end", flush=True)

inputs_embeds_with_images.append(inputs_embeds[:, image_token_indexes[-1]:])
inputs_embeds = torch.cat(inputs_embeds_with_images, dim=1)
attention_mask_with_images.append(attention_mask[:, image_token_indexes[-1]:])
Expand All @@ -209,6 +234,9 @@ def generate(
inputs_embeds = inputs_embeds.to(self.language_model.lm_head.weight.dtype)
attention_mask = attention_mask.to(self.language_model.lm_head.weight.dtype)

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: llm generate start", flush=True)

if not self.use_prompt_cache or batch_size != 1:
outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
Expand All @@ -230,13 +258,19 @@ def generate(
past_key_values = outputs["past_key_values"]
self.register_prompt_cache(prompt_ids, past_key_values)

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: first llm generate end", flush=True)

prompt_length = self.prompt_id.shape[1]
if torch.all(input_ids[:, :prompt_length] == self.prompt_id):
past_key_values = self.prompt_key_values
else:
past_key_values = None
generate_kwargs["past_key_values"] = past_key_values

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: second llm generate start", flush=True)

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds[:, prompt_length:],
attention_mask=attention_mask[:, prompt_length:],
Expand All @@ -245,4 +279,10 @@ def generate(
)
outputs = outputs.logits

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: second llm generate end", flush=True)

# current_time = time.strftime("%H:%M:%S", time.localtime())
# print(f"{current_time}: model.generate: end", flush=True)

return outputs
Loading

0 comments on commit 284080f

Please sign in to comment.