diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py index 04f178504..8cdf5644d 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py @@ -258,8 +258,17 @@ def main(): # the LN that precedes it. force_optimize_params = [] if "bigscience/bloom-" in args.model_name_or_path: - torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight) - torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias) + zero_init_enabled = (args.zero_stage == 3) + params = [ + rm_model.rwtranrsformer.ln_f.weight, + rm_model.rwtranrsformer.ln_f.bias + ] + with deepspeed.zero.GatheredParameters(params, + modifier_rank=0, + enabled=zero_init_enabled): + if deepspeed.comm.get_rank() == 0 or not zero_init_enabled: + torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight) + torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias) force_optimize_params.extend( ['rwtransformer.ln_f.weight', 'rwtransformer.ln_f.bias']) diff --git a/benchmarks/inference/mii/README.md b/benchmarks/inference/mii/README.md index d9e475cdb..092ac4867 100644 --- a/benchmarks/inference/mii/README.md +++ b/benchmarks/inference/mii/README.md @@ -2,38 +2,59 @@ ## Run the Benchmark -The benchmarking scripts use DeepSpeed-FastGen in the persistent mode. -You can start the server with the command below: +The benchmarking scripts use DeepSpeed-FastGen in the persistent mode. You can +run the benchmark using `run_benchmark.py`. This script will run several +combinations of inference servers and clients with different tensor parallel +size, number of model replicas (MII only), number of clients, prompt length, and +max new tokens values. By default, the benchmark will run with the `meta-llama/Llama-2-7b-hf` model. ```bash -python server.py [options] start +python run_benchmark.py ``` -Use the -h option to view all available options. To stop the server, use this command: +Use the -h option to view all available options. Several models have pre-defined +default values, including `meta-llama/Llama-2-{7|13|70}b-hf`, +`tiiuae/falcon-{40|180}B`, `microsoft/phi-2`, and `mistralai/Mixtral-8x7B-v0.1`. +These defaults can be overridden if provided to the `run_benchmark.py` script. +For example, to run `meta-llama/Llama-13b-hf` with a tensor parallel size of `1` +and `2` (instead of the default `1`, `2`, and `4`): -```bash -python server.py stop +```bash +python run_benchmark.py --tp_size 1 2 ``` -Once the server is up and running, initiate the client using the command below. The -h option will display all the possible options. +By default the benchmark runs with DeepSpeed-MII as the backend inference +server. To change the backend to vLLM, provide the `--vllm` flag: ```bash -python run_benchmark_client.py [options] +python run_benchmark.py --vllm ``` -The run_all.sh script performs benchmarks across various model sizes and client numbers. For VLLM benchmarks, use the run_all_vllm.sh script. Results are logged in a directory named logs.[BENCHMARK_PARAMETERS]. +The run_all.sh script performs benchmarks across various models, client numbers, +tensor parallel sizes, etc. This script is intended to be run on a system with +8xA100 (80GB) GPUs available. It will run all the benchmarks (including vLLM) +and collect the data used in our [DeepSpeed-Fastgen +blogs](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen). +Results are collected in `./results/`. ## Analyze the Benchmark Results -The scripts mentioned below were used for generating the plots featured in our blog. Specify the root directory for log files using --log_dir. +The scripts mentioned below were used for generating the plots featured in our +blog. Specify the root directory for log files using `--log_dir`. The generated +figures will be saved to `./plots/` -- `plot_th_lat.py`: This script generates charts for throughput and latency across different model sizes and client counts. -- `plot_effective_throughput.py`: Use this to chart effective throughput. -- `plot_latency_percentile.py`: This script will plot the 50th, 90th, and 95th percentile latencies. +- `src/plot_th_lat.py`: This script generates charts for throughput and latency across different model sizes and client counts. +- `src/plot_effective_throughput.py`: Use this to chart effective throughput. +- `src/plot_latency_percentile.py`: This script will plot the 50th, 90th, and 95th percentile latencies. ## Running an End-to-End Example -To quickly experience the end-to-end process of running our benchmark and getting results, you can use the `run_example.sh`. This script is designed to execute the benchmark with a specific configuration. The plots below will be generated in the charts directory. These plots show the performance as depicted in figure 8 of our blog [post.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen#f-other-hardware-platforms) +To quickly experience the end-to-end process of running our benchmark and +getting results, you can use the `run_example.sh`. This script is designed to +execute the benchmark with a specific configuration. The plots below will be +generated in the `./plots/` directory. These plots show the performance as +depicted in figure 8 of our blog +[post.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen#f-other-hardware-platforms) ```bash bash run_example.sh diff --git a/benchmarks/inference/mii/plot_th_lat.py b/benchmarks/inference/mii/plot_th_lat.py deleted file mode 100644 index e99dc5a3e..000000000 --- a/benchmarks/inference/mii/plot_th_lat.py +++ /dev/null @@ -1,116 +0,0 @@ -import glob -import matplotlib.pyplot as plt -import argparse -from pathlib import Path -import numpy as np -import pdb -from postprocess_results import read_json, get_summary - -bs = 768 - -tp_sizes_test = { - "7b": [1] -} - -tp_sizes_all = { - "7b": [1], - "70b": [4, 8], -} - -prompt_gen_pairs_test = [ - (2600, 60) -] - -prompt_gen_pairs_all = [ - (1200, 60), - (1200, 128), - (2600, 60), - (2600, 128), -] - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--test", action="store_true") - parser.add_argument("--no_vllm", action="store_true") - parser.add_argument("--log_dir", type=Path, default=".") - parser.add_argument("--out_dir", type=Path, default="charts/throughput_latency") - args = parser.parse_args() - return args - - -def extract_values(file_pattern): - files = glob.glob(file_pattern) - - print(f"Found {len(files)}") - print('\n'.join(files)) - - clients = [] - throughputs = [] - latencies = [] - for f in files: - prof_args, response_details = read_json(f) - summary = get_summary(prof_args, response_details) - clients.append(prof_args["client_num"]) - throughputs.append(summary.throughput) - latencies.append(summary.latency) - - return clients, throughputs, latencies - - -def output_charts(model_size, tp, bs, prompt, gen, log_dir, out_dir): - if not log_dir.exists(): - print(f"Log directory {log_dir} does not exist") - return - - if not out_dir.exists(): - out_dir.mkdir(parents=True, exist_ok=True) - - mii_file_pattern = f"{log_dir}/logs.llama2-{model_size}-tp{tp}-b{bs}/llama2-{model_size}-tp{tp}-b{bs}_c*_p{prompt}_g{gen}.json" - if not args.no_vllm: - vllm_file_pattern = f"{log_dir}/logs.vllm-llama2-{model_size}-tp{tp}/vllm-llama2-{model_size}-tp{tp}_c*_p{prompt}_g{gen}.json" - - _, mii_throughputs, mii_latencies = extract_values(mii_file_pattern) - if not args.no_vllm: - _, vllm_throughputs, vllm_latencies = extract_values(vllm_file_pattern) - - # Plotting the scatter plot - plt.figure(figsize=(6, 4)) - - if not args.no_vllm: - plt.scatter(vllm_throughputs, vllm_latencies, label=f"vLLM", marker="x", color="orange") - fit_vllm_x_list = np.arange(min(vllm_throughputs), max(vllm_throughputs), 0.01) - vllm_vllm_model = np.polyfit(vllm_throughputs, vllm_latencies, 3) - vllm_model_fn = np.poly1d(vllm_vllm_model) - plt.plot(fit_vllm_x_list, vllm_model_fn(fit_vllm_x_list), color="orange", alpha=0.5, linestyle="--") - - plt.scatter(mii_throughputs, mii_latencies, label=f"DeepSpeed FastGen", marker="o", color="blue") - fit_mii_x_list = np.arange(min(mii_throughputs), max(mii_throughputs), 0.01) - mii_fit_model = np.polyfit(mii_throughputs, mii_latencies, 3) - mii_model_fn = np.poly1d(mii_fit_model) - plt.plot(fit_mii_x_list, mii_model_fn(fit_mii_x_list), color="blue", alpha=0.5, linestyle="--") - - plt.title(f'Model Llama 2 {model_size.upper()}, Prompt: {prompt}, Generation: {gen}, TP: {tp}') - plt.xlabel('Throughput (queries/s)', fontsize=14) - plt.ylabel('Latency', fontsize=14) - plt.legend() - plt.grid(True) - plt.tight_layout() - out_file = out_dir / f"th_lat_curve_llama{model_size}_tp{tp}_p{prompt}g{gen}.png" - print(f"Saving {out_file}") - plt.savefig(out_file) - - -if __name__ == "__main__": - args = get_args() - if args.test: - tp_sizes = tp_sizes_test - prompt_gen_pairs = prompt_gen_pairs_test - else: - tp_sizes = tp_sizes_all - prompt_gen_pairs = prompt_gen_pairs_test_all - - for model_size, tps in tp_sizes.items(): - for tp in tps: - for prompt, gen in prompt_gen_pairs: - output_charts(model_size, tp, bs, prompt, gen, args.log_dir, args.out_dir) - diff --git a/benchmarks/inference/mii/requirements.txt b/benchmarks/inference/mii/requirements.txt new file mode 100644 index 000000000..7ac014ef8 --- /dev/null +++ b/benchmarks/inference/mii/requirements.txt @@ -0,0 +1,5 @@ +transformers +matplotlib +deepspeed-mii>=0.2.0 +vllm>=0.2.7 +numpy \ No newline at end of file diff --git a/benchmarks/inference/mii/run_all.sh b/benchmarks/inference/mii/run_all.sh index ca504a6c9..095b3ae12 100644 --- a/benchmarks/inference/mii/run_all.sh +++ b/benchmarks/inference/mii/run_all.sh @@ -1,25 +1,15 @@ -RAGGED_BATCH_SIZE=768 -PARAM_SIZES=(7b 13b 70b) +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 -declare -A TP_SIZES -TP_SIZES["7b"]="1" -TP_SIZES["13b"]="1:2:4" -TP_SIZES["70b"]="4:8" +# DeepSpeed Team -for PARAM_SIZE in ${PARAM_SIZES[@]}; do - - IFS=':' read -ra TP_VALUES <<< ${TP_SIZES[${PARAM_SIZE}]} - for TP in ${TP_VALUES[@]}; do - DEPLOYMENT_NAME=llama2-${PARAM_SIZE}-tp${TP}-b${RAGGED_BATCH_SIZE} - python server.py --model_name meta-llama/Llama-2-${PARAM_SIZE}-hf -d ${DEPLOYMENT_NAME} -m ${TP} -b ${RAGGED_BATCH_SIZE} start +MODELS=(meta-llama/Llama-2-7b-hf meta-llama/Llama-2-13b-hf meta-llama/Llama-2-70b-hf tiiuae/falcon-40B tiiuae/falcon-180B microsoft/phi-2 mistralai/Mixtral-8x7B-v0.1) - DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=2600 MAX_NEW_TOKENS=60 bash ./run_benchmark_client.sh - DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=2600 MAX_NEW_TOKENS=128 bash ./run_benchmark_client.sh - DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=1200 MAX_NEW_TOKENS=60 bash ./run_benchmark_client.sh - DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=1200 MAX_NEW_TOKENS=128 bash ./run_benchmark_client.sh - - echo "Stopping server" - python server.py -d ${DEPLOYMENT_NAME} stop - sleep 120 - done +for MODEL in ${MODELS[@]}; do + python ./run_benchmark.py --model ${MODEL} --stream + python ./run_benchmark.py --model ${MODEL} --stream --vllm done + +# Extra runs for Mixtral with non-default settings +python ./run_benchmark.py --model mistralai/Mixtral-8x7B-v0.1 --stream --tp_size 4 --mean_prompt_length 500 --mean_max_new_tokens 150 500 1024 +python ./run_benchmark.py --model mistralai/Mixtral-8x7B-v0.1 --stream --tp_size 4 --mean_prompt_length 500 --mean_max_new_tokens 150 500 1024 --vllm \ No newline at end of file diff --git a/benchmarks/inference/mii/run_all_replica.sh b/benchmarks/inference/mii/run_all_replica.sh deleted file mode 100644 index b3fba0408..000000000 --- a/benchmarks/inference/mii/run_all_replica.sh +++ /dev/null @@ -1,25 +0,0 @@ -RAGGED_BATCH_SIZE=768 -PARAM_SIZES=(7b) -REPLICA_NUMS=(1) - -declare -A TP_SIZES -TP_SIZES["7b"]="4" -TP_SIZES["13b"]="1" -TP_SIZES["70b"]="4" - -for PARAM_SIZE in ${PARAM_SIZES[@]}; do - IFS=':' read -ra TP_VALUES <<< ${TP_SIZES[${PARAM_SIZE}]} - for TP in ${TP_VALUES[@]}; do - for REPL in ${REPLICA_NUMS[@]}; do - DEPLOYMENT_NAME=llama2-${PARAM_SIZE}-tp${TP}-b${RAGGED_BATCH_SIZE}_repl${REPL} - python server.py --model_name meta-llama/Llama-2-${PARAM_SIZE}-hf -d ${DEPLOYMENT_NAME} -m ${TP} -r ${REPL} -b ${RAGGED_BATCH_SIZE} start - - REQUEST_NUM=$((256 * ${REPL})) - DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=2600 MAX_NEW_TOKENS=60 CLIENT_NUMS=$((16 * ${REPL})) REQUEST_NUM=$((256 * ${REPL})) bash ./run_bench_client_num.sh - - echo "Stopping server" - python server.py -d ${DEPLOYMENT_NAME} stop - sleep 120 - done - done -done diff --git a/benchmarks/inference/mii/run_all_vllm.sh b/benchmarks/inference/mii/run_all_vllm.sh deleted file mode 100644 index 572377f13..000000000 --- a/benchmarks/inference/mii/run_all_vllm.sh +++ /dev/null @@ -1,26 +0,0 @@ -RAGGED_BATCH_SIZE=768 -PARAM_SIZES=(7b 13b 70b) - -declare -A TP_SIZES -TP_SIZES["7b"]="1" -TP_SIZES["13b"]="1:2:4" -TP_SIZES["70b"]="4:8" - -for PARAM_SIZE in ${PARAM_SIZES[@]}; do - - IFS=':' read -ra TP_VALUES <<< ${TP_SIZES[${PARAM_SIZE}]} - for TP in ${TP_VALUES[@]}; do - DEPLOYMENT_NAME=vllm-llama2-${PARAM_SIZE}-tp${TP} - python -m vllm.entrypoints.api_server --host 127.0.0.1 --port 26500 --tensor-parallel-size ${TP} --model meta-llama/Llama-2-${PARAM_SIZE}-hf & - sleep 60 - - DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=2600 MAX_NEW_TOKENS=60 VLLM="--vllm" bash ./run_benchmark_client.sh - DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=2600 MAX_NEW_TOKENS=128 VLLM="--vllm" bash ./run_benchmark_client.sh - DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=1200 MAX_NEW_TOKENS=60 VLLM="--vllm" bash ./run_benchmark_client.sh - DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=1200 MAX_NEW_TOKENS=128 VLLM="--vllm" bash ./run_benchmark_client.sh - - echo "Stopping server" - pkill -u ${USER} -f vllm.entrypoints.api_server - sleep 30 - done -done diff --git a/benchmarks/inference/mii/run_benchmark.py b/benchmarks/inference/mii/run_benchmark.py new file mode 100644 index 000000000..96e88155f --- /dev/null +++ b/benchmarks/inference/mii/run_benchmark.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from src.client import run_client +from src.server import start_server, stop_server +from src.utils import ( + get_args_product, + parse_args, + print_summary, + results_exist, + save_json_results, + CLIENT_PARAMS, + SERVER_PARAMS, +) + + +def run_benchmark() -> None: + args = parse_args(server_args=True, client_args=True) + + for server_args in get_args_product(args, which=SERVER_PARAMS): + start_server(server_args) + + for client_args in get_args_product(server_args, which=CLIENT_PARAMS): + if results_exist(client_args) and not args.overwrite_results: + print( + f"Found existing results and skipping current setting. To ignore existing results, use --overwrite_results" + ) + continue + + response_details = run_client(client_args) + print_summary(client_args, response_details) + save_json_results(client_args, response_details) + + stop_server(server_args) + + +if __name__ == "__main__": + run_benchmark() diff --git a/benchmarks/inference/mii/run_benchmark_client.sh b/benchmarks/inference/mii/run_benchmark_client.sh deleted file mode 100644 index 318e9092e..000000000 --- a/benchmarks/inference/mii/run_benchmark_client.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -DEPLOYMENT_NAME=${DEPLOYMENT_NAME:-llama2-7b} -VLLM=${VLLM:-""} - -CLIENT_NUMS=${CLIENT_NUMS:-1 2 4 6 8 12 16 20 24 28 32} -MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-60} -PROMPT_LENGTH=${PROMPT_LENGTH:-3072} -REQUEST_NUM=${REQUEST_NUM:-512} - -LOG_DIR=logs.${DEPLOYMENT_NAME} -mkdir -p ${LOG_DIR} - -for client_num in ${CLIENT_NUMS[@]}; do - RESULT_FILE=${DEPLOYMENT_NAME}_c${client_num}_p${PROMPT_LENGTH}_g${MAX_NEW_TOKENS}.json - - python run_benchmark_client.py -w 1 \ - -d ${DEPLOYMENT_NAME} -n ${REQUEST_NUM} -c ${client_num} \ - -k ${MAX_NEW_TOKENS} -l ${PROMPT_LENGTH} \ - -o ${LOG_DIR}/${RESULT_FILE} \ - ${VLLM} --stream \ - 2>&1 | tee ${LOG_DIR}/bench_client_num_c${client_num}_p${PROMPT_LENGTH}_g${MAX_NEW_TOKENS}.log -done diff --git a/benchmarks/inference/mii/run_example.sh b/benchmarks/inference/mii/run_example.sh index ece8393ed..e80253828 100644 --- a/benchmarks/inference/mii/run_example.sh +++ b/benchmarks/inference/mii/run_example.sh @@ -1,19 +1,19 @@ -### Run the server -RAGGED_BATCH_SIZE=768 -PARAM_SIZES=(7b) -DEPLOYMENT_NAME=llama2-7b-tp1-b768 -python server.py --model_name meta-llama/Llama-2-7b-hf -d llama2-7b-tp1-b768 -m 1 -b 768 start +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 -### This command will run the client with 60 generation steps and input prompt length of 2600 -DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=2600 MAX_NEW_TOKENS=60 bash ./run_benchmark_client.sh +# DeepSpeed Team -### Stop the server -echo "Stopping server" -python server.py -d ${DEPLOYMENT_NAME} stop -sleep 120 +# Run benchmark +python ./run_benchmark.py \ + --model meta-llama/Llama-2-7b-hf \ + --tp_size 1 \ + --num_replicas 1 \ + --max_ragged_batch_size 768 \ + --mean_prompt_length 2600 \ + --mean_max_new_tokens 60 \ + --stream ### Gernerate the plots -python plot_th_lat.py --log_dir . --test --no_vllm -python plot_effective_throughput.py --log_dir . --test --no_vllm +python ./src/plot_th_lat.py -echo "Find the plots in the charts directory and the logs inside logs.llama2-7b-tp1-b768" +echo "Find figures in ./plots/ and log outputs in ./results/" \ No newline at end of file diff --git a/benchmarks/inference/mii/server.py b/benchmarks/inference/mii/server.py deleted file mode 100644 index 2e6164187..000000000 --- a/benchmarks/inference/mii/server.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team -import mii -import argparse - -from mii.constants import DeploymentType - -from deepspeed.inference import RaggedInferenceEngineConfig, DeepSpeedTPConfig -from deepspeed.inference.v2.ragged import DSStateManagerConfig - -def start_server(model_name, - deployment_name, - task, - tensor_parallel, - replica_num, - max_ragged_batch_size): - tp_config = DeepSpeedTPConfig(tp_size=tensor_parallel) - mgr_config = DSStateManagerConfig(max_ragged_batch_size=max_ragged_batch_size, max_ragged_sequence_count=max_ragged_batch_size) - inference_config = RaggedInferenceEngineConfig(tensor_parallel=tp_config, - state_manager=mgr_config) - - mii.serve( - model_name, - deployment_name=deployment_name, - tensor_parallel=tensor_parallel, - task=task, - inference_engine_config=inference_config, - replica_num=replica_num - ) - -def stop_server(deployment_name): - mii.client(deployment_name).terminate_server() - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--model_name", - type=str, - default="meta-llama/Llama-2-7b-hf", - help="Name of the model in the model_files to benchmark") - parser.add_argument("-d", - "--deployment_name", - type=str, - default="benchmark_deployment") - parser.add_argument("-t", "--task", type=str, - help="Task type. Currently only text-generation is supported", - default="text-generation") - parser.add_argument("-m", - "--tensor_parallel", - type=int, - help="Degree of tensor (model) parallelism", - default=1) - parser.add_argument("-b", - "--ragged_batch_size", - type=int, - help="Max batch size for ragged batching", - default=768) - parser.add_argument("-r", - "--replica_num", - type=int, - help="Number of replicas for load balancing", - default=1) - parser.add_argument("cmd", help="start, stop, or restart") - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - - if args.cmd == "start": - start_server(args.model_name, - args.deployment_name, - args.task, - args.tensor_parallel, - args.replica_num, - args.ragged_batch_size) - elif args.cmd == "stop": - print("running stop") - stop_server(args.deployment_name) - else: - raise ValueError(f"Unknown command: {args.cmd}") diff --git a/benchmarks/inference/mii/src/__init__.py b/benchmarks/inference/mii/src/__init__.py new file mode 100644 index 000000000..208299fb8 --- /dev/null +++ b/benchmarks/inference/mii/src/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/benchmarks/inference/mii/run_benchmark_client.py b/benchmarks/inference/mii/src/client.py similarity index 51% rename from benchmarks/inference/mii/run_benchmark_client.py rename to benchmarks/inference/mii/src/client.py index caf20351e..c440d0b63 100644 --- a/benchmarks/inference/mii/run_benchmark_client.py +++ b/benchmarks/inference/mii/src/client.py @@ -1,70 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import asyncio +import json +import multiprocessing import os -import time -import random -import argparse import queue -import multiprocessing +import random +import requests import threading -from statistics import mean -from dataclasses import dataclass, asdict +import time from typing import List, Iterable -from pathlib import Path -from datetime import datetime -import numpy as np +import numpy as np from transformers import AutoTokenizer -from random_query_generator import RandomQueryGenerator -from sample_input import all_text -import time -import json -import asyncio -import requests -from postprocess_results import get_summary, ResponseDetails - -MAX_PROMPT_LENGTH = 4000 -PROMPT_LENGTH_VAR = 0.3 -MAX_NEW_TOKENS_VAR = 0.3 - -def parse_args(): - parser = argparse.ArgumentParser(description="Benchmark MII services") - parser.add_argument("-k", - "--max_new_tokens", - type=int, - default=60, - help="min and max num tokens argument for huggingface") - parser.add_argument("-d", - "--deployment_name", - type=str, - default="benchmark_deployment") - parser.add_argument("-n", - "--num_queries", - type=int, - help="number of queries to run", - default=10) - parser.add_argument("-w", - "--warmup", - type=int, - help="number of queries for warming up", - default=1) - parser.add_argument("-c", - "--client_num", - type=int, - help="number of parallel client processes", - default=2) - parser.add_argument("-l", - "--prompt_length", - type=int, - default=2600) - parser.add_argument('--use_thread', action='store_true', - help='use thread to run parallel clients, otherwise use multiprocessing', - default=False) - parser.add_argument('--stream', action='store_true', default=True) - parser.add_argument('--vllm', action='store_true', default=False) - parser.add_argument('-o', '--out_json_path', type=Path, default=None) - - args = parser.parse_args() - return args +from .postprocess_results import ResponseDetails +from .random_query_generator import RandomQueryGenerator +from .sample_input import all_text +from .utils import parse_args, print_summary, get_args_product, CLIENT_PARAMS def call_mii(client, input_tokens, max_new_tokens, stream): @@ -85,11 +41,10 @@ def callback(response): if stream: output_tokens = [] client.generate( - input_tokens, max_new_tokens=max_new_tokens, - streaming_fn=callback) + input_tokens, max_new_tokens=max_new_tokens, streaming_fn=callback + ) else: - result = client.generate( - input_tokens, max_new_tokens=max_new_tokens) + result = client.generate(input_tokens, max_new_tokens=max_new_tokens) output_tokens = result[0].generated_text return ResponseDetails( @@ -98,7 +53,8 @@ def callback(response): start_time=start_time, end_time=time.time(), model_time=0, - token_gen_time=token_gen_time) + token_gen_time=token_gen_time, + ) def call_vllm(input_tokens, max_new_tokens, stream=True): @@ -114,15 +70,19 @@ def call_vllm(input_tokens, max_new_tokens, stream=True): "ignore_eos": False, "stream": stream, } + def clear_line(n: int = 1) -> None: - LINE_UP = '\033[1A' - LINE_CLEAR = '\x1b[2K' + LINE_UP = "\033[1A" + LINE_CLEAR = "\x1b[2K" for _ in range(n): print(LINE_UP, end=LINE_CLEAR, flush=True) - def get_streaming_response(response: requests.Response, time_last_token) -> Iterable[List[str]]: - for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, - delimiter=b"\0"): + def get_streaming_response( + response: requests.Response, time_last_token + ) -> Iterable[List[str]]: + for chunk in response.iter_lines( + chunk_size=8192, decode_unicode=False, delimiter=b"\0" + ): if chunk: data = json.loads(chunk.decode("utf-8")) output = data["text"][0] @@ -149,13 +109,23 @@ def get_response(response: requests.Response) -> List[str]: start_time=start_time, end_time=time.time(), model_time=0, - token_gen_time=token_gen_time) + token_gen_time=token_gen_time, + ) else: output = get_response(response) raise NotImplementedError("Not implemented for non-streaming") -def _run_parallel(deployment_name, warmup, barrier, query_queue, result_queue, client_num, stream, vllm): +def _run_parallel( + deployment_name, + warmup, + barrier, + query_queue, + result_queue, + num_clients, + stream, + vllm, +): pid = os.getpid() session_id = f"test_session_p{pid}_t{threading.get_ident()}" @@ -163,6 +133,7 @@ def _run_parallel(deployment_name, warmup, barrier, query_queue, result_queue, c asyncio.set_event_loop(event_loop) if not vllm: import mii + client = mii.client(deployment_name) barrier.wait() @@ -178,7 +149,7 @@ def _run_parallel(deployment_name, warmup, barrier, query_queue, result_queue, c barrier.wait() - time.sleep(random.uniform(0, client_num) * 0.01) + time.sleep(random.uniform(0, num_clients) * 0.01) try: while not query_queue.empty(): print(f"queue size: {query_queue.qsize()} ({pid})", flush=True) @@ -197,18 +168,33 @@ def _run_parallel(deployment_name, warmup, barrier, query_queue, result_queue, c print(f"Worker ({pid}) finished. session_id: {session_id}") -def run_client(client_num, deployment_name, prompt_length, max_new_tokens, num_queries, warmup, stream, vllm, use_thread=False): +def run_client(args): """ Run MII client for benchmarking. The scenario is a bit complicated: - 1. The main process puts `num_queries` queries into the input queue + 1. The main process puts `num_requests` queries into the input queue 2. Each client runs `warmup` iterations () taking the queries from the input queue 3. --- barrier --- 4. The main process marks the start time - 5a. All clients send `num_queries' query in total and put the results into the result queue + 5a. All clients send `num_requests' query in total and put the results into the result queue 5b. The main process takes the results from the result queue (in parallel with 5a) - 6. The main process marks the end time after receiving `num_queries' results + 6. The main process marks the end time after receiving `num_requests' results """ + # Unpack arguments + model = args.model + deployment_name = args.deployment_name + mean_prompt_length = args.mean_prompt_length + mean_max_new_tokens = args.mean_max_new_tokens + num_clients = args.num_clients + num_requests = args.num_requests + warmup = args.warmup + max_prompt_length = args.max_prompt_length + prompt_length_var = args.prompt_length_var + max_new_tokens_var = args.max_new_tokens_var + stream = args.stream + vllm = args.vllm + use_thread = args.use_thread + if use_thread: runnable_cls = threading.Thread barrier_cls = threading.Barrier @@ -218,23 +204,44 @@ def run_client(client_num, deployment_name, prompt_length, max_new_tokens, num_q barrier_cls = multiprocessing.Barrier queue_cls = multiprocessing.Queue - barrier = barrier_cls(client_num + 1) + barrier = barrier_cls(num_clients + 1) query_queue = queue_cls() result_queue = queue_cls() - processes = [runnable_cls(target=_run_parallel, - args=(deployment_name, warmup, barrier, query_queue, result_queue, client_num, stream, vllm)) - for i in range(client_num)] + processes = [ + runnable_cls( + target=_run_parallel, + args=( + deployment_name, + warmup, + barrier, + query_queue, + result_queue, + num_clients, + stream, + vllm, + ), + ) + for i in range(num_clients) + ] for p in processes: p.start() - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + tokenizer = AutoTokenizer.from_pretrained(model) query_generator = RandomQueryGenerator(all_text, tokenizer, seed=42) - MAX_PROMPT_LENGTH = 4000 - request_text = query_generator.get_random_request_text(prompt_length, prompt_length*PROMPT_LENGTH_VAR, MAX_PROMPT_LENGTH, num_queries + warmup*client_num) + request_text = query_generator.get_random_request_text( + mean_prompt_length, + mean_prompt_length * prompt_length_var, + max_prompt_length, + num_requests + warmup * num_clients, + ) for t in request_text: - req_max_new_tokens = int(np.random.normal(max_new_tokens, MAX_NEW_TOKENS_VAR*max_new_tokens)) + req_max_new_tokens = int( + np.random.normal( + mean_max_new_tokens, max_new_tokens_var * mean_max_new_tokens + ) + ) query_queue.put((t, req_max_new_tokens)) # Tokenizers must be initialized after fork. @@ -245,41 +252,21 @@ def run_client(client_num, deployment_name, prompt_length, max_new_tokens, num_q barrier.wait() response_details = [] - while len(response_details) < num_queries: + while len(response_details) < num_requests: res = result_queue.get() # vLLM returns concatinated tokens if vllm: all_tokens = tokenizer.tokenize(res.generated_tokens) - res.generated_tokens = all_tokens[len(tokenizer.tokenize(res.prompt)):] + res.generated_tokens = all_tokens[len(tokenizer.tokenize(res.prompt)) :] response_details.append(res) return response_details + if __name__ == "__main__": - args = parse_args() - print(args) - - if args.out_json_path is not None and not args.out_json_path.parent.exists(): - raise ValueError(f"Parent directory of {args.out_json_path}") - - response_details = run_client(args.client_num, args.deployment_name, - args.prompt_length, - args.max_new_tokens, args.num_queries, args.warmup, - args.stream, args.vllm, args.use_thread) - - args_dict = vars(args) - ps = get_summary(args_dict, response_details) - print(f"Deployment: {args.deployment_name} Clients: {args.client_num}, " - + f"Prompt (mean): {args.prompt_length} tokens, " - + f"Generation (mean): {args.max_new_tokens} tokens, " - + f"Query throughput: {ps.throughput:.3f} queries/s, " - + f"Token throughput (total): {ps.tokens_per_sec:.3f} tokens/s, " - + f"Query latency: {ps.latency:.3f} s, " - + f"Token generation latency: {ps.token_gen_latency:.3f} s/token, " - + f"First token received: {ps.first_token_latency:.3f} s") - - if args.out_json_path is not None: - with open(args.out_json_path, "w") as f: - args_dict["out_json_path"] = str(args.out_json_path) # Path is not JSON serializable - data = {"args": args_dict, "time": str(datetime.now()), "response_details": [asdict(r) for r in response_details]} - json.dump(data, f, indent=2) + args = parse_args(client_args=True) + + for client_args in get_args_product(args, which=CLIENT_PARAMS): + response_details = run_client(client_args) + + print_summary(client_args, response_details) diff --git a/benchmarks/inference/mii/src/defaults.py b/benchmarks/inference/mii/src/defaults.py new file mode 100644 index 000000000..79ce91c97 --- /dev/null +++ b/benchmarks/inference/mii/src/defaults.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +ARG_DEFAULTS = { + "tp_size": 1, + "max_ragged_batch_size": 768, + "num_replicas": 1, + "max_prompt_length": 4000, + "mean_prompt_length": 2600, + "mean_max_new_tokens": 60, +} + +MODEL_DEFAULTS = { + "meta-llama/Llama-2-7b-hf": { + "max_prompt_length": 4000, + "mean_prompt_length": (1200, 2600), + "mean_max_new_tokens": (60, 128), + "tp_size": 1, + }, + "meta-llama/Llama-13b-hf": { + "max_prompt_length": 4000, + "mean_prompt_length": (1200, 2600), + "mean_max_new_tokens": (60, 128), + "tp_size": (1, 2, 4), + }, + "meta-llama/Llama-2-70b-hf": { + "max_prompt_length": 4000, + "mean_prompt_length": (1200, 2600), + "mean_max_new_tokens": (60, 128), + "tp_size": (4, 8), + }, + "tiiuae/falcon-40B": { + "max_prompt_length": 2000, + "mean_prompt_length": (1200, 1900), + "mean_max_new_tokens": (60, 128), + "tp_size": (2, 4), + }, + "tiiuae/falcon-180B": { + "max_prompt_length": 2000, + "mean_prompt_length": (1200, 1900), + "mean_max_new_tokens": (60, 128), + "tp_size": 8, + }, + "microsoft/phi-2": { + "max_prompt_length": 2000, + "mean_prompt_length": (1200, 1900), + "mean_max_new_tokens": (60, 128), + "tp_size": 1, + }, + "mistralai/Mixtral-8x7B-v0.1": { + "max_prompt_length": 4000, + "mean_prompt_length": (1200, 2600), + "mean_max_new_tokens": (60, 128), + "tp_size": 4, + }, +} diff --git a/benchmarks/inference/mii/plot_effective_throughput.py b/benchmarks/inference/mii/src/plot_effective_throughput.py similarity index 53% rename from benchmarks/inference/mii/plot_effective_throughput.py rename to benchmarks/inference/mii/src/plot_effective_throughput.py index 350c269c3..efa471c76 100644 --- a/benchmarks/inference/mii/plot_effective_throughput.py +++ b/benchmarks/inference/mii/src/plot_effective_throughput.py @@ -1,3 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + import argparse from pathlib import Path import glob @@ -5,21 +10,16 @@ import numpy as np import pandas as pd -from postprocess_results import read_json, get_tokenizer +from .postprocess_results import read_json, get_tokenizer RAGGED_BATCH_SIZE = 768 SLA_PROMPT_TOKENS_PER_SEC = 512 SLA_GEN_TOKENS_PER_SEC = [1, 2, 3, 4, 6, 8] EMA_SPAN = 16 -tp_sizes_all = { - "7b": [1], - "70b": [4, 8] -} +tp_sizes_all = {"7b": [1], "70b": [4, 8]} -tp_sizes_test = { - "7b": [1] -} +tp_sizes_test = {"7b": [1]} prompt_gen_pairs_all = [ (1200, 60), @@ -28,9 +28,8 @@ (2600, 128), ] -prompt_gen_pairs_test = [ - (2600, 60) -] +prompt_gen_pairs_test = [(2600, 60)] + def get_args(): parser = argparse.ArgumentParser() @@ -43,23 +42,54 @@ def get_args(): def check_token_latency_step(response_details, token_index): - P50_token_latency = np.percentile([r.token_gen_time[token_index] for r in response_details if len(r.token_gen_time) > token_index], 50) - P90_token_latency = np.percentile([r.token_gen_time[token_index] for r in response_details if len(r.token_gen_time) > token_index], 90) - P99_token_latency = np.percentile([r.token_gen_time[token_index] for r in response_details if len(r.token_gen_time) > token_index], 99) + P50_token_latency = np.percentile( + [ + r.token_gen_time[token_index] + for r in response_details + if len(r.token_gen_time) > token_index + ], + 50, + ) + P90_token_latency = np.percentile( + [ + r.token_gen_time[token_index] + for r in response_details + if len(r.token_gen_time) > token_index + ], + 90, + ) + P99_token_latency = np.percentile( + [ + r.token_gen_time[token_index] + for r in response_details + if len(r.token_gen_time) > token_index + ], + 99, + ) return P50_token_latency, P90_token_latency, P99_token_latency def validate_token_cum_latency_SLA(response_detail, sla_token_gen): cumsum_latencies = np.cumsum(np.array(response_detail.token_gen_time[1:])) - return all([cumsum_latencies[i] <= (1 / sla_token_gen) * (i + 1) for i in range(len(cumsum_latencies))]) + return all( + [ + cumsum_latencies[i] <= (1 / sla_token_gen) * (i + 1) + for i in range(len(cumsum_latencies)) + ] + ) def validate_token_ema_latency_SLA(response_detail, sla_token_gen, ema_span): - ema_latency = pd.Series(response_detail.token_gen_time[1:]).ewm(span=ema_span).mean().values.tolist() - return all([t < 1. / sla_token_gen for t in ema_latency]) + ema_latency = ( + pd.Series(response_detail.token_gen_time[1:]) + .ewm(span=ema_span) + .mean() + .values.tolist() + ) + return all([t < 1.0 / sla_token_gen for t in ema_latency]) + - def validate_prompt_latency_SLA(response_detail, sla_token_gen, f): tokenizer = get_tokenizer() prompt_length = len(tokenizer.tokenize(response_detail.prompt)) @@ -71,14 +101,14 @@ def validate_prompt_latency_SLA(response_detail, sla_token_gen, f): return True return f[0](response_detail, sla_token_gen, *f[1]) - + def calc_throughput(response_details): start_time = min([r.start_time for r in response_details]) end_time = max([r.end_time for r in response_details]) return len(response_details) / (end_time - start_time) - + def extract_values(file_pattern, sla_token_gen, validate_func): files = glob.glob(file_pattern) print(f"Found {len(files)} files") @@ -87,8 +117,16 @@ def extract_values(file_pattern, sla_token_gen, validate_func): for f in files: prof_args, response_details = read_json(f) client_num = prof_args["client_num"] - num_req_ok = len([r for r in response_details if validate_prompt_latency_SLA(r, sla_token_gen, validate_func)]) - goodputs[client_num] = calc_throughput(response_details) * (num_req_ok / len(response_details)) + num_req_ok = len( + [ + r + for r in response_details + if validate_prompt_latency_SLA(r, sla_token_gen, validate_func) + ] + ) + goodputs[client_num] = calc_throughput(response_details) * ( + num_req_ok / len(response_details) + ) good_ratios[client_num] = num_req_ok / len(response_details) return goodputs, good_ratios @@ -98,11 +136,13 @@ def display_results(model_size, tp, bs, sla_token_gen, prompt, gen, log_dir, out if not log_dir.exists(): print(f"Log directory {log_dir} does not exist") return - + if not out_dir.exists(): out_dir.mkdir(parents=True, exist_ok=True) - - print(f"model: {model_size} Prompt: {prompt}, Generation: {gen}, TP: {tp} sla_token_gen: {sla_token_gen}") + + print( + f"model: {model_size} Prompt: {prompt}, Generation: {gen}, TP: {tp} sla_token_gen: {sla_token_gen}" + ) mii_file_pattern = f"{log_dir}/logs.llama2-{model_size}-tp{tp}-b{bs}/llama2-{model_size}-tp{tp}-b{bs}_c*_p{prompt}_g{gen}.json" if not args.no_vllm: @@ -110,55 +150,91 @@ def display_results(model_size, tp, bs, sla_token_gen, prompt, gen, log_dir, out validate_funcs = [ (validate_token_cum_latency_SLA, (), "cum"), - (validate_token_ema_latency_SLA, (EMA_SPAN, ), f"ema{EMA_SPAN}"), + (validate_token_ema_latency_SLA, (EMA_SPAN,), f"ema{EMA_SPAN}"), ] for f in validate_funcs: - - mii_goodputs, mii_good_ratios = extract_values(mii_file_pattern, sla_token_gen, f) + + mii_goodputs, mii_good_ratios = extract_values( + mii_file_pattern, sla_token_gen, f + ) client_num_list = sorted(list(mii_goodputs.keys())) mii_goodputs_list = [mii_goodputs[client_num] for client_num in client_num_list] if not args.no_vllm: - vllm_goodputs, vllm_good_ratios = extract_values(vllm_file_pattern, sla_token_gen, f) - vllm_goodputs_list = [vllm_goodputs[client_num] for client_num in client_num_list] + vllm_goodputs, vllm_good_ratios = extract_values( + vllm_file_pattern, sla_token_gen, f + ) + vllm_goodputs_list = [ + vllm_goodputs[client_num] for client_num in client_num_list + ] # print(f"MII {mii_goodputs_list} ratio={mii_good_ratios}") # print(f"vLLM {vllm_goodputs_list} ratio={vllm_good_ratios}") # Plotting the scatter plot plt.figure(figsize=(7, 4)) - plt.scatter(client_num_list, mii_goodputs_list, label=f"DeepSpeed-FastGen", marker="o", color="blue") + plt.scatter( + client_num_list, + mii_goodputs_list, + label=f"DeepSpeed-FastGen", + marker="o", + color="blue", + ) if not args.no_vllm: - plt.scatter(client_num_list, vllm_goodputs_list, label=f"vLLM", marker="x", color="orange") + plt.scatter( + client_num_list, + vllm_goodputs_list, + label=f"vLLM", + marker="x", + color="orange", + ) fit_x_list = np.arange(min(client_num_list), max(client_num_list), 0.1) mii_fit_model = np.polyfit(client_num_list, mii_goodputs_list, 4) mii_model_fn = np.poly1d(mii_fit_model) - plt.plot(fit_x_list, mii_model_fn(fit_x_list), color="blue", alpha=0.5, linestyle="--") + plt.plot( + fit_x_list, + mii_model_fn(fit_x_list), + color="blue", + alpha=0.5, + linestyle="--", + ) if not args.no_vllm: vllm_fit_model = np.polyfit(client_num_list, vllm_goodputs_list, 4) vllm_model_fn = np.poly1d(vllm_fit_model) - plt.plot(fit_x_list, vllm_model_fn(fit_x_list), color="orange", alpha=0.5, linestyle="--") - - title = f"Effective throughput (SLA prompt: {SLA_PROMPT_TOKENS_PER_SEC} tokens/s, generation: {sla_token_gen} tokens/s)\n" \ - + f'Llama 2 {model_size.upper()} Prompt: {prompt}, Generation: {gen}, TP: {tp}' + plt.plot( + fit_x_list, + vllm_model_fn(fit_x_list), + color="orange", + alpha=0.5, + linestyle="--", + ) + + title = ( + f"Effective throughput (SLA prompt: {SLA_PROMPT_TOKENS_PER_SEC} tokens/s, generation: {sla_token_gen} tokens/s)\n" + + f"Llama 2 {model_size.upper()} Prompt: {prompt}, Generation: {gen}, TP: {tp}" + ) plt.title(title, fontsize=10) - plt.xlabel('Number of clients', fontsize=10) - plt.ylabel('Effective throughput (queries/s)', fontsize=10) + plt.xlabel("Number of clients", fontsize=10) + plt.ylabel("Effective throughput (queries/s)", fontsize=10) # plt.rcParams['figure.subplot.bottom'] = 0.30 plt.ylim(bottom=-0.05) plt.legend() plt.grid(True) # plt.show() - out_file = out_dir / f"goodput_llama{model_size}_SLAp{SLA_PROMPT_TOKENS_PER_SEC}g{sla_token_gen}_tp{tp}_b{bs}_p{prompt}g{gen}_{f[2]}.png" + out_file = ( + out_dir + / f"goodput_llama{model_size}_SLAp{SLA_PROMPT_TOKENS_PER_SEC}g{sla_token_gen}_tp{tp}_b{bs}_p{prompt}g{gen}_{f[2]}.png" + ) plt.savefig(out_file) plt.clf() print(f"Saved {out_file}") - + if __name__ == "__main__": + raise NotImplementedError("This script is not up to date") args = get_args() if args.test: @@ -172,5 +248,13 @@ def display_results(model_size, tp, bs, sla_token_gen, prompt, gen, log_dir, out for tp in tps: for prompt, gen in prompt_gen_pairs: for sla_token_gen in SLA_GEN_TOKENS_PER_SEC: - display_results(model_size, tp, RAGGED_BATCH_SIZE, sla_token_gen, prompt, gen, args.log_dir, args.out_dir) - + display_results( + model_size, + tp, + RAGGED_BATCH_SIZE, + sla_token_gen, + prompt, + gen, + args.log_dir, + args.out_dir, + ) diff --git a/benchmarks/inference/mii/plot_latency_percentile.py b/benchmarks/inference/mii/src/plot_latency_percentile.py similarity index 72% rename from benchmarks/inference/mii/plot_latency_percentile.py rename to benchmarks/inference/mii/src/plot_latency_percentile.py index c91c78bf1..9b08f12da 100644 --- a/benchmarks/inference/mii/plot_latency_percentile.py +++ b/benchmarks/inference/mii/src/plot_latency_percentile.py @@ -1,3 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + import argparse import glob from pathlib import Path @@ -5,12 +10,12 @@ import numpy as np import itertools -from postprocess_results import read_json, get_token_latency +from .postprocess_results import read_json, get_token_latency bs = 768 SKIP_HEAD_TOKEN_NUM = 2 SKIP_REQUEST_NUM = 100 - + tp_sizes = { "70b": [4], } @@ -23,14 +28,16 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--log_dir", type=Path, default=".") - parser.add_argument("--out_dir", type=Path, default="charts/percentile_token_latency") + parser.add_argument( + "--out_dir", type=Path, default="charts/percentile_token_latency" + ) args = parser.parse_args() return args def extract_values(file_pattern): files = glob.glob(file_pattern) - + latencies = {} for f in files: prof_args, response_details = read_json(f) @@ -38,18 +45,20 @@ def extract_values(file_pattern): response_details.sort(key=lambda r: r.start_time) response_details = response_details[SKIP_REQUEST_NUM:-SKIP_REQUEST_NUM] - token_latencies = [r.token_gen_time[SKIP_HEAD_TOKEN_NUM:-1] for r in response_details] + token_latencies = [ + r.token_gen_time[SKIP_HEAD_TOKEN_NUM:-1] for r in response_details + ] flat_latency_list = list(itertools.chain(*token_latencies)) latencies[client_num] = flat_latency_list return latencies -def output_charts(model_size, tp, bs, prompt, gen, log_dir, out_dir): +def output_charts(model_size, tp, bs, prompt, gen, log_dir, out_dir): if not log_dir.exists(): print(f"Log directory {log_dir} does not exist") return - + if not out_dir.exists(): out_dir.mkdir(parents=True, exist_ok=True) @@ -79,7 +88,10 @@ def output_charts(model_size, tp, bs, prompt, gen, log_dir, out_dir): # print(f"P95_vllm_val={P95_vllm_val}") # print(f"P95_mii_val={P95_mii_val}") - out_file = out_dir / f"p{percentile}_token_latency_llama{model_size}_c{client_num}_tp{tp}_p{prompt}g{gen}.png" + out_file = ( + out_dir + / f"p{percentile}_token_latency_llama{model_size}_c{client_num}_tp{tp}_p{prompt}g{gen}.png" + ) x1 = [1, 2, 3] y1 = [P50_vllm_val, P90_vllm_val, P95_vllm_val] @@ -87,11 +99,13 @@ def output_charts(model_size, tp, bs, prompt, gen, log_dir, out_dir): x2 = [1.3, 2.3, 3.3] y2 = [P50_mii_val, P90_mii_val, P95_mii_val] - label_x = ['P50', 'P90', 'P95'] + label_x = ["P50", "P90", "P95"] - plt.bar(x1, y1, width=0.3, label='vLLM', align="center", color="orange") - plt.bar(x2, y2, width=0.3, label="DeepSpeed-FastGen", align="center", color="blue") - plt.ylabel('Latency', fontsize=14) + plt.bar(x1, y1, width=0.3, label="vLLM", align="center", color="orange") + plt.bar( + x2, y2, width=0.3, label="DeepSpeed-FastGen", align="center", color="blue" + ) + plt.ylabel("Latency", fontsize=14) plt.legend(loc=2) plt.xticks([1.15, 2.15, 3.15], label_x) @@ -101,10 +115,12 @@ def output_charts(model_size, tp, bs, prompt, gen, log_dir, out_dir): if __name__ == "__main__": + raise NotImplementedError("This script is not up to date") args = get_args() - + for model_size, tps in tp_sizes.items(): for tp in tps: for prompt, gen in prompt_gen_pairs: - output_charts(model_size, tp, bs, prompt, gen, args.log_dir, args.out_dir) - + output_charts( + model_size, tp, bs, prompt, gen, args.log_dir, args.out_dir + ) diff --git a/benchmarks/inference/mii/plot_repl_scale.py b/benchmarks/inference/mii/src/plot_repl_scale.py similarity index 81% rename from benchmarks/inference/mii/plot_repl_scale.py rename to benchmarks/inference/mii/src/plot_repl_scale.py index 394c54588..7791be0ca 100644 --- a/benchmarks/inference/mii/plot_repl_scale.py +++ b/benchmarks/inference/mii/src/plot_repl_scale.py @@ -1,10 +1,15 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + import glob import matplotlib.pyplot as plt import argparse from pathlib import Path import numpy as np -from postprocess_results import read_json, get_summary +from .postprocess_results import read_json, get_summary bs = 768 @@ -18,6 +23,7 @@ (2600, 60), ] + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--log_dir", type=Path, default=".") @@ -46,7 +52,7 @@ def output_charts(model_size, tp, bs, prompt, gen, log_dir, out_dir): if not log_dir.exists(): print(f"Log directory {log_dir} does not exist") return - + if not out_dir.exists(): out_dir.mkdir(parents=True, exist_ok=True) @@ -67,17 +73,19 @@ def output_charts(model_size, tp, bs, prompt, gen, log_dir, out_dir): # Plotting the scatter plot plt.figure(figsize=(6, 4)) - + plt.bar(REPLICA_NUMS, throughputs[c], color="blue", alpha=0.9) fit_x_list = np.arange(min(REPLICA_NUMS), max(REPLICA_NUMS), 0.1) mii_fit_model = np.polyfit(REPLICA_NUMS, throughputs[c], 1) mii_model_fn = np.poly1d(mii_fit_model) plt.plot(fit_x_list, mii_model_fn(fit_x_list), color="blue", linestyle="--") - - plt.title(f'Model Llama 2 {model_size.upper()}, Prompt: {prompt}, Generation: {gen}, TP: {tp}') - plt.xlabel('Number of replicas', fontsize=14) - plt.ylabel('Throughput (queries/s)', fontsize=14) + + plt.title( + f"Model Llama 2 {model_size.upper()}, Prompt: {prompt}, Generation: {gen}, TP: {tp}" + ) + plt.xlabel("Number of replicas", fontsize=14) + plt.ylabel("Throughput (queries/s)", fontsize=14) plt.grid(True) plt.tight_layout() # plt.show() @@ -86,10 +94,12 @@ def output_charts(model_size, tp, bs, prompt, gen, log_dir, out_dir): if __name__ == "__main__": + raise NotImplementedError("This script is not up to date") args = get_args() - + for model_size, tps in tp_sizes.items(): for tp in tps: for prompt, gen in prompt_gen_pairs: - output_charts(model_size, tp, bs, prompt, gen, args.log_dir, args.out_dir) - + output_charts( + model_size, tp, bs, prompt, gen, args.log_dir, args.out_dir + ) diff --git a/benchmarks/inference/mii/src/plot_th_lat.py b/benchmarks/inference/mii/src/plot_th_lat.py new file mode 100644 index 000000000..9aa292ca6 --- /dev/null +++ b/benchmarks/inference/mii/src/plot_th_lat.py @@ -0,0 +1,130 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import argparse +import glob +import os +import re +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + +from postprocess_results import read_json, get_summary + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--log_dir", type=Path, default="./results") + parser.add_argument("--out_dir", type=Path, default="./plots/throughput_latency") + args = parser.parse_args() + return args + + +def extract_values(file_pattern): + files = glob.glob(file_pattern) + + print(f"Found {len(files)}") + print("\n".join(files)) + + clients = [] + throughputs = [] + latencies = [] + for f in files: + prof_args, response_details = read_json(f) + summary = get_summary(prof_args, response_details) + clients.append(prof_args["num_clients"]) + throughputs.append(summary.throughput) + latencies.append(summary.latency) + + return clients, throughputs, latencies + + +def output_charts(model, tp_size, bs, replicas, prompt, gen, log_dir, out_dir): + out_dir.mkdir(parents=True, exist_ok=True) + + result_file_pattern = f"{model}-tp{tp_size}-bs{bs}-replicas{replicas}-prompt{prompt}-gen{gen}-clients*.json" + mii_file_pattern = f"{log_dir}/fastgen/{result_file_pattern}" + vllm_file_pattern = f"{log_dir}/vllm/{result_file_pattern}" + + _, mii_throughputs, mii_latencies = extract_values(mii_file_pattern) + _, vllm_throughputs, vllm_latencies = extract_values(vllm_file_pattern) + + # Plotting the scatter plot + plt.figure(figsize=(6, 4)) + + if len(vllm_throughputs) > 0: + plt.scatter( + vllm_throughputs, vllm_latencies, label=f"vLLM", marker="x", color="orange" + ) + fit_vllm_x_list = np.arange(min(vllm_throughputs), max(vllm_throughputs), 0.01) + vllm_vllm_model = np.polyfit(vllm_throughputs, vllm_latencies, 3) + vllm_model_fn = np.poly1d(vllm_vllm_model) + plt.plot( + fit_vllm_x_list, + vllm_model_fn(fit_vllm_x_list), + color="orange", + alpha=0.5, + linestyle="--", + ) + + plt.scatter( + mii_throughputs, + mii_latencies, + label=f"DeepSpeed FastGen", + marker="o", + color="blue", + ) + fit_mii_x_list = np.arange(min(mii_throughputs), max(mii_throughputs), 0.01) + mii_fit_model = np.polyfit(mii_throughputs, mii_latencies, 3) + mii_model_fn = np.poly1d(mii_fit_model) + plt.plot( + fit_mii_x_list, + mii_model_fn(fit_mii_x_list), + color="blue", + alpha=0.5, + linestyle="--", + ) + + plt.title(f"Model {model}, Prompt: {prompt}, Generation: {gen}, TP: {tp_size}") + plt.xlabel("Throughput (queries/s)", fontsize=14) + plt.ylabel("Latency", fontsize=14) + plt.legend() + plt.grid(True) + plt.tight_layout() + out_file = ( + out_dir + / f"{model}-tp{tp_size}-bs{bs}-replicas{replicas}-prompt{prompt}-gen{gen}.png" + ) + print(f"Saving {out_file}") + plt.savefig(out_file) + + +if __name__ == "__main__": + args = get_args() + + if not args.log_dir.exists(): + raise ValueError(f"Log dir {args.log_dir} does not exist") + + result_params = set() + result_re = re.compile( + r"(.+)-tp(\d+)-bs(\d+)-replicas(\d+)-prompt(\d+)-gen(\d+)-clients.*.json" + ) + for f in os.listdir(os.path.join(args.log_dir, "fastgen")): + match = result_re.match(f) + if match: + result_params.add(match.groups()) + + for model, tp_size, bs, replicas, prompt, gen in result_params: + output_charts( + model=model, + tp_size=tp_size, + bs=bs, + replicas=replicas, + prompt=prompt, + gen=gen, + log_dir=args.log_dir, + out_dir=args.out_dir, + ) diff --git a/benchmarks/inference/mii/plot_tp_sizes.py b/benchmarks/inference/mii/src/plot_tp_sizes.py similarity index 73% rename from benchmarks/inference/mii/plot_tp_sizes.py rename to benchmarks/inference/mii/src/plot_tp_sizes.py index 546310258..f02b643f2 100644 --- a/benchmarks/inference/mii/plot_tp_sizes.py +++ b/benchmarks/inference/mii/src/plot_tp_sizes.py @@ -1,13 +1,18 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + import glob import matplotlib.pyplot as plt import argparse from pathlib import Path import numpy as np -from postprocess_results import read_json, get_summary +from .postprocess_results import read_json, get_summary bs = 768 - + tp_sizes = { # "7b": [1], "13b": [1, 2, 4], @@ -22,6 +27,7 @@ (2600, 256), ] + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--log_dir", type=Path, default="logs.release") @@ -34,7 +40,7 @@ def extract_values(file_pattern): files = glob.glob(file_pattern) print(f"Found {len(files)}") - print('\n'.join(files)) + print("\n".join(files)) clients = [] throughputs = [] @@ -53,7 +59,7 @@ def output_charts(model_size, tps, bs, prompt, gen, log_dir, out_dir): if not log_dir.exists(): print(f"Log directory {log_dir} does not exist") return - + if not out_dir.exists(): out_dir.mkdir(parents=True, exist_ok=True) @@ -73,26 +79,39 @@ def output_charts(model_size, tps, bs, prompt, gen, log_dir, out_dir): tflops_per_query = n_params * (prompt + gen) * 2 * 1e-3 mii_tflops = [th * tflops_per_query / tp for th in mii_throughputs] - plt.scatter(mii_tflops, mii_latencies, label=f"TP={tp}", marker="o", color=color) + plt.scatter( + mii_tflops, mii_latencies, label=f"TP={tp}", marker="o", color=color + ) fit_mii_x_list = np.arange(min(mii_tflops), max(mii_tflops), 0.01) mii_fit_model = np.polyfit(mii_tflops, mii_latencies, 3) mii_model_fn = np.poly1d(mii_fit_model) - plt.plot(fit_mii_x_list, mii_model_fn(fit_mii_x_list), color=color, alpha=0.5, linestyle="--") - - plt.title(f'Model Llama 2 {model_size.upper()}, Prompt: {prompt}, Generation: {gen}, TP: {tps}') - plt.xlabel('TFLOPs (per GPU)', fontsize=14) - plt.ylabel('Latency', fontsize=14) + plt.plot( + fit_mii_x_list, + mii_model_fn(fit_mii_x_list), + color=color, + alpha=0.5, + linestyle="--", + ) + + plt.title( + f"Model Llama 2 {model_size.upper()}, Prompt: {prompt}, Generation: {gen}, TP: {tps}" + ) + plt.xlabel("TFLOPs (per GPU)", fontsize=14) + plt.ylabel("Latency", fontsize=14) plt.legend() plt.grid(True) # plt.show() - out_file = out_dir / f"tp_sizes_llama{model_size}_tp{'_'.join([str(tp) for tp in tps])}_p{prompt}g{gen}.png" + out_file = ( + out_dir + / f"tp_sizes_llama{model_size}_tp{'_'.join([str(tp) for tp in tps])}_p{prompt}g{gen}.png" + ) plt.savefig(out_file) if __name__ == "__main__": + raise NotImplementedError("This script is not up to date") args = get_args() - + for model_size, tps in tp_sizes.items(): for prompt, gen in prompt_gen_pairs: output_charts(model_size, tps, bs, prompt, gen, args.log_dir, args.out_dir) - diff --git a/benchmarks/inference/mii/postprocess_results.py b/benchmarks/inference/mii/src/postprocess_results.py similarity index 53% rename from benchmarks/inference/mii/postprocess_results.py rename to benchmarks/inference/mii/src/postprocess_results.py index cb2000d5f..7e25bfddc 100644 --- a/benchmarks/inference/mii/postprocess_results.py +++ b/benchmarks/inference/mii/src/postprocess_results.py @@ -1,12 +1,17 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + import argparse -from pathlib import Path import json -import numpy as np -from statistics import mean -from functools import reduce from dataclasses import dataclass +from functools import reduce +from pathlib import Path +from statistics import mean from typing import List +import numpy as np from transformers import AutoTokenizer @@ -31,10 +36,10 @@ class ProfilingSummary: first_token_latency: float tokens_per_sec: float - + def parse_args(): parser = argparse.ArgumentParser(description="Postprocess results") - parser.add_argument('-i', '--input_path', type=Path, default="results.json") + parser.add_argument("-i", "--input_path", type=Path, default="results.json") args = parser.parse_args() return args @@ -44,13 +49,13 @@ def get_tokenizer(): global tokenizer if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") - return tokenizer + return tokenizer def read_json(file_path): - with open(file_path, 'r') as f: + with open(file_path, "r") as f: data = json.load(f) - + args = data["args"] response_details = [] @@ -61,34 +66,56 @@ def read_json(file_path): def get_summary(args, response_details): - client_num = args["client_num"] + num_clients = args["num_clients"] # Calculate latency and throughput using P95 latency latency = mean([r.end_time - r.start_time for r in response_details]) - throughput = client_num / latency - - tokens_per_sec = mean([(len(get_tokenizer().tokenize(r.prompt)) + len(r.generated_tokens)) / (r.end_time - r.start_time) for r in response_details]) + throughput = num_clients / latency + + tokens_per_sec = mean( + [ + (len(get_tokenizer().tokenize(r.prompt)) + len(r.generated_tokens)) + / (r.end_time - r.start_time) + for r in response_details + ] + ) first_token_latency = mean([r.token_gen_time[0] for r in response_details]) - token_gen_latency_flat = reduce(list.__add__, [r.token_gen_time[1:-1] for r in response_details if len(r.token_gen_time) > 2]) + token_gen_latency_flat = reduce( + list.__add__, + [r.token_gen_time[1:-1] for r in response_details if len(r.token_gen_time) > 2], + ) token_gen_latency = mean([t for t in token_gen_latency_flat]) - return ProfilingSummary(throughput, latency, token_gen_latency, first_token_latency, tokens_per_sec) + return ProfilingSummary( + throughput, latency, token_gen_latency, first_token_latency, tokens_per_sec + ) -def get_token_latency(response_details, percentile=None, variance=False, cumulative=False): +def get_token_latency( + response_details, percentile=None, variance=False, cumulative=False +): req_latencies = [r.token_gen_time for r in response_details] if cumulative: - req_latencies = [np.cumsum(np.array(r.token_gen_time)).tolist() for r in response_details] + req_latencies = [ + np.cumsum(np.array(r.token_gen_time)).tolist() for r in response_details + ] max_gen_length = max([len(r.generated_tokens) for r in response_details]) latency = [] for i in range(max_gen_length): if variance: - token_latency_step = np.var([latency[i] for latency in req_latencies if len(latency) > i]) + token_latency_step = np.var( + [latency[i] for latency in req_latencies if len(latency) > i] + ) if percentile is None: - token_latency_step = [latency[i] for latency in req_latencies if len(latency) > i] + token_latency_step = [ + latency[i] for latency in req_latencies if len(latency) > i + ] else: - token_latency_step = np.percentile([latency[i] for latency in req_latencies if len(latency) > i], percentile) + token_latency_step = np.percentile( + [latency[i] for latency in req_latencies if len(latency) > i], + percentile, + ) latency.append(token_latency_step) @@ -104,9 +131,11 @@ def get_token_acc_latency(response_details, percentile=99): prof_args, response_details = read_json(args.input_path) ps = get_summary(prof_args, response_details) - print(f"Deployment: {prof_args['deployment_name']} Clients: {prof_args['client_num']}, " - + f"Query throughput: {ps.throughput:.3f} queries/s, " - + f"Token throughput (total): {ps.tokens_per_sec:.3f} tokens/s, " - + f"Query latency: {ps.latency:.3f} s, " - + f"Token generation latency: {ps.token_gen_latency:.3f} s/token, " - + f"First token received: {ps.first_token_latency:.3f} s") + print( + f"Deployment: {prof_args['deployment_name']} Clients: {prof_args['num_clients']}, " + + f"Query throughput: {ps.throughput:.3f} queries/s, " + + f"Token throughput (total): {ps.tokens_per_sec:.3f} tokens/s, " + + f"Query latency: {ps.latency:.3f} s, " + + f"Token generation latency: {ps.token_gen_latency:.3f} s/token, " + + f"First token received: {ps.first_token_latency:.3f} s" + ) diff --git a/benchmarks/inference/mii/random_query_generator.py b/benchmarks/inference/mii/src/random_query_generator.py similarity index 72% rename from benchmarks/inference/mii/random_query_generator.py rename to benchmarks/inference/mii/src/random_query_generator.py index b8442af4f..eca16d8ff 100644 --- a/benchmarks/inference/mii/random_query_generator.py +++ b/benchmarks/inference/mii/src/random_query_generator.py @@ -1,7 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import numpy as np import torch import random -import numpy as np -import time + class RandomQueryGenerator: def __init__(self, input_text, tokenizer, seed): @@ -14,9 +19,9 @@ def __init__(self, input_text, tokenizer, seed): def get_random_request_text(self, length, variance, max_length, batch): request_text = [] - tokenized_input = self.tokenizer.batch_encode_plus([self.input_text], - return_tensors="pt", - padding=False) + tokenized_input = self.tokenizer.batch_encode_plus( + [self.input_text], return_tensors="pt", padding=False + ) offset = list(range(512)) random.shuffle(offset) @@ -25,6 +30,6 @@ def get_random_request_text(self, length, variance, max_length, batch): # Set max_new_tokens following normal distribution with mean=max_new_tokens and std=0.3*max_new_tokens req_prompt_length = min(int(np.random.normal(length, variance)), max_length) - text = self.tokenizer.decode(text_ids[i:req_prompt_length+i]) + text = self.tokenizer.decode(text_ids[i : req_prompt_length + i]) request_text.append(text) return request_text diff --git a/benchmarks/inference/mii/sample_input.py b/benchmarks/inference/mii/src/sample_input.py similarity index 99% rename from benchmarks/inference/mii/sample_input.py rename to benchmarks/inference/mii/src/sample_input.py index 77d02af5f..bae18ce62 100644 --- a/benchmarks/inference/mii/sample_input.py +++ b/benchmarks/inference/mii/src/sample_input.py @@ -1,8 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team # This is a sample input consisting of: # Code & Text -all_text = '''Deep learning involves the use of neural networks, which are computational models inspired by the structure and functioning of the human brain. These networks consist of interconnected nodes called neurons. Each neuron takes input, performs a computation, and produces an output. +all_text = """Deep learning involves the use of neural networks, which are computational models inspired by the structure and functioning of the human brain. These networks consist of interconnected nodes called neurons. Each neuron takes input, performs a computation, and produces an output. During training, the neural network learns to make accurate predictions by adjusting its internal parameters. This adjustment is done using an optimization algorithm called gradient descent. Gradient descent calculates the gradients of a loss function, which measures the discrepancy between the predicted output of the network and the desired output. These gradients indicate the direction and magnitude of parameter updates that will minimize the loss. The learning rate is an important hyperparameter in gradient descent. It determines the step size taken during parameter updates. A higher learning rate can lead to faster convergence, but it risks overshooting the optimal solution. On the other hand, a lower learning rate may converge more slowly, but it can result in more precise updates. Activation functions are applied to the output of each neuron in a neural network. They introduce non-linearities, enabling the network to learn complex patterns and relationships in the data. Popular activation functions include the rectified linear unit (ReLU), sigmoid, and hyperbolic tangent (tanh). @@ -218,4 +222,4 @@ def top_p_sampling(self, logits, p=0.9): print("Top-k Sampling:", top_k_text) print("Top-p Sampling:", top_p_text) Make sure to adjust the server_url with the appropriate URL of your HTTP server, and ensure that the server is running and accessible before making requests through the API. - ''' \ No newline at end of file + """ diff --git a/benchmarks/inference/mii/src/server.py b/benchmarks/inference/mii/src/server.py new file mode 100644 index 000000000..d0ecabaf3 --- /dev/null +++ b/benchmarks/inference/mii/src/server.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import subprocess +import time + +import mii +from deepspeed.inference import RaggedInferenceEngineConfig, DeepSpeedTPConfig +from deepspeed.inference.v2.ragged import DSStateManagerConfig + +from .utils import parse_args, SERVER_PARAMS + + +def start_server(args): + vllm = args.vllm + model = args.model + deployment_name = args.deployment_name + tp_size = args.tp_size + num_replicas = args.num_replicas + max_ragged_batch_size = args.max_ragged_batch_size + + if vllm: + start_vllm_server(model=model, tp_size=tp_size) + else: + start_mii_server( + model=model, + deployment_name=deployment_name, + tp_size=tp_size, + num_replicas=num_replicas, + max_ragged_batch_size=max_ragged_batch_size, + ) + + +def start_vllm_server(model: str, tp_size: int) -> None: + vllm_cmd = ( + "python", + "-m", + "vllm.entrypoints.api_server", + "--host", + "127.0.0.1", + "--port", + "26500", + "--tensor-parallel-size", + str(tp_size), + "--model", + model, + ) + p = subprocess.Popen( + vllm_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, close_fds=True + ) + start_time = time.time() + timeout_after = 60 * 5 # 5 minutes + while True: + line = p.stderr.readline().decode("utf-8") + if "Application startup complete" in line: + break + if "error" in line.lower(): + p.terminate() + stop_vllm_server() + raise RuntimeError(f"Error starting VLLM server: {line}") + if time.time() - start_time > timeout_after: + p.terminate() + stop_vllm_server() + raise TimeoutError("Timed out waiting for VLLM server to start") + time.sleep(0.01) + + +def start_mii_server( + model, deployment_name, tp_size, num_replicas, max_ragged_batch_size +): + tp_config = DeepSpeedTPConfig(tp_size=tp_size) + mgr_config = DSStateManagerConfig( + max_ragged_batch_size=max_ragged_batch_size, + max_ragged_sequence_count=max_ragged_batch_size, + ) + inference_config = RaggedInferenceEngineConfig( + tensor_parallel=tp_config, state_manager=mgr_config + ) + + mii.serve( + model, + deployment_name=deployment_name, + tensor_parallel=tp_size, + inference_engine_config=inference_config, + replica_num=num_replicas, + ) + + +def stop_server(args): + vllm = args.vllm + deployment_name = args.deployment_name + + if vllm: + stop_vllm_server() + else: + stop_mii_server(deployment_name) + + +def stop_vllm_server(): + vllm_cmd = ("pkill", "-f", "vllm.entrypoints.api_server") + p = subprocess.Popen(vllm_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p.wait() + + +def stop_mii_server(deployment_name): + mii.client(deployment_name).terminate_server() + + +if __name__ == "__main__": + args = parse_args(server_args=True) + + if args.cmd == "start": + start_server(args) + elif args.cmd == "stop": + stop_server(args) + elif args.cmd == "restart": + stop_server(args) + start_server(args) + else: + raise ValueError(f"Invalid command {args.cmd}") diff --git a/benchmarks/inference/mii/src/utils.py b/benchmarks/inference/mii/src/utils.py new file mode 100644 index 000000000..6499a54b4 --- /dev/null +++ b/benchmarks/inference/mii/src/utils.py @@ -0,0 +1,235 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import argparse +import copy +import itertools +import json +import os + +from dataclasses import asdict +from datetime import datetime +from pathlib import Path +from typing import Iterator, List + +from .defaults import ARG_DEFAULTS, MODEL_DEFAULTS +from .postprocess_results import get_summary, ResponseDetails + +# For these arguments, users can provide multiple values when running the +# benchmark. The benchmark will iterate over all possible combinations. +SERVER_PARAMS = ["tp_size", "max_ragged_batch_size", "num_replicas"] +CLIENT_PARAMS = ["mean_prompt_length", "mean_max_new_tokens", "num_clients"] + + +def parse_args( + server_args: bool = False, client_args: bool = False +) -> argparse.Namespace: + if not (server_args or client_args): + raise ValueError("Must specify server_args or client_args or both") + + # Server args + server_parser = argparse.ArgumentParser(add_help=False) + server_parser.add_argument( + "--tp_size", type=int, nargs="+", default=None, help="Tensor parallelism size" + ) + server_parser.add_argument( + "--max_ragged_batch_size", + type=int, + nargs="+", + default=None, + help="Max batch size for ragged batching", + ) + server_parser.add_argument( + "--num_replicas", + type=int, + nargs="+", + default=None, + help="Number of MII model replicas", + ) + server_parser.add_argument( + "cmd", + type=str, + nargs="?", + choices=["start", "stop", "restart"], + help="Command for running server.py to manually start/stop/restart a server", + ) + + # Client args + client_parser = argparse.ArgumentParser(add_help=False) + client_parser.add_argument( + "--max_prompt_length", type=int, default=None, help="Max length a prompt can be" + ) + client_parser.add_argument( + "--mean_prompt_length", + type=int, + nargs="+", + default=None, + help="Mean prompt length in tokens", + ) + client_parser.add_argument( + "--mean_max_new_tokens", + type=int, + nargs="+", + default=None, + help="Mean number of new tokens to generate per prompt", + ) + client_parser.add_argument( + "--num_clients", + type=int, + nargs="+", + default=[1, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32], + help="Number of concurrent clients", + ) + client_parser.add_argument( + "--num_requests", + type=int, + default=512, + help="Number of requests to process by clients", + ) + client_parser.add_argument( + "--prompt_length_var", type=float, default=0.3, help="Variance of prompt length" + ) + client_parser.add_argument( + "--max_new_tokens_var", + type=float, + default=0.3, + help="Variance of max new tokens", + ) + client_parser.add_argument( + "--warmup", type=int, default=1, help="Number of warmup requests to process" + ) + client_parser.add_argument( + "--use_thread", action="store_true", help="Use threads instead of processes" + ) + client_parser.add_argument( + "--stream", action="store_true", help="Stream generated tokens" + ) + client_parser.add_argument( + "--out_json_dir", + type=Path, + default="./results/", + help="Directory to save result JSON files", + ) + + # Create the parser, inheriting from the server and/or client parsers + parents = [] + if server_args: + parents.append(server_parser) + if client_args: + parents.append(client_parser) + + # Common args + parser = argparse.ArgumentParser(parents=parents) + parser.add_argument( + "--model", type=str, default="meta-llama/Llama-2-7b-hf", help="Model name" + ) + parser.add_argument( + "--deployment_name", + type=str, + default="mii-benchmark-deployment", + help="Deployment name for MII server", + ) + parser.add_argument("--vllm", action="store_true", help="Use VLLM instead of MII") + parser.add_argument( + "--overwrite_results", action="store_true", help="Overwrite existing results" + ) + + # Parse arguments + args = parser.parse_args() + + # Set default values for model-specific parameters + if args.model in MODEL_DEFAULTS: + for k, v in MODEL_DEFAULTS[args.model].items(): + if hasattr(args, k) and getattr(args, k) is None: + setattr(args, k, v) + + # Grab any remaining default values not specified for a model + for k, v in ARG_DEFAULTS.items(): + if hasattr(args, k) and getattr(args, k) is None: + setattr(args, k, v) + + if server_args and not client_args: + # If we are not running the benchmark, we need to make sure to only have one value for the server args + for k in SERVER_PARAMS: + if not isinstance(getattr(args, k), int): + setattr(args, k, getattr(args, k)[0]) + + return args + + +def get_args_product( + args: argparse.Namespace, which: List[str] = None +) -> Iterator[argparse.Namespace]: + if which is None: + return copy.deepcopy(args) + for k in which: + if isinstance(getattr(args, k), int): + setattr(args, k, [getattr(args, k)]) + arg_values_product = itertools.product(*[getattr(args, k) for k in which]) + for arg_values in arg_values_product: + args_copy = copy.deepcopy(args) + for k, v in zip(which, arg_values): + setattr(args_copy, k, v) + yield args_copy + + +def get_results_path(args: argparse.Namespace) -> Path: + if args.vllm: + lib_path = "vllm" + else: + lib_path = "fastgen" + return Path( + args.out_json_dir, + f"{lib_path}/", + "-".join( + ( + args.model.replace("/", "_"), + f"tp{args.tp_size}", + f"bs{args.max_ragged_batch_size}", + f"replicas{args.num_replicas}", + f"prompt{args.mean_prompt_length}", + f"gen{args.mean_max_new_tokens}", + f"clients{args.num_clients}", + ) + ) + + ".json", + ) + + +def print_summary( + args: argparse.Namespace, response_details: List[ResponseDetails] +) -> None: + ps = get_summary(vars(args), response_details) + print( + f"Deployment: {args.deployment_name} Clients: {args.num_clients}, " + + f"Prompt (mean): {args.mean_prompt_length} tokens, " + + f"Generation (mean): {args.mean_max_new_tokens} tokens, " + + f"Query throughput: {ps.throughput:.3f} queries/s, " + + f"Token throughput (total): {ps.tokens_per_sec:.3f} tokens/s, " + + f"Query latency: {ps.latency:.3f} s, " + + f"Token generation latency: {ps.token_gen_latency:.3f} s/token, " + + f"First token received: {ps.first_token_latency:.3f} s" + ) + + +def save_json_results( + args: argparse.Namespace, response_details: List[ResponseDetails] +) -> None: + args_dict = vars(args) + out_json_path = get_results_path(args) + os.makedirs(out_json_path.parent, exist_ok=True) + + with open(out_json_path, "w") as f: + args_dict["out_json_dir"] = str(out_json_path) # Path is not JSON serializable + data = { + "args": args_dict, + "time": str(datetime.now()), + "response_details": [asdict(r) for r in response_details], + } + json.dump(data, f, indent=2) + + +def results_exist(args: argparse.Namespace) -> bool: + return get_results_path(args).exists() diff --git a/training/cifar/README.md b/training/cifar/README.md index 7c58f3b98..878b28157 100644 --- a/training/cifar/README.md +++ b/training/cifar/README.md @@ -1,21 +1,22 @@ Thanks Gopi Kumar for contributing this example, demonstrating how to apply DeepSpeed to CIFAR-10 model. -cifar10_tutorial.py +`cifar10_tutorial.py` Baseline CIFAR-10 model. -cifar10_deepspeed.py +`cifar10_deepspeed.py` DeepSpeed applied CIFAR-10 model. -ds_config.json - DeepSpeed configuration file. - -run_ds.sh +`run_ds.sh` Script for running DeepSpeed applied model. -run_ds_moe.sh +`run_ds_moe.sh` Script for running DeepSpeed model with Mixture of Experts (MoE) integration. -* To run baseline CIFAR-10 model - "python cifar10_tutorial.py" -* To run DeepSpeed CIFAR-10 model - "bash run_ds.sh" -* To run DeepSpeed CIFAR-10 model with Mixture of Experts (MoE) - "bash run_ds_moe.sh" -* To run with different data type (default='fp16') and zero stages (default=0) - "bash run_ds.sh --dtype={fp16|bf16} --stage={0|1|2|3}" +`run_ds_prmoe.sh` + Script for running DeepSpeed model with Pyramid Residual MoE (PR-MoE) integration. + +* To run baseline CIFAR-10 model - `python cifar10_tutorial.py` +* To run DeepSpeed CIFAR-10 model - `bash run_ds.sh` +* To run DeepSpeed CIFAR-10 model with Mixture of Experts (MoE) - `bash run_ds_moe.sh` +* To run DeepSpeed CIFAR-10 model with Pyramid Residual MoE (PR-MoE) - `bash run_ds_prmoe.sh` +* To run with different data type (default=`fp16`) and zero stages (default=`0`) - `bash run_ds.sh --dtype={fp16|bf16} --stage={0|1|2|3}` diff --git a/training/cifar/cifar10_deepspeed.py b/training/cifar/cifar10_deepspeed.py index da82e60db..521a75cdf 100755 --- a/training/cifar/cifar10_deepspeed.py +++ b/training/cifar/cifar10_deepspeed.py @@ -1,112 +1,105 @@ +import argparse + +import deepspeed import torch +import torch.nn as nn +import torch.nn.functional as F import torchvision import torchvision.transforms as transforms -import argparse -import deepspeed from deepspeed.accelerator import get_accelerator +from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer def add_argument(): + parser = argparse.ArgumentParser(description="CIFAR") - parser = argparse.ArgumentParser(description='CIFAR') - - #data - # cuda - parser.add_argument('--with_cuda', - default=False, - action='store_true', - help='use CPU in case there\'s no GPU support') - parser.add_argument('--use_ema', - default=False, - action='store_true', - help='whether use exponential moving average') - - # train - parser.add_argument('-b', - '--batch_size', - default=32, - type=int, - help='mini-batch size (default: 32)') - parser.add_argument('-e', - '--epochs', - default=30, - type=int, - help='number of total epochs (default: 30)') - parser.add_argument('--local_rank', - type=int, - default=-1, - help='local rank passed from distributed launcher') - - parser.add_argument('--log-interval', - type=int, - default=2000, - help="output logging information at a given interval") - - parser.add_argument('--moe', - default=False, - action='store_true', - help='use deepspeed mixture of experts (moe)') - - parser.add_argument('--ep-world-size', - default=1, - type=int, - help='(moe) expert parallel world size') - parser.add_argument('--num-experts', - type=int, - nargs='+', - default=[ - 1, - ], - help='number of experts list, MoE related.') + # For train. parser.add_argument( - '--mlp-type', - type=str, - default='standard', - help= - 'Only applicable when num-experts > 1, accepts [standard, residual]') - parser.add_argument('--top-k', - default=1, - type=int, - help='(moe) gating top 1 and 2 supported') + "-e", + "--epochs", + default=30, + type=int, + help="number of total epochs (default: 30)", + ) parser.add_argument( - '--min-capacity', - default=0, + "--local_rank", type=int, - help= - '(moe) minimum capacity of an expert regardless of the capacity_factor' + default=-1, + help="local rank passed from distributed launcher", ) parser.add_argument( - '--noisy-gate-policy', - default=None, + "--log-interval", + type=int, + default=2000, + help="output logging information at a given interval", + ) + + # For mixed precision training. + parser.add_argument( + "--dtype", + default="fp16", type=str, - help= - '(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter' + choices=["bf16", "fp16", "fp32"], + help="Datatype used for training", + ) + + # For ZeRO Optimization. + parser.add_argument( + "--stage", + default=0, + type=int, + choices=[0, 1, 2, 3], + help="Datatype used for training", ) + + # For MoE (Mixture of Experts). parser.add_argument( - '--moe-param-group', + "--moe", default=False, - action='store_true', - help= - '(moe) create separate moe param groups, required when using ZeRO w. MoE' + action="store_true", + help="use deepspeed mixture of experts (moe)", + ) + parser.add_argument( + "--ep-world-size", default=1, type=int, help="(moe) expert parallel world size" + ) + parser.add_argument( + "--num-experts", + type=int, + nargs="+", + default=[ + 1, + ], + help="number of experts list, MoE related.", ) parser.add_argument( - '--dtype', - default='fp16', + "--mlp-type", type=str, - choices=['bf16', 'fp16', 'fp32'], - help= - 'Datatype used for training' + default="standard", + help="Only applicable when num-experts > 1, accepts [standard, residual]", + ) + parser.add_argument( + "--top-k", default=1, type=int, help="(moe) gating top 1 and 2 supported" ) parser.add_argument( - '--stage', + "--min-capacity", default=0, type=int, - choices=[0, 1, 2, 3], - help= - 'Datatype used for training' + help="(moe) minimum capacity of an expert regardless of the capacity_factor", + ) + parser.add_argument( + "--noisy-gate-policy", + default=None, + type=str, + help="(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter", + ) + parser.add_argument( + "--moe-param-group", + default=False, + action="store_true", + help="(moe) create separate moe param groups, required when using ZeRO w. MoE", ) - # Include DeepSpeed configuration arguments + # Include DeepSpeed configuration arguments. parser = deepspeed.add_config_arguments(parser) args = parser.parse_args() @@ -114,110 +107,87 @@ def add_argument(): return args -deepspeed.init_distributed() - -######################################################################## -# The output of torchvision datasets are PILImage images of range [0, 1]. -# We transform them to Tensors of normalized range [-1, 1]. -# .. note:: -# If running on Windows and you get a BrokenPipeError, try setting -# the num_worker of torch.utils.data.DataLoader() to 0. - -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) -]) - -if torch.distributed.get_rank() != 0: - # might be downloading cifar data, let rank 0 download first - torch.distributed.barrier() - -trainset = torchvision.datasets.CIFAR10(root='./data', - train=True, - download=True, - transform=transform) - -if torch.distributed.get_rank() == 0: - # cifar data is downloaded, indicate other ranks can proceed - torch.distributed.barrier() - -trainloader = torch.utils.data.DataLoader(trainset, - batch_size=16, - shuffle=True, - num_workers=2) - -testset = torchvision.datasets.CIFAR10(root='./data', - train=False, - download=True, - transform=transform) -testloader = torch.utils.data.DataLoader(testset, - batch_size=4, - shuffle=False, - num_workers=2) - -classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', - 'ship', 'truck') - -######################################################################## -# Let us show some of the training images, for fun. - -import matplotlib.pyplot as plt -import numpy as np - -# functions to show an image - - -def imshow(img): - img = img / 2 + 0.5 # unnormalize - npimg = img.numpy() - plt.imshow(np.transpose(npimg, (1, 2, 0))) - plt.show() - - -# get some random training images -dataiter = iter(trainloader) -images, labels = next(dataiter) - -# show images -imshow(torchvision.utils.make_grid(images)) -# print labels -print(' '.join('%5s' % classes[labels[j]] for j in range(4))) - -######################################################################## -# 2. Define a Convolutional Neural Network -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# Copy the neural network from the Neural Networks section before and modify it to -# take 3-channel images (instead of 1-channel images as it was defined). +def create_moe_param_groups(model): + """Create separate parameter groups for each expert.""" + parameters = {"params": [p for p in model.parameters()], "name": "parameters"} + return split_params_into_different_moe_groups_for_optimizer(parameters) -import torch.nn as nn -import torch.nn.functional as F -args = add_argument() +def get_ds_config(args): + """Get the DeepSpeed configuration dictionary.""" + ds_config = { + "train_batch_size": 16, + "steps_per_print": 2000, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001, + "betas": [0.8, 0.999], + "eps": 1e-8, + "weight_decay": 3e-7, + }, + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 0.001, + "warmup_num_steps": 1000, + }, + }, + "gradient_clipping": 1.0, + "prescale_gradients": False, + "bf16": {"enabled": args.dtype == "bf16"}, + "fp16": { + "enabled": args.dtype == "fp16", + "fp16_master_weights_and_grads": False, + "loss_scale": 0, + "loss_scale_window": 500, + "hysteresis": 2, + "min_loss_scale": 1, + "initial_scale_power": 15, + }, + "wall_clock_breakdown": False, + "zero_optimization": { + "stage": args.stage, + "allgather_partitions": True, + "reduce_scatter": True, + "allgather_bucket_size": 50000000, + "reduce_bucket_size": 50000000, + "overlap_comm": True, + "contiguous_gradients": True, + "cpu_offload": False, + }, + } + return ds_config class Net(nn.Module): - def __init__(self): + def __init__(self, args): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) - if args.moe: + self.moe = args.moe + if self.moe: fc3 = nn.Linear(84, 84) self.moe_layer_list = [] for n_e in args.num_experts: - # create moe layers based on the number of experts + # Create moe layers based on the number of experts. self.moe_layer_list.append( deepspeed.moe.layer.MoE( hidden_size=84, expert=fc3, num_experts=n_e, ep_size=args.ep_world_size, - use_residual=args.mlp_type == 'residual', + use_residual=args.mlp_type == "residual", k=args.top_k, min_capacity=args.min_capacity, - noisy_gate_policy=args.noisy_gate_policy)) + noisy_gate_policy=args.noisy_gate_policy, + ) + ) self.moe_layer_list = nn.ModuleList(self.moe_layer_list) self.fc4 = nn.Linear(84, 10) else: @@ -229,7 +199,7 @@ def forward(self, x): x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) - if args.moe: + if self.moe: for layer in self.moe_layer_list: x, _, _ = layer(x) x = self.fc4(x) @@ -238,214 +208,192 @@ def forward(self, x): return x -net = Net() +def test(model_engine, testset, local_device, target_dtype, test_batch_size=4): + """Test the network on the test data. + + Args: + model_engine (deepspeed.runtime.engine.DeepSpeedEngine): the DeepSpeed engine. + testset (torch.utils.data.Dataset): the test dataset. + local_device (str): the local device name. + target_dtype (torch.dtype): the target datatype for the test data. + test_batch_size (int): the test batch size. + + """ + # The 10 classes for CIFAR10. + classes = ( + "plane", + "car", + "bird", + "cat", + "deer", + "dog", + "frog", + "horse", + "ship", + "truck", + ) + # Define the test dataloader. + testloader = torch.utils.data.DataLoader( + testset, batch_size=test_batch_size, shuffle=False, num_workers=0 + ) -def create_moe_param_groups(model): - from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer + # For total accuracy. + correct, total = 0, 0 + # For accuracy per class. + class_correct = list(0.0 for i in range(10)) + class_total = list(0.0 for i in range(10)) + + # Start testing. + model_engine.eval() + with torch.no_grad(): + for data in testloader: + images, labels = data + if target_dtype != None: + images = images.to(target_dtype) + outputs = model_engine(images.to(local_device)) + _, predicted = torch.max(outputs.data, 1) + # Count the total accuracy. + total += labels.size(0) + correct += (predicted == labels.to(local_device)).sum().item() + + # Count the accuracy per class. + batch_correct = (predicted == labels.to(local_device)).squeeze() + for i in range(test_batch_size): + label = labels[i] + class_correct[label] += batch_correct[i].item() + class_total[label] += 1 + + if model_engine.local_rank == 0: + print( + f"Accuracy of the network on the {total} test images: {100 * correct / total : .0f} %" + ) + + # For all classes, print the accuracy. + for i in range(10): + print( + f"Accuracy of {classes[i] : >5s} : {100 * class_correct[i] / class_total[i] : 2.0f} %" + ) + + +def main(args): + # Initialize DeepSpeed distributed backend. + deepspeed.init_distributed() + + ######################################################################## + # Step1. Data Preparation. + # + # The output of torchvision datasets are PILImage images of range [0, 1]. + # We transform them to Tensors of normalized range [-1, 1]. + # + # Note: + # If running on Windows and you get a BrokenPipeError, try setting + # the num_worker of torch.utils.data.DataLoader() to 0. + ######################################################################## + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) - parameters = { - 'params': [p for p in model.parameters()], - 'name': 'parameters' - } + if torch.distributed.get_rank() != 0: + # Might be downloading cifar data, let rank 0 download first. + torch.distributed.barrier() - return split_params_into_different_moe_groups_for_optimizer(parameters) + # Load or download cifar data. + trainset = torchvision.datasets.CIFAR10( + root="./data", train=True, download=True, transform=transform + ) + testset = torchvision.datasets.CIFAR10( + root="./data", train=False, download=True, transform=transform + ) + if torch.distributed.get_rank() == 0: + # Cifar data is downloaded, indicate other ranks can proceed. + torch.distributed.barrier() + + ######################################################################## + # Step 2. Define the network with DeepSpeed. + # + # First, we define a Convolution Neural Network. + # Then, we define the DeepSpeed configuration dictionary and use it to + # initialize the DeepSpeed engine. + ######################################################################## + net = Net(args) + + # Get list of parameters that require gradients. + parameters = filter(lambda p: p.requires_grad, net.parameters()) + + # If using MoE, create separate param groups for each expert. + if args.moe_param_group: + parameters = create_moe_param_groups(net) + + # Initialize DeepSpeed to use the following features. + # 1) Distributed model. + # 2) Distributed data loader. + # 3) DeepSpeed optimizer. + ds_config = get_ds_config(args) + model_engine, optimizer, trainloader, __ = deepspeed.initialize( + args=args, + model=net, + model_parameters=parameters, + training_data=trainset, + config=ds_config, + ) -parameters = filter(lambda p: p.requires_grad, net.parameters()) -if args.moe_param_group: - parameters = create_moe_param_groups(net) - -# Initialize DeepSpeed to use the following features -# 1) Distributed model -# 2) Distributed data loader -# 3) DeepSpeed optimizer -ds_config = { - "train_batch_size": 16, - "steps_per_print": 2000, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.001, - "betas": [ - 0.8, - 0.999 - ], - "eps": 1e-8, - "weight_decay": 3e-7 - } - }, - "scheduler": { - "type": "WarmupLR", - "params": { - "warmup_min_lr": 0, - "warmup_max_lr": 0.001, - "warmup_num_steps": 1000 - } - }, - "gradient_clipping": 1.0, - "prescale_gradients": False, - "bf16": { - "enabled": args.dtype == "bf16" - }, - "fp16": { - "enabled": args.dtype == "fp16", - "fp16_master_weights_and_grads": False, - "loss_scale": 0, - "loss_scale_window": 500, - "hysteresis": 2, - "min_loss_scale": 1, - "initial_scale_power": 15 - }, - "wall_clock_breakdown": False, - "zero_optimization": { - "stage": args.stage, - "allgather_partitions": True, - "reduce_scatter": True, - "allgather_bucket_size": 50000000, - "reduce_bucket_size": 50000000, - "overlap_comm": True, - "contiguous_gradients": True, - "cpu_offload": False - } -} - -model_engine, optimizer, trainloader, __ = deepspeed.initialize( - args=args, model=net, model_parameters=parameters, training_data=trainset, config=ds_config) - -local_device = get_accelerator().device_name(model_engine.local_rank) -local_rank = model_engine.local_rank - -# For float32, target_dtype will be None so no datatype conversion needed -target_dtype = None -if model_engine.bfloat16_enabled(): - target_dtype=torch.bfloat16 -elif model_engine.fp16_enabled(): - target_dtype=torch.half - -#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -#net.to(device) -######################################################################## -# 3. Define a Loss function and optimizer -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# Let's use a Classification Cross-Entropy loss and SGD with momentum. - -import torch.optim as optim - -criterion = nn.CrossEntropyLoss() -#optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) - -######################################################################## -# 4. Train the network -# ^^^^^^^^^^^^^^^^^^^^ -# -# This is when things start to get interesting. -# We simply have to loop over our data iterator, and feed the inputs to the -# network and optimize. - -for epoch in range(args.epochs): # loop over the dataset multiple times - - running_loss = 0.0 - for i, data in enumerate(trainloader): - # get the inputs; data is a list of [inputs, labels] - inputs, labels = data[0].to(local_device), data[1].to(local_device) - if target_dtype != None: - inputs = inputs.to(target_dtype) - outputs = model_engine(inputs) - loss = criterion(outputs, labels) - - model_engine.backward(loss) - model_engine.step() - - # print statistics - running_loss += loss.item() - if local_rank == 0 and i % args.log_interval == ( - args.log_interval - - 1): # print every log_interval mini-batches - print('[%d, %5d] loss: %.3f' % - (epoch + 1, i + 1, running_loss / args.log_interval)) - running_loss = 0.0 - -print('Finished Training') - -######################################################################## -# 5. Test the network on the test data -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# -# We have trained the network for 2 passes over the training dataset. -# But we need to check if the network has learnt anything at all. -# -# We will check this by predicting the class label that the neural network -# outputs, and checking it against the ground-truth. If the prediction is -# correct, we add the sample to the list of correct predictions. -# -# Okay, first step. Let us display an image from the test set to get familiar. - -dataiter = iter(testloader) -images, labels = next(dataiter) - -# print images -imshow(torchvision.utils.make_grid(images)) -print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4))) - -######################################################################## -# Okay, now let us see what the neural network thinks these examples above are: -if target_dtype != None: - images = images.to(target_dtype) -outputs = net(images.to(local_device)) - -######################################################################## -# The outputs are energies for the 10 classes. -# The higher the energy for a class, the more the network -# thinks that the image is of the particular class. -# So, let's get the index of the highest energy: -_, predicted = torch.max(outputs, 1) - -print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4))) - -######################################################################## -# The results seem pretty good. -# -# Let us look at how the network performs on the whole dataset. - -correct = 0 -total = 0 -with torch.no_grad(): - for data in testloader: - images, labels = data - if target_dtype != None: - images = images.to(target_dtype) - outputs = net(images.to(local_device)) - _, predicted = torch.max(outputs.data, 1) - total += labels.size(0) - correct += (predicted == labels.to(local_device)).sum().item() - -print('Accuracy of the network on the 10000 test images: %d %%' % - (100 * correct / total)) - -######################################################################## -# That looks way better than chance, which is 10% accuracy (randomly picking -# a class out of 10 classes). -# Seems like the network learnt something. -# -# Hmmm, what are the classes that performed well, and the classes that did -# not perform well: - -class_correct = list(0. for i in range(10)) -class_total = list(0. for i in range(10)) -with torch.no_grad(): - for data in testloader: - images, labels = data - if target_dtype != None: - images = images.to(target_dtype) - outputs = net(images.to(local_device)) - _, predicted = torch.max(outputs, 1) - c = (predicted == labels.to(local_device)).squeeze() - for i in range(4): - label = labels[i] - class_correct[label] += c[i].item() - class_total[label] += 1 - -for i in range(10): - print('Accuracy of %5s : %2d %%' % - (classes[i], 100 * class_correct[i] / class_total[i])) + # Get the local device name (str) and local rank (int). + local_device = get_accelerator().device_name(model_engine.local_rank) + local_rank = model_engine.local_rank + + # For float32, target_dtype will be None so no datatype conversion needed. + target_dtype = None + if model_engine.bfloat16_enabled(): + target_dtype = torch.bfloat16 + elif model_engine.fp16_enabled(): + target_dtype = torch.half + + # Define the Classification Cross-Entropy loss function. + criterion = nn.CrossEntropyLoss() + + ######################################################################## + # Step 3. Train the network. + # + # This is when things start to get interesting. + # We simply have to loop over our data iterator, and feed the inputs to the + # network and optimize. (DeepSpeed handles the distributed details for us!) + ######################################################################## + + for epoch in range(args.epochs): # loop over the dataset multiple times + running_loss = 0.0 + for i, data in enumerate(trainloader): + # Get the inputs. ``data`` is a list of [inputs, labels]. + inputs, labels = data[0].to(local_device), data[1].to(local_device) + + # Try to convert to target_dtype if needed. + if target_dtype != None: + inputs = inputs.to(target_dtype) + + outputs = model_engine(inputs) + loss = criterion(outputs, labels) + + model_engine.backward(loss) + model_engine.step() + + # Print statistics + running_loss += loss.item() + if local_rank == 0 and i % args.log_interval == ( + args.log_interval - 1 + ): # Print every log_interval mini-batches. + print( + f"[{epoch + 1 : d}, {i + 1 : 5d}] loss: {running_loss / args.log_interval : .3f}" + ) + running_loss = 0.0 + print("Finished Training") + + ######################################################################## + # Step 4. Test the network on the test data. + ######################################################################## + test(model_engine, testset, local_device, target_dtype) + + +if __name__ == "__main__": + args = add_argument() + main(args) diff --git a/training/cifar/run_ds_moe.sh b/training/cifar/run_ds_moe.sh index b7dcb7fa7..f87a29628 100755 --- a/training/cifar/run_ds_moe.sh +++ b/training/cifar/run_ds_moe.sh @@ -15,7 +15,6 @@ deepspeed --num_nodes=${NUM_NODES}\ cifar10_deepspeed.py \ --log-interval 100 \ --deepspeed \ - --deepspeed_config ds_config.json \ --moe \ --ep-world-size ${EP_SIZE} \ --num-experts ${EXPERTS} \ diff --git a/training/cifar/run_ds_prmoe.sh b/training/cifar/run_ds_prmoe.sh index 72731b0d5..d9755a331 100644 --- a/training/cifar/run_ds_prmoe.sh +++ b/training/cifar/run_ds_prmoe.sh @@ -12,7 +12,6 @@ EXPERTS='2 4' deepspeed --num_nodes=${NUM_NODES} --num_gpus=${NUM_GPUS} cifar10_deepspeed.py \ --log-interval 100 \ --deepspeed \ - --deepspeed_config ds_config.json \ --moe \ --ep-world-size ${EP_SIZE} \ --num-experts ${EXPERTS} \