Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix chunked prefill #766

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 39 additions & 60 deletions lightllm/models/internvl/img_process.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,56 @@
import torch
import torch.nn.functional as F
from PIL import Image
import math
from torchvision import transforms as T
from torchvision.transforms.functional import InterpolationMode


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
def find_closest_aspect_ratio(width, height, min_num=1, max_num=6, image_size=448):
"""
Find the closest aspect ratio from a list of target ratios to match the given aspect ratio.
If the difference is the same, use the area to decide the better ratio.
"""
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
# Compare areas to decide the better ratio when the difference is the same
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
assert min_num == 1
log_ratio = math.log(width / height)
ratio = width * height / (image_size * image_size)
multiple = min(math.ceil(ratio), max_num)
if multiple <= 1:
return [1, 1]
candidate_split_grids_nums = []
for i in [multiple - 1, multiple, multiple + 1]:
if i > max_num:
continue
candidate_split_grids_nums.append(i)

candidate_grids = []
for split_grids_nums in candidate_split_grids_nums:
m = 1
while m <= split_grids_nums:
if split_grids_nums % m == 0:
candidate_grids.append([m, split_grids_nums // m])
m += 1
best_grid = [1, 1]
min_error = float("inf")
for grid in candidate_grids:
error = abs(log_ratio - math.log(grid[0] / grid[1]))
if error < min_error:
best_grid = grid
min_error = error

return best_grid


def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=True):
"""
Preprocess the image dynamically by finding the closest aspect ratio,
resizing the image, and splitting it into smaller blocks.
Optionally add a thumbnail version of the image.
"""
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height

# Calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

# Find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)

# Calculate the target width and height
original_width, original_height = image.size
target_aspect_ratio = find_closest_aspect_ratio(original_width, original_height, min_num, max_num, image_size)
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

# Resize the image to the target dimensions
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
Expand All @@ -63,40 +60,22 @@ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnai
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# Split the image into blocks
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)

assert len(processed_images) == blocks

# Optionally add a thumbnail version of the image
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)

return processed_images


def get_image_patch(orign_width, orign_height, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
def get_image_patch(orign_width, orign_height, min_num=1, max_num=6, image_size=448, use_thumbnail=True):
"""
Calculate the number of image patches based on the closest aspect ratio
and the given width and height of the original image.
"""
aspect_ratio = orign_width / orign_height

# calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orign_width, orign_height, image_size)

target_aspect_ratio = find_closest_aspect_ratio(orign_width, orign_height, min_num, max_num, image_size)
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
if use_thumbnail and blocks != 1:
blocks += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@


class ChunkedPrefillBackend(ModeBackend):
def __init__(self) -> None:
def __init__(self, is_multimodal) -> None:
super().__init__()
self.is_multimodal = is_multimodal
self.forward_step = 0
args = get_env_start_args()
self.max_wait_step = args.router_max_wait_tokens
Expand All @@ -31,7 +32,7 @@ def decode(self):
self.forward_batch(kwargs, run_reqs)
if len(run_reqs) == 0 or self.forward_step % self.max_wait_step == 0:
# run prefill
kwargs, run_reqs = prepare_prefill_inputs(g_infer_context.infer_req_ids)
kwargs, run_reqs = prepare_prefill_inputs(g_infer_context.infer_req_ids, self.is_multimodal)
self.forward_batch(kwargs, run_reqs)
self.forward_step += 1
return
Expand Down
3 changes: 2 additions & 1 deletion lightllm/server/router/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,14 @@ def init_model(self, kvargs):
is_xgrammar_constraint_mode = False
is_prefill_node = False
is_decode_node = False
is_multimodal = kvargs.get("enable_multimodal", False)
# use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False)
if is_prefill_node:
self.backend = ContinuesBatchBackendForPrefillNode(self.info_queue, self.mem_queue)
elif is_decode_node:
self.backend = ContinuesBatchBackendForDecodeNode(self.info_queue, self.mem_queue)
elif enable_chunked_prefill:
self.backend = ChunkedPrefillBackend()
self.backend = ChunkedPrefillBackend(is_multimodal)
elif use_reward_model:
self.backend = RewardModelBackend()
elif return_all_prompt_logprobs:
Expand Down