-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SGLang + Verl #3852
base: main
Are you sure you want to change the base?
SGLang + Verl #3852
Changes from all commits
15b1cb1
b86144c
8dec005
5d3aaa3
9a22ee4
d889d9a
8c6e2e5
6827ecb
a245074
2f2221f
51e73a9
76efa04
8769324
0af4a69
07704ca
e497d46
562f46f
0f37323
eba2bbf
f451284
d9ff06c
328a3ab
2ec60f5
a0ca4a7
497cf2f
48bc84a
540b774
4b1107d
0c10ec8
6a4891e
b6955bc
80676b4
4761faa
c88b0c5
8fb6c8a
19882b8
9312d4d
c873c5d
bb3119b
b36c32b
93d9986
fbe2eb3
f3e0ee0
b477db8
ead6296
ba83492
e0d6388
e3fd7e9
9070ffa
cd24855
a59e8ca
e8bed71
5d7fa08
d0a1998
acac43d
5c4dc92
8e985a8
4069d80
3698ad7
7f0f289
113e6e2
dca1882
3770003
4595bb1
0af9e11
35d3d07
c48b1e7
842b9bd
8748755
58c1d7a
90e7db8
98420f9
99bd169
993c39a
fc1c92f
cd71750
7fd2117
3e37bd8
432270b
94c76d8
cf10aee
a27a6cc
0c1235d
f662f5d
bce48ac
790f001
03f4517
08d61fa
5f127fe
6b8578a
740c7dd
b345a3c
35ba72b
5cd818b
5104d3b
e16e7ef
502019e
5a08027
9527150
c661ec2
e14436e
d88940c
deac1b3
5ea4007
d09b224
07b9c1a
1a705c3
4630b19
2c6d20a
6aa5c85
8de7c8b
b0b3d6e
a15079f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,271 @@ | ||
# TODO temporarily here, remove it last | ||
|
||
# Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
import time | ||
from typing import List | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch.distributed.device_mesh import init_device_mesh | ||
from torch.distributed.fsdp import CPUOffload | ||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
from torch.distributed.fsdp import MixedPrecision | ||
from torch.distributed.fsdp.api import ( | ||
ShardedStateDictConfig, | ||
ShardingStrategy, | ||
StateDictType, | ||
) | ||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | ||
|
||
from sglang.srt.entrypoints.verl_engine import VerlEngine | ||
|
||
|
||
def main(): | ||
assert torch.cuda.is_available(), "CUDA must be present to run FSDP vLLM example" | ||
local_rank, rank, world_size = initialize_global_process_group() | ||
|
||
# NOTE MODIFIED path-related logic | ||
# local_cache_path = '~/.cache/verl/rlhf' | ||
# local_cache_path = os.path.expanduser(local_cache_path) | ||
# hdfs_path = "Qwen/Qwen2-7B-Instruct" | ||
hdfs_path = "meta-llama/Llama-3.2-1B-Instruct" | ||
local_model_path = hdfs_path | ||
# from verl.utils.fs import copy_local_path_from_hdfs | ||
# local_model_path = copy_local_path_from_hdfs(src=hdfs_path, cache_dir=local_cache_path) | ||
tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True) | ||
actor_model_config = AutoConfig.from_pretrained( | ||
local_model_path, trust_remote_code=True | ||
) | ||
with torch.device("cuda"): | ||
actor_model = AutoModelForCausalLM.from_pretrained( | ||
local_model_path, trust_remote_code=True | ||
) | ||
actor_model.to(torch.bfloat16) | ||
|
||
max_prompt_length = 16 | ||
response_length = 32 | ||
preencode_prompts = [ | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
tokenizer.pad_token = tokenizer.eos_token | ||
prompts = tokenizer( | ||
preencode_prompts, return_tensors="pt", padding=True, padding_side="left" | ||
) # NOTE MODIFIED add | ||
input_ids = prompts["input_ids"] | ||
attention_mask = prompts["attention_mask"] | ||
# from verl.utils.torch_functional import pad_sequence_to_length | ||
input_ids = pad_sequence_to_length( | ||
input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True | ||
).cuda() | ||
attention_mask = pad_sequence_to_length( | ||
attention_mask, max_prompt_length, 0, left_pad=True | ||
).cuda() | ||
|
||
from transformers import GenerationConfig | ||
|
||
generation_config = GenerationConfig(do_sample=False) | ||
actor_model.cuda() | ||
output = actor_model.generate( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
max_new_tokens=32, | ||
# max_length=max_length, | ||
eos_token_id=tokenizer.eos_token_id, | ||
pad_token_id=tokenizer.pad_token_id, | ||
generation_config=generation_config, | ||
# renormalize_logits=True, | ||
output_scores=False, # this is potentially very large | ||
return_dict_in_generate=True, | ||
use_cache=False, | ||
) # may OOM when use_cache = True | ||
seq = output.sequences | ||
response = seq[:, max_prompt_length:] | ||
|
||
print(f"hf response: {tokenizer.batch_decode(response)}") | ||
|
||
tensor_model_parallel_size = 4 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This runs on TP=4, so it needs four GPUs to run? We only have two GPUs for testing right now. If needed, we can create one for this use case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently I name it "adhoc" and will remove it and this is modified from guangming's #2736. It is currently here both for extra testing and because I guess @ocss884 may need this as a reference for the verl side. There is another example offline_batch_inference_torchrun.py for testing, and also a test_verl_engine.py containing things like update weights, comparison tests, etc. There are a lot of things like Does this need to be a real example? If so, surely this script needs a big refactor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fzyzcjy Yeah I got your point. Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should PR this to verl? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw test_verl_engine.py somehow mimics this adhoc_verl_torchrun.py, doing things like comparison tests and update weights. |
||
device_mesh = init_device_mesh( | ||
"cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"] | ||
) | ||
|
||
mixed_precision = MixedPrecision( | ||
param_dtype=torch.bfloat16, | ||
reduce_dtype=torch.float32, | ||
buffer_dtype=torch.float32, | ||
) | ||
fsdp_model = FSDP( | ||
actor_model, | ||
use_orig_params=True, | ||
auto_wrap_policy=None, | ||
device_id=torch.cuda.current_device(), | ||
sharding_strategy=ShardingStrategy.FULL_SHARD, | ||
mixed_precision=mixed_precision, | ||
cpu_offload=CPUOffload(offload_params=False), | ||
sync_module_states=False, | ||
device_mesh=device_mesh, | ||
) | ||
|
||
FSDP.set_state_dict_type( | ||
fsdp_model, | ||
state_dict_type=StateDictType.SHARDED_STATE_DICT, | ||
state_dict_config=ShardedStateDictConfig(), | ||
) | ||
|
||
state_dict = fsdp_model.state_dict() | ||
|
||
# for debug | ||
# if rank == 0: | ||
# lines = ["------------------------ state_dict ------------------------"] | ||
# for k, v in state_dict.items(): | ||
# v_local = v.to_local() | ||
# lines.append( | ||
# f"{k}\t: {v.shape=} {v_local.shape=} {v.dtype=} {v_local.dtype=} {type(v)=} {type(v_local)=}" | ||
# ) | ||
# print("\n".join(lines)) | ||
|
||
# NOTE MODIFIED | ||
# sampling_params = SamplingParams(temperature=0, | ||
# top_p=1, | ||
# n=1, | ||
# max_tokens=response_length, | ||
# logprobs=1, | ||
# ignore_eos=True, | ||
# detokenize=False) | ||
Comment on lines
+131
to
+148
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we remove this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (see above) |
||
sampling_params = dict( | ||
temperature=0, top_p=1, n=1, max_new_tokens=response_length, ignore_eos=True | ||
) | ||
|
||
tp_size, dp_size = 4, 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in this case. We can also test tp 2 dp 2. This test can run longer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (see above) |
||
kwargs = dict(mesh_shape=(tp_size, dp_size, 1), mesh_dim_names=["tp", "dp", "pp"]) | ||
inference_device_mesh_cpu = init_device_mesh("cpu", **kwargs) | ||
tp_rank = inference_device_mesh_cpu.get_local_rank("tp") | ||
print(f"{inference_device_mesh_cpu=}") | ||
|
||
for k in ["TORCHELASTIC_USE_AGENT_STORE"]: | ||
if k in os.environ: | ||
del os.environ[k] | ||
|
||
print(actor_model_config) | ||
# llm = LLM(model=None, | ||
# tokenizer=tokenizer, | ||
# model_hf_config=actor_model_config, | ||
# tensor_parallel_size=tensor_model_parallel_size, | ||
# enforce_eager=True, | ||
# dtype='bfloat16', | ||
# load_format='dummy_dtensor', | ||
# gpu_memory_utilization=0.1, | ||
# trust_remote_code=True) | ||
Comment on lines
+164
to
+172
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. delete this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (see above) |
||
changed_model_path = local_model_path.replace("-Instruct", "") | ||
assert changed_model_path != local_model_path | ||
print(f"{changed_model_path=}") | ||
llm = VerlEngine( | ||
model_path=changed_model_path, # use model of same type but different weight to test update_weights | ||
dtype="bfloat16", | ||
mem_fraction_static=0.2, | ||
device_mesh_cpu=inference_device_mesh_cpu["tp"], | ||
base_gpu_id=0, | ||
gpu_id_step=1, | ||
) | ||
|
||
t = time.time() | ||
if 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the meaning of this line? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (see above - this is a to-be-deleted adhoc script) |
||
# most naive way | ||
state_dict_full = {k: v.full_tensor() for k, v in state_dict.items()} | ||
print(f"gather full tensor: {time.time() - t:.2f}") | ||
llm.update_weights_from_tensor([(k, v) for k, v in state_dict_full.items()]) | ||
else: | ||
llm.update_weights_from_tensor([(k, v) for k, v in state_dict.items()]) | ||
print(f"[{tp_rank=}] gather + update weights: {time.time() - t:.2f}") | ||
|
||
input_ids = input_ids.cuda() | ||
attention_mask = attention_mask.cuda() | ||
idx_list = [] | ||
batch_size = input_ids.shape[0] | ||
|
||
pad_token_id = ( | ||
tokenizer.pad_token_id | ||
if tokenizer.pad_token_id is not None | ||
else tokenizer.eos_token_id | ||
) | ||
# from verl.workers.rollout.vllm_rollout.vllm_rollout import _pre_process_inputs | ||
for i in range(batch_size): | ||
idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i])) | ||
print("start generation") | ||
# outputs = llm.generate(prompt_token_ids=idx_list, sampling_params=sampling_params, use_tqdm=False) | ||
outputs = llm.generate(input_ids=idx_list, sampling_params=sampling_params) | ||
|
||
# vllm_output = outputs[0].cuda() | ||
if torch.distributed.get_rank() == 0: | ||
print(f"hf response: {tokenizer.batch_decode(response)}") | ||
# print(f'vllm response: {tokenizer.batch_decode(vllm_output)}') | ||
print(f'vllm response: {[o["text"] for o in outputs]}') | ||
|
||
llm.shutdown() | ||
|
||
|
||
# COPIED FROM verl | ||
def initialize_global_process_group(timeout_second=36000): | ||
from datetime import timedelta | ||
|
||
import torch.distributed | ||
|
||
# NOTE MODIFIED should provide backend=None to have nccl+gloo | ||
# torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second)) | ||
torch.distributed.init_process_group(timeout=timedelta(seconds=timeout_second)) | ||
|
||
local_rank = int(os.environ["LOCAL_RANK"]) | ||
rank = int(os.environ["RANK"]) | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
|
||
if torch.distributed.is_initialized(): | ||
torch.cuda.set_device(local_rank) | ||
return local_rank, rank, world_size | ||
|
||
|
||
# COPIED FROM verl | ||
def pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False): | ||
""" | ||
pad a 2D tensors (e.g. responses, logprobs) in the last dim to max_seq_length. | ||
input shape: [bs, seq_length] | ||
output shape: [bs, max_seq_length] | ||
(0, max_seq_len - tensors.shape[-1]) means right pad to max_seq_length and no left pad | ||
""" | ||
if tensors.shape[-1] >= max_seq_len: | ||
return tensors | ||
pad_tuple = ( | ||
(max_seq_len - tensors.shape[-1], 0) | ||
if left_pad | ||
else (0, max_seq_len - tensors.shape[-1]) | ||
) | ||
return F.pad(tensors, pad_tuple, "constant", pad_token_id) | ||
|
||
|
||
# COPIED FROM verl | ||
# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding. | ||
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]: | ||
# remove the left padding in the prompt token_id | ||
# pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id | ||
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][ | ||
0 | ||
] | ||
token_ids = prompt_token_ids[non_pad_index:].tolist() | ||
return token_ids | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import datetime | ||
import os | ||
import sys | ||
|
||
from torch.distributed.device_mesh import init_device_mesh | ||
|
||
from sglang.srt.entrypoints.verl_engine import VerlEngine | ||
|
||
|
||
def run(): | ||
""" | ||
Example command: | ||
``` | ||
torchrun --nproc_per_node=8 offline_batch_inference_torchrun.py | ||
``` | ||
""" | ||
|
||
local_rank = int(os.environ["LOCAL_RANK"]) | ||
rank = int(os.environ["RANK"]) | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
|
||
def _log(text): | ||
t = datetime.datetime.now().strftime("%H:%M:%S") | ||
print(f"[{t}] [rank={rank}] {text}") | ||
|
||
_log( | ||
f'start {local_rank=} {rank=} {world_size=} {sys.argv=} {os.environ.get("CUDA_VISIBLE_DEVICES")}' | ||
) | ||
|
||
tp_size = 4 | ||
dp_size = 2 | ||
assert world_size == tp_size * dp_size | ||
|
||
device_mesh_kwargs = dict( | ||
mesh_shape=(tp_size, dp_size, 1), mesh_dim_names=["tp", "dp", "pp"] | ||
) | ||
device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs) | ||
_log(f"{device_mesh_cpu=}") | ||
|
||
tp_rank = device_mesh_cpu.get_local_rank("tp") | ||
dp_rank = device_mesh_cpu.get_local_rank("dp") | ||
_log(f"{tp_rank=} {tp_size=} ; {dp_rank=} {dp_size=}") | ||
|
||
model_name, mem_fraction_static = "meta-llama/Llama-3.2-1B-Instruct", 0.1 | ||
# model_name, mem_fraction_static = "meta-llama/Llama-3.1-70B-Instruct", 0.9 # test large models | ||
# model_name, mem_fraction_static = "deepseek-ai/DeepSeek-V2-Lite", 0.8 | ||
|
||
for k in ["TORCHELASTIC_USE_AGENT_STORE"]: | ||
if k in os.environ: | ||
del os.environ[k] | ||
|
||
fragment = VerlEngine( | ||
model_path=model_name, | ||
mem_fraction_static=mem_fraction_static, | ||
device_mesh_cpu=device_mesh_cpu["tp"], | ||
base_gpu_id=dp_rank, | ||
gpu_id_step=dp_size, | ||
port=30000, | ||
# for DeepSeek-V2-Lite + DP Attention | ||
# enable_dp_attention=True, port=30000 + dp_rank * 100, | ||
) | ||
_log(f"{fragment=}") | ||
|
||
prompt_all = [ | ||
["1+1=2, 1+2=3, 1+3=4, 1+4=", "9-1=8, 8-1=7, 7-1="], | ||
["2*1=2, 2*2=4, 2*3=", "8/2=4, 6/2="], | ||
] | ||
prompt = prompt_all[dp_rank] | ||
|
||
output = fragment.generate( | ||
prompt=prompt, | ||
sampling_params=dict(max_new_tokens=16, temperature=0.0), | ||
) | ||
_log(f"{prompt=} {output=}") | ||
|
||
fragment.shutdown() | ||
_log(f"End script") | ||
|
||
|
||
if __name__ == "__main__": | ||
run() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I don't quite understand why this is running on verl-VLLM. Where is the SGLang one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(see above)