Skip to content

Commit

Permalink
Merge branch 'master' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Jan 27, 2024
2 parents 7888212 + 107681e commit c87e8ab
Show file tree
Hide file tree
Showing 28 changed files with 1,379 additions and 937 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,17 @@ def main():
# the LN that precedes it.
force_optimize_params = []
if "bigscience/bloom-" in args.model_name_or_path:
torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight)
torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias)
zero_init_enabled = (args.zero_stage == 3)
params = [
rm_model.rwtranrsformer.ln_f.weight,
rm_model.rwtranrsformer.ln_f.bias
]
with deepspeed.zero.GatheredParameters(params,
modifier_rank=0,
enabled=zero_init_enabled):
if deepspeed.comm.get_rank() == 0 or not zero_init_enabled:
torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight)
torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias)
force_optimize_params.extend(
['rwtransformer.ln_f.weight', 'rwtransformer.ln_f.bias'])

Expand Down
49 changes: 35 additions & 14 deletions benchmarks/inference/mii/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,59 @@

## Run the Benchmark

The benchmarking scripts use DeepSpeed-FastGen in the persistent mode.
You can start the server with the command below:
The benchmarking scripts use DeepSpeed-FastGen in the persistent mode. You can
run the benchmark using `run_benchmark.py`. This script will run several
combinations of inference servers and clients with different tensor parallel
size, number of model replicas (MII only), number of clients, prompt length, and
max new tokens values. By default, the benchmark will run with the `meta-llama/Llama-2-7b-hf` model.

```bash
python server.py [options] start
python run_benchmark.py
```

Use the -h option to view all available options. To stop the server, use this command:
Use the -h option to view all available options. Several models have pre-defined
default values, including `meta-llama/Llama-2-{7|13|70}b-hf`,
`tiiuae/falcon-{40|180}B`, `microsoft/phi-2`, and `mistralai/Mixtral-8x7B-v0.1`.
These defaults can be overridden if provided to the `run_benchmark.py` script.
For example, to run `meta-llama/Llama-13b-hf` with a tensor parallel size of `1`
and `2` (instead of the default `1`, `2`, and `4`):

```bash
python server.py stop
```bash
python run_benchmark.py --tp_size 1 2
```

Once the server is up and running, initiate the client using the command below. The -h option will display all the possible options.
By default the benchmark runs with DeepSpeed-MII as the backend inference
server. To change the backend to vLLM, provide the `--vllm` flag:

```bash
python run_benchmark_client.py [options]
python run_benchmark.py --vllm
```

The run_all.sh script performs benchmarks across various model sizes and client numbers. For VLLM benchmarks, use the run_all_vllm.sh script. Results are logged in a directory named logs.[BENCHMARK_PARAMETERS].
The run_all.sh script performs benchmarks across various models, client numbers,
tensor parallel sizes, etc. This script is intended to be run on a system with
8xA100 (80GB) GPUs available. It will run all the benchmarks (including vLLM)
and collect the data used in our [DeepSpeed-Fastgen
blogs](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen).
Results are collected in `./results/`.

## Analyze the Benchmark Results

The scripts mentioned below were used for generating the plots featured in our blog. Specify the root directory for log files using --log_dir.
The scripts mentioned below were used for generating the plots featured in our
blog. Specify the root directory for log files using `--log_dir`. The generated
figures will be saved to `./plots/`

- `plot_th_lat.py`: This script generates charts for throughput and latency across different model sizes and client counts.
- `plot_effective_throughput.py`: Use this to chart effective throughput.
- `plot_latency_percentile.py`: This script will plot the 50th, 90th, and 95th percentile latencies.
- `src/plot_th_lat.py`: This script generates charts for throughput and latency across different model sizes and client counts.
- `src/plot_effective_throughput.py`: Use this to chart effective throughput.
- `src/plot_latency_percentile.py`: This script will plot the 50th, 90th, and 95th percentile latencies.

## Running an End-to-End Example

To quickly experience the end-to-end process of running our benchmark and getting results, you can use the `run_example.sh`. This script is designed to execute the benchmark with a specific configuration. The plots below will be generated in the charts directory. These plots show the performance as depicted in figure 8 of our blog [post.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen#f-other-hardware-platforms)
To quickly experience the end-to-end process of running our benchmark and
getting results, you can use the `run_example.sh`. This script is designed to
execute the benchmark with a specific configuration. The plots below will be
generated in the `./plots/` directory. These plots show the performance as
depicted in figure 8 of our blog
[post.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen#f-other-hardware-platforms)

```bash
bash run_example.sh
Expand Down
116 changes: 0 additions & 116 deletions benchmarks/inference/mii/plot_th_lat.py

