-
Notifications
You must be signed in to change notification settings - Fork 6
/
baseline.py
117 lines (100 loc) · 4.41 KB
/
baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Benchmark offline inference throughput."""
import argparse
import json
import random
import time
from typing import List, Optional, Tuple
import torch
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from tqdm import tqdm
def run_hf(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: PreTrainedTokenizerBase,
trust_remote_code: bool,
) -> float:
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
llm = llm.cuda()
input_num_tokens = []
output_num_tokens = []
start = time.perf_counter()
for i in tqdm(range(len(requests))):
prompt, prompt_len, output_len = requests[i]
# Generate the sequences.
input_ids = tokenizer(prompt, return_tensors="pt",
padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=False,
num_return_sequences=1,
num_beams=1,
temperature=1.0,
top_p=1.0,
use_cache=True,
max_new_tokens=output_len,
)
# Include the decoding time.
tokenizer.decode(llm_outputs[0], skip_special_tokens=True)
input_num_tokens.append(len(input_ids[0]))
output_num_tokens.append(len(llm_outputs[0]))
end = time.perf_counter()
return end - start, input_num_tokens, output_num_tokens
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
if args.dataset is None:
# Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_samples)]
else:
with open(args.dataset) as f:
requests = json.load(f)
if args.num_samples is not None:
requests = requests[0:args.num_samples]
elapsed_time, input_num_tokens, output_num_tokens = run_hf(requests, args.model, tokenizer, args.trust_remote_code)
prompt_num_tokens = sum(prompt_len for prompt_len in input_num_tokens)
total_num_tokens = sum(output_len for output_len in output_num_tokens)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s \n"
f"Tokens/s: {total_num_tokens / elapsed_time:.2f} tokens/s \n"
f"Prompt_num_tokens:{prompt_num_tokens:.2f} tokens \n"
f"Total_num_tokens:{total_num_tokens:.2f} tokens \n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--dataset", type=str, default=None, help="Path to the dataset.")
parser.add_argument("--model", type=str, default="meta/llama2-70b")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument("--input-len", type=int, default=None, help="Input prompt length for each request")
parser.add_argument("--output-len", type=int, default=None, help="Output length for each request")
parser.add_argument("--num-samples", type=int, default=None, help="Number of first few samples used for inference test")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
if args.dataset is None:
assert args.input_len is not None
assert args.output_len is not None
else:
assert args.input_len is None
assert args.output_len is None
main(args)