Skip to content

Commit

Permalink
make mii benchmark support multiple models, some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii committed Jan 13, 2024
1 parent 6c31d8d commit ef8a360
Show file tree
Hide file tree
Showing 9 changed files with 472 additions and 254 deletions.
Empty file.
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import os
import time
import random
import argparse
import queue
import multiprocessing
import threading
from statistics import mean
from dataclasses import dataclass, asdict
from typing import List, Iterable
from pathlib import Path
from datetime import datetime
import numpy as np

from transformers import AutoTokenizer
Expand All @@ -20,51 +15,9 @@
import asyncio
import requests

from postprocess_results import get_summary, ResponseDetails

MAX_PROMPT_LENGTH = 4000
PROMPT_LENGTH_VAR = 0.3
MAX_NEW_TOKENS_VAR = 0.3

def parse_args():
parser = argparse.ArgumentParser(description="Benchmark MII services")
parser.add_argument("-k",
"--max_new_tokens",
type=int,
default=60,
help="min and max num tokens argument for huggingface")
parser.add_argument("-d",
"--deployment_name",
type=str,
default="benchmark_deployment")
parser.add_argument("-n",
"--num_queries",
type=int,
help="number of queries to run",
default=10)
parser.add_argument("-w",
"--warmup",
type=int,
help="number of queries for warming up",
default=1)
parser.add_argument("-c",
"--client_num",
type=int,
help="number of parallel client processes",
default=2)
parser.add_argument("-l",
"--prompt_length",
type=int,
default=2600)
parser.add_argument('--use_thread', action='store_true',
help='use thread to run parallel clients, otherwise use multiprocessing',
default=False)
parser.add_argument('--stream', action='store_true', default=True)
parser.add_argument('--vllm', action='store_true', default=False)
parser.add_argument('-o', '--out_json_path', type=Path, default=None)

args = parser.parse_args()
return args
from postprocess_results import ResponseDetails

from utils import parse_args, output_summary


def call_mii(client, input_tokens, max_new_tokens, stream):
Expand All @@ -85,11 +38,10 @@ def callback(response):
if stream:
output_tokens = []
client.generate(
input_tokens, max_new_tokens=max_new_tokens,
streaming_fn=callback)
input_tokens, max_new_tokens=max_new_tokens, streaming_fn=callback
)
else:
result = client.generate(
input_tokens, max_new_tokens=max_new_tokens)
result = client.generate(input_tokens, max_new_tokens=max_new_tokens)
output_tokens = result[0].generated_text

return ResponseDetails(
Expand All @@ -98,7 +50,8 @@ def callback(response):
start_time=start_time,
end_time=time.time(),
model_time=0,
token_gen_time=token_gen_time)
token_gen_time=token_gen_time,
)


def call_vllm(input_tokens, max_new_tokens, stream=True):
Expand All @@ -114,15 +67,19 @@ def call_vllm(input_tokens, max_new_tokens, stream=True):
"ignore_eos": False,
"stream": stream,
}

def clear_line(n: int = 1) -> None:
LINE_UP = '\033[1A'
LINE_CLEAR = '\x1b[2K'
LINE_UP = "\033[1A"
LINE_CLEAR = "\x1b[2K"
for _ in range(n):
print(LINE_UP, end=LINE_CLEAR, flush=True)

def get_streaming_response(response: requests.Response, time_last_token) -> Iterable[List[str]]:
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False,
delimiter=b"\0"):
def get_streaming_response(
response: requests.Response, time_last_token
) -> Iterable[List[str]]:
for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"][0]
Expand All @@ -149,20 +106,31 @@ def get_response(response: requests.Response) -> List[str]:
start_time=start_time,
end_time=time.time(),
model_time=0,
token_gen_time=token_gen_time)
token_gen_time=token_gen_time,
)
else:
output = get_response(response)
raise NotImplementedError("Not implemented for non-streaming")


def _run_parallel(deployment_name, warmup, barrier, query_queue, result_queue, client_num, stream, vllm):
def _run_parallel(
deployment_name,
warmup,
barrier,
query_queue,
result_queue,
num_clients,
stream,
vllm,
):
pid = os.getpid()
session_id = f"test_session_p{pid}_t{threading.get_ident()}"

event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
if not vllm:
import mii

client = mii.client(deployment_name)

