This document briefs on serving the Llama 2 as presented in the original Llama repo using PyTorch(PT) Tensor Parallel (TP) APIs, which under the hood make use of DTensors. It basically, takes a sharding plan for linear layers in MLP and Attention blocks of Llama2 model and make a TP model distributed over multiple GPUs. In the following, we show the steps how to use this and serve the Llama2 7-70B model with Torchserve.
Here we convert the Meta Llama2 model, which is based on Fairscale TP layers to PT distributed compliant checkpoints and use PT TP (DTensor) API to run the Distributed inference.
Note The following has been tested on A100 GPUs with 40 GB memory so far.
1- Make sure you have access to Llama2 weights on HF model hub, there is a form you need to fill up and within few mins you will get access. Any Llama2 model name on the hub without -hf is Meta/FAIR weight.
Make sure you are signed up in HF as well, you will need your API token than can be accessed from here, make sure to use the same email for accessing the weights as email you signed in to HF.
Once you have the access, in your terminal login to HF
huggingface-cli login YOUR_TOKEN
Make sure to have PyTorch Nighlies installed.
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118
pip install transformers fire sentencepiece
Login into HuggingFace hub with token by running the below command, make sure to specify the right name for the Llama2 model from HuggingFace (HF) model hub, any model name on the model hub without -hf is Meta original model/ checkpoints and we need them not the HF converted versions.
huggingface-cli login
paste the token generated from HuggingFace hub. Make sure use_auth_token=True
is in Download script.
python ../utils/Download_model.py --model_name meta-llama/Llama-2-7b
The script prints the path where the model is downloaded as below.
model/models--meta-llama--Llama-2-7b/snapshots/365ffa8f1a6c455d3e2028ae658236b4b85ba824
Convert the checkpoints to PT-D compliant checkpoints as follows, note that for 7B --model_parallel_size 1
for 13B would be --model_parallel_size 2
and 70B model_parallel_size 8
, you can also set --nproc_per_node
accordingly. PT-D compliant support flexible world_size when loading back the checkpoints into TP(lized) model.
You would be able to use larger number of processes/ TP size when load the model back. For example if you have converted the 13B
checkpoints with --nproc_per_node 2
, during the inference you can use --nproc_per_node
be [2, max_num_available_gpu]
which you are changing the world_size and effectively the TP size. The recommendation here is to keep the TP size as shown above respective to model size, 7B (TP Size =1), 13B (TP Size =2), 70B (TP Size =8), unless your benchmark and your batch size/ compute load compensate for communication cost.
This will save the model args in model_args.json
, during the inference step you need to pass this json file for build the model. Make sure you are setting --max_seq_len
which is the maximum sequence length for input text (context length) and --max_batch_size
which is maximum batch size for inference to respective values. These two values will be used to construct the KV cache.
torchrun --nnodes 1 --nproc_per_node 8 convert_checkpoints.py --original_ckpt_dir PATH/TO/MODEL/CHECKPOINTS --tokenizer_path PATH/TO/MODEL/CHECKPOINTS/tokenizer.model --model_parallel_size 1 --save_checkpoint_dir converted_checkpoints --max_seq_len 512 --max_batch_size 2
Lets setup configs in model-config.yaml
#frontend settings
minWorkers: 1
maxWorkers: 1
maxBatchDelay: 200
responseTimeout: 300
parallelType: "tp"
deviceType: "gpu"
torchrun:
nproc-per-node: 8 # TP size
handler:
converted_ckpt_dir: "converted_checkpoints"
tokenizer_path: "tokenizer.model"
model_args_path: "model_args.json"
max_seq_len: 512
max_batch_size: 6
max_new_tokens: 50
temperature: 0.6
top_p: 0.9
manual_seed: 40
mode: "text_completion" #choices are text_completion, chat
Create the mar file using the following command here.
torch-model-archiver --model-name llama --version 1.0 --handler llama-handler.py --config-file model-config.yaml --archive-format no-archive --extra-files "llama2.py,llama2_tokenizer.py,generate.py,checkpoint_converter.py"
mv converted_checkpoints llama
mv PATH/TO/MODEL/CHECKPOINTS/tokenizer.model llama
mv model_args.json llama
torchserve --ncs --start --model-store model_store --models llama
Text completion example :
curl -v "http://localhost:8080/predictions/llama" -T sample_text.txt
Chat example :
curl -v "http://localhost:8080/predictions/llama" -T dialogs.txt