This file was deleted.

5 changes: 5 additions & 0 deletions benchmarks/inference/mii/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
transformers
matplotlib
deepspeed-mii>=0.2.0
vllm>=0.2.7
numpy
32 changes: 11 additions & 21 deletions benchmarks/inference/mii/run_all.sh
Original file line number Diff line number Diff line change
@@ -1,25 +1,15 @@
RAGGED_BATCH_SIZE=768
PARAM_SIZES=(7b 13b 70b)
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

declare -A TP_SIZES
TP_SIZES["7b"]="1"
TP_SIZES["13b"]="1:2:4"
TP_SIZES["70b"]="4:8"
# DeepSpeed Team

for PARAM_SIZE in ${PARAM_SIZES[@]}; do

IFS=':' read -ra TP_VALUES <<< ${TP_SIZES[${PARAM_SIZE}]}
for TP in ${TP_VALUES[@]}; do
DEPLOYMENT_NAME=llama2-${PARAM_SIZE}-tp${TP}-b${RAGGED_BATCH_SIZE}
python server.py --model_name meta-llama/Llama-2-${PARAM_SIZE}-hf -d ${DEPLOYMENT_NAME} -m ${TP} -b ${RAGGED_BATCH_SIZE} start
MODELS=(meta-llama/Llama-2-7b-hf meta-llama/Llama-2-13b-hf meta-llama/Llama-2-70b-hf tiiuae/falcon-40B tiiuae/falcon-180B microsoft/phi-2 mistralai/Mixtral-8x7B-v0.1)

DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=2600 MAX_NEW_TOKENS=60 bash ./run_benchmark_client.sh
DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=2600 MAX_NEW_TOKENS=128 bash ./run_benchmark_client.sh
DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=1200 MAX_NEW_TOKENS=60 bash ./run_benchmark_client.sh
DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=1200 MAX_NEW_TOKENS=128 bash ./run_benchmark_client.sh

echo "Stopping server"
python server.py -d ${DEPLOYMENT_NAME} stop
sleep 120
done
for MODEL in ${MODELS[@]}; do
python ./run_benchmark.py --model ${MODEL} --stream
python ./run_benchmark.py --model ${MODEL} --stream --vllm
done

# Extra runs for Mixtral with non-default settings
python ./run_benchmark.py --model mistralai/Mixtral-8x7B-v0.1 --stream --tp_size 4 --mean_prompt_length 500 --mean_max_new_tokens 150 500 1024
python ./run_benchmark.py --model mistralai/Mixtral-8x7B-v0.1 --stream --tp_size 4 --mean_prompt_length 500 --mean_max_new_tokens 150 500 1024 --vllm
25 changes: 0 additions & 25 deletions benchmarks/inference/mii/run_all_replica.sh

This file was deleted.

26 changes: 0 additions & 26 deletions benchmarks/inference/mii/run_all_vllm.sh

This file was deleted.

40 changes: 40 additions & 0 deletions benchmarks/inference/mii/run_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from src.client import run_client
from src.server import start_server, stop_server
from src.utils import (
get_args_product,
parse_args,
print_summary,
results_exist,
save_json_results,
CLIENT_PARAMS,
SERVER_PARAMS,
)


def run_benchmark() -> None:
args = parse_args(server_args=True, client_args=True)

for server_args in get_args_product(args, which=SERVER_PARAMS):
start_server(server_args)

for client_args in get_args_product(server_args, which=CLIENT_PARAMS):
if results_exist(client_args) and not args.overwrite_results:
print(
f"Found existing results and skipping current setting. To ignore existing results, use --overwrite_results"
)
continue

response_details = run_client(client_args)
print_summary(client_args, response_details)
save_json_results(client_args, response_details)

stop_server(server_args)


if __name__ == "__main__":
run_benchmark()
23 changes: 0 additions & 23 deletions benchmarks/inference/mii/run_benchmark_client.sh

This file was deleted.

Loading

0 comments on commit c87e8ab

Please sign in to comment.