barrier.wait()
Expand All @@ -178,7 +146,7 @@ def _run_parallel(deployment_name, warmup, barrier, query_queue, result_queue, c

barrier.wait()

time.sleep(random.uniform(0, client_num) * 0.01)
time.sleep(random.uniform(0, num_clients) * 0.01)
try:
while not query_queue.empty():
print(f"queue size: {query_queue.qsize()} ({pid})", flush=True)
Expand All @@ -197,16 +165,30 @@ def _run_parallel(deployment_name, warmup, barrier, query_queue, result_queue, c
print(f"Worker ({pid}) finished. session_id: {session_id}")


def run_client(client_num, deployment_name, prompt_length, max_new_tokens, num_queries, warmup, stream, vllm, use_thread=False):
def run_client(
num_clients,
model,
deployment_name,
mean_prompt_length,
mean_max_new_tokens,
num_requests,
warmup,
max_prompt_length,
prompt_length_var,
max_new_tokens_var,
stream,
vllm,
use_thread,
):
"""
Run MII client for benchmarking. The scenario is a bit complicated:
1. The main process puts `num_queries` queries into the input queue
1. The main process puts `num_requests` queries into the input queue
2. Each client runs `warmup` iterations () taking the queries from the input queue
3. --- barrier ---
4. The main process marks the start time
5a. All clients send `num_queries' query in total and put the results into the result queue
5a. All clients send `num_requests' query in total and put the results into the result queue
5b. The main process takes the results from the result queue (in parallel with 5a)
6. The main process marks the end time after receiving `num_queries' results
6. The main process marks the end time after receiving `num_requests' results
"""

if use_thread:
Expand All @@ -218,23 +200,44 @@ def run_client(client_num, deployment_name, prompt_length, max_new_tokens, num_q
barrier_cls = multiprocessing.Barrier
queue_cls = multiprocessing.Queue

barrier = barrier_cls(client_num + 1)
barrier = barrier_cls(num_clients + 1)
query_queue = queue_cls()
result_queue = queue_cls()

processes = [runnable_cls(target=_run_parallel,
args=(deployment_name, warmup, barrier, query_queue, result_queue, client_num, stream, vllm))
for i in range(client_num)]
processes = [
runnable_cls(
target=_run_parallel,
args=(
deployment_name,
warmup,
barrier,
query_queue,
result_queue,
num_clients,
stream,
vllm,
),
)
for i in range(num_clients)
]
for p in processes:
p.start()

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained(model)
query_generator = RandomQueryGenerator(all_text, tokenizer, seed=42)
MAX_PROMPT_LENGTH = 4000
request_text = query_generator.get_random_request_text(prompt_length, prompt_length*PROMPT_LENGTH_VAR, MAX_PROMPT_LENGTH, num_queries + warmup*client_num)
request_text = query_generator.get_random_request_text(
mean_prompt_length,
mean_prompt_length * prompt_length_var,
max_prompt_length,
num_requests + warmup * num_clients,
)

for t in request_text:
req_max_new_tokens = int(np.random.normal(max_new_tokens, MAX_NEW_TOKENS_VAR*max_new_tokens))
req_max_new_tokens = int(
np.random.normal(
mean_max_new_tokens, max_new_tokens_var * mean_max_new_tokens
)
)
query_queue.put((t, req_max_new_tokens))

# Tokenizers must be initialized after fork.
Expand All @@ -245,41 +248,37 @@ def run_client(client_num, deployment_name, prompt_length, max_new_tokens, num_q
barrier.wait()

response_details = []
while len(response_details) < num_queries:
while len(response_details) < num_requests:
res = result_queue.get()
# vLLM returns concatinated tokens
if vllm:
all_tokens = tokenizer.tokenize(res.generated_tokens)
res.generated_tokens = all_tokens[len(tokenizer.tokenize(res.prompt)):]
res.generated_tokens = all_tokens[len(tokenizer.tokenize(res.prompt)) :]
response_details.append(res)

return response_details


if __name__ == "__main__":
args = parse_args()
print(args)
args = parse_args(client_args=True)

if args.out_json_path is not None and not args.out_json_path.parent.exists():
raise ValueError(f"Parent directory of {args.out_json_path}")

response_details = run_client(args.client_num, args.deployment_name,
args.prompt_length,
args.max_new_tokens, args.num_queries, args.warmup,
args.stream, args.vllm, args.use_thread)

args_dict = vars(args)
ps = get_summary(args_dict, response_details)
print(f"Deployment: {args.deployment_name} Clients: {args.client_num}, "
+ f"Prompt (mean): {args.prompt_length} tokens, "
+ f"Generation (mean): {args.max_new_tokens} tokens, "
+ f"Query throughput: {ps.throughput:.3f} queries/s, "
+ f"Token throughput (total): {ps.tokens_per_sec:.3f} tokens/s, "
+ f"Query latency: {ps.latency:.3f} s, "
+ f"Token generation latency: {ps.token_gen_latency:.3f} s/token, "
+ f"First token received: {ps.first_token_latency:.3f} s")

if args.out_json_path is not None:
with open(args.out_json_path, "w") as f:
args_dict["out_json_path"] = str(args.out_json_path) # Path is not JSON serializable
data = {"args": args_dict, "time": str(datetime.now()), "response_details": [asdict(r) for r in response_details]}
json.dump(data, f, indent=2)
response_details = run_client(
num_clients=args.num_clients,
model=args.model,
deployment_name=args.deployment_name,
mean_prompt_length=args.mean_prompt_length,
mean_max_new_tokens=args.mean_max_new_tokens,
num_requests=args.num_requests,
warmup=args.warmup,
max_prompt_length=args.max_prompt_length,
prompt_length_var=args.prompt_length_var,
max_new_tokens_var=args.max_new_tokens_var,
stream=args.stream,
vllm=args.vllm,
use_thread=args.use_thread,
)

output_summary(args, response_details)
Loading

0 comments on commit ef8a360

Please sign in to comment.