Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SGLang + Verl #3852

Open
wants to merge 113 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
113 commits
Select commit Hold shift + click to select a range
15b1cb1
empty
fzyzcjy Feb 25, 2025
b86144c
more
fzyzcjy Feb 25, 2025
8dec005
more
fzyzcjy Feb 25, 2025
5d3aaa3
more
fzyzcjy Feb 25, 2025
9a22ee4
more
fzyzcjy Feb 25, 2025
d889d9a
more
fzyzcjy Feb 25, 2025
8c6e2e5
more
fzyzcjy Feb 25, 2025
6827ecb
more
fzyzcjy Feb 25, 2025
a245074
more
fzyzcjy Feb 25, 2025
2f2221f
more
fzyzcjy Feb 25, 2025
51e73a9
more
fzyzcjy Feb 25, 2025
76efa04
more
fzyzcjy Feb 25, 2025
8769324
more
fzyzcjy Feb 25, 2025
0af4a69
more
fzyzcjy Feb 25, 2025
07704ca
more
fzyzcjy Feb 25, 2025
e497d46
more
fzyzcjy Feb 25, 2025
562f46f
more
fzyzcjy Feb 25, 2025
0f37323
more
fzyzcjy Feb 25, 2025
eba2bbf
more
fzyzcjy Feb 25, 2025
f451284
more
fzyzcjy Feb 25, 2025
d9ff06c
more
fzyzcjy Feb 25, 2025
328a3ab
more
fzyzcjy Feb 25, 2025
2ec60f5
rm gather_pyobj
fzyzcjy Feb 25, 2025
a0ca4a7
more
fzyzcjy Feb 25, 2025
497cf2f
more
fzyzcjy Feb 25, 2025
48bc84a
more
fzyzcjy Feb 25, 2025
540b774
more
fzyzcjy Feb 25, 2025
4b1107d
more
fzyzcjy Feb 25, 2025
0c10ec8
more
fzyzcjy Feb 25, 2025
6a4891e
more
fzyzcjy Feb 25, 2025
b6955bc
more
fzyzcjy Feb 25, 2025
80676b4
more
fzyzcjy Feb 25, 2025
4761faa
cp from old
fzyzcjy Feb 25, 2025
c88b0c5
more
fzyzcjy Feb 25, 2025
8fb6c8a
more
fzyzcjy Feb 25, 2025
19882b8
more
fzyzcjy Feb 25, 2025
9312d4d
more
fzyzcjy Feb 25, 2025
c873c5d
more
fzyzcjy Feb 25, 2025
bb3119b
more
fzyzcjy Feb 25, 2025
b36c32b
more
fzyzcjy Feb 25, 2025
93d9986
more
fzyzcjy Feb 25, 2025
fbe2eb3
more
fzyzcjy Feb 25, 2025
f3e0ee0
more
fzyzcjy Feb 25, 2025
b477db8
more
fzyzcjy Feb 25, 2025
ead6296
more
fzyzcjy Feb 25, 2025
ba83492
more
fzyzcjy Feb 25, 2025
e0d6388
cp from old
fzyzcjy Feb 25, 2025
e3fd7e9
more
fzyzcjy Feb 25, 2025
9070ffa
more
fzyzcjy Feb 25, 2025
cd24855
more
fzyzcjy Feb 25, 2025
a59e8ca
more
fzyzcjy Feb 25, 2025
e8bed71
more
fzyzcjy Feb 25, 2025
5d7fa08
more
fzyzcjy Feb 25, 2025
d0a1998
temp
fzyzcjy Feb 25, 2025
acac43d
temp
fzyzcjy Feb 25, 2025
5c4dc92
more
fzyzcjy Feb 25, 2025
8e985a8
more
fzyzcjy Feb 25, 2025
4069d80
more
fzyzcjy Feb 25, 2025
3698ad7
more
fzyzcjy Feb 25, 2025
7f0f289
rm temp
fzyzcjy Feb 25, 2025
113e6e2
more
fzyzcjy Feb 25, 2025
dca1882
more
fzyzcjy Feb 25, 2025
3770003
more
fzyzcjy Feb 25, 2025
4595bb1
more
fzyzcjy Feb 25, 2025
0af9e11
more
fzyzcjy Feb 25, 2025
35d3d07
more
fzyzcjy Feb 25, 2025
c48b1e7
more
fzyzcjy Feb 25, 2025
842b9bd
more
fzyzcjy Feb 25, 2025
8748755
more
fzyzcjy Feb 25, 2025
58c1d7a
more
fzyzcjy Feb 25, 2025
90e7db8
more
fzyzcjy Feb 25, 2025
98420f9
rm temp
fzyzcjy Feb 25, 2025
99bd169
more
fzyzcjy Feb 25, 2025
993c39a
more
fzyzcjy Feb 25, 2025
fc1c92f
cp from old
fzyzcjy Feb 25, 2025
cd71750
more
fzyzcjy Feb 25, 2025
7fd2117
more
fzyzcjy Feb 25, 2025
3e37bd8
more
fzyzcjy Feb 25, 2025
432270b
try
fzyzcjy Feb 25, 2025
94c76d8
Revert "try"
fzyzcjy Feb 25, 2025
cf10aee
more
fzyzcjy Feb 25, 2025
a27a6cc
temp debug
fzyzcjy Feb 25, 2025
0c1235d
Revert "temp debug"
fzyzcjy Feb 25, 2025
f662f5d
more
fzyzcjy Feb 25, 2025
bce48ac
more
fzyzcjy Feb 25, 2025
790f001
more
fzyzcjy Feb 25, 2025
03f4517
more
fzyzcjy Feb 25, 2025
08d61fa
fmt
fzyzcjy Feb 25, 2025
5f127fe
fmt
fzyzcjy Feb 25, 2025
6b8578a
more
fzyzcjy Feb 25, 2025
740c7dd
cp old test
fzyzcjy Feb 25, 2025
b345a3c
fmt
fzyzcjy Feb 25, 2025
35ba72b
more
fzyzcjy Feb 25, 2025
5cd818b
cleanup tests
fzyzcjy Feb 25, 2025
5104d3b
more
fzyzcjy Feb 26, 2025
e16e7ef
more
fzyzcjy Feb 26, 2025
502019e
more
fzyzcjy Feb 26, 2025
5a08027
more
fzyzcjy Feb 26, 2025
9527150
more
fzyzcjy Feb 26, 2025
c661ec2
more
fzyzcjy Feb 26, 2025
e14436e
more
fzyzcjy Feb 26, 2025
d88940c
more
fzyzcjy Feb 26, 2025
deac1b3
more
fzyzcjy Feb 26, 2025
5ea4007
more
fzyzcjy Feb 26, 2025
d09b224
more
fzyzcjy Feb 26, 2025
07b9c1a
more
fzyzcjy Feb 26, 2025
1a705c3
fmt
fzyzcjy Feb 26, 2025
4630b19
Merge branch 'main' into feat/verl_20250225
fzyzcjy Feb 26, 2025
2c6d20a
comments
fzyzcjy Feb 26, 2025
6aa5c85
more tests
fzyzcjy Feb 26, 2025
8de7c8b
fmt
fzyzcjy Feb 26, 2025
b0b3d6e
fix merge
fzyzcjy Feb 26, 2025
a15079f
Merge branch 'main' into feat/verl_20250225
fzyzcjy Feb 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ jobs:
cd test/srt
python3 test_update_weights_from_distributed.py
- name: Test VerlEngine
timeout-minutes: 10
run: |
cd test/srt
python3 test_verl_engine.py
- name: Test expert parallelism (EP=2)
timeout-minutes: 10
run: |
Expand Down
271 changes: 271 additions & 0 deletions examples/runtime/engine/adhoc_verl_torchrun.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't quite understand why this is running on verl-VLLM. Where is the SGLang one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(see above)

Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
# TODO temporarily here, remove it last

# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from typing import List

import torch
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.api import (
ShardedStateDictConfig,
ShardingStrategy,
StateDictType,
)
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from sglang.srt.entrypoints.verl_engine import VerlEngine


def main():
assert torch.cuda.is_available(), "CUDA must be present to run FSDP vLLM example"
local_rank, rank, world_size = initialize_global_process_group()

# NOTE MODIFIED path-related logic
# local_cache_path = '~/.cache/verl/rlhf'
# local_cache_path = os.path.expanduser(local_cache_path)
# hdfs_path = "Qwen/Qwen2-7B-Instruct"
hdfs_path = "meta-llama/Llama-3.2-1B-Instruct"
local_model_path = hdfs_path
# from verl.utils.fs import copy_local_path_from_hdfs
# local_model_path = copy_local_path_from_hdfs(src=hdfs_path, cache_dir=local_cache_path)
tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True)
actor_model_config = AutoConfig.from_pretrained(
local_model_path, trust_remote_code=True
)
with torch.device("cuda"):
actor_model = AutoModelForCausalLM.from_pretrained(
local_model_path, trust_remote_code=True
)
actor_model.to(torch.bfloat16)

max_prompt_length = 16
response_length = 32
preencode_prompts = [
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
tokenizer.pad_token = tokenizer.eos_token
prompts = tokenizer(
preencode_prompts, return_tensors="pt", padding=True, padding_side="left"
) # NOTE MODIFIED add
input_ids = prompts["input_ids"]
attention_mask = prompts["attention_mask"]
# from verl.utils.torch_functional import pad_sequence_to_length
input_ids = pad_sequence_to_length(
input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True
).cuda()
attention_mask = pad_sequence_to_length(
attention_mask, max_prompt_length, 0, left_pad=True
).cuda()

from transformers import GenerationConfig

generation_config = GenerationConfig(do_sample=False)
actor_model.cuda()
output = actor_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=32,
# max_length=max_length,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
generation_config=generation_config,
# renormalize_logits=True,
output_scores=False, # this is potentially very large
return_dict_in_generate=True,
use_cache=False,
) # may OOM when use_cache = True
seq = output.sequences
response = seq[:, max_prompt_length:]

print(f"hf response: {tokenizer.batch_decode(response)}")

tensor_model_parallel_size = 4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This runs on TP=4, so it needs four GPUs to run? We only have two GPUs for testing right now. If needed, we can create one for this use case.

Copy link
Collaborator Author

@fzyzcjy fzyzcjy Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently I name it "adhoc" and will remove it and this is modified from guangming's #2736. It is currently here both for extra testing and because I guess @ocss884 may need this as a reference for the verl side.

There is another example offline_batch_inference_torchrun.py for testing, and also a test_verl_engine.py containing things like update weights, comparison tests, etc.

There are a lot of things like vllm in the script, because guangming's original script is named like that, and I try to make changes as little as possible, and also deliberately comment out original code instead of removing it, to make it align and easy to check.

Does this need to be a real example? If so, surely this script needs a big refactor.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fzyzcjy Yeah I got your point. Thanks!
@zhaochenyang20 as far as I know the TP size is not important here. If prefer having this script for testing I would recommend to just clean it up and change TP=2. But in fact it is more like a minimal dev example which showcase "how the actor_ollout part init and update weight in verl using SGLang rollout". So I don't think it is quiet necessary for SGLang to contain such an example, it is more like a verl example.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should PR this to verl?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw test_verl_engine.py somehow mimics this adhoc_verl_torchrun.py, doing things like comparison tests and update weights.

device_mesh = init_device_mesh(
"cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]
)

mixed_precision = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
fsdp_model = FSDP(
actor_model,
use_orig_params=True,
auto_wrap_policy=None,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
cpu_offload=CPUOffload(offload_params=False),
sync_module_states=False,
device_mesh=device_mesh,
)

FSDP.set_state_dict_type(
fsdp_model,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)

state_dict = fsdp_model.state_dict()

# for debug
# if rank == 0:
# lines = ["------------------------ state_dict ------------------------"]
# for k, v in state_dict.items():
# v_local = v.to_local()
# lines.append(
# f"{k}\t: {v.shape=} {v_local.shape=} {v.dtype=} {v_local.dtype=} {type(v)=} {type(v_local)=}"
# )
# print("\n".join(lines))

# NOTE MODIFIED
# sampling_params = SamplingParams(temperature=0,
# top_p=1,
# n=1,
# max_tokens=response_length,
# logprobs=1,
# ignore_eos=True,
# detokenize=False)
Comment on lines +131 to +148
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we remove this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(see above)

sampling_params = dict(
temperature=0, top_p=1, n=1, max_new_tokens=response_length, ignore_eos=True
)

tp_size, dp_size = 4, 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case. We can also test tp 2 dp 2. This test can run longer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(see above)

kwargs = dict(mesh_shape=(tp_size, dp_size, 1), mesh_dim_names=["tp", "dp", "pp"])
inference_device_mesh_cpu = init_device_mesh("cpu", **kwargs)
tp_rank = inference_device_mesh_cpu.get_local_rank("tp")
print(f"{inference_device_mesh_cpu=}")

for k in ["TORCHELASTIC_USE_AGENT_STORE"]:
if k in os.environ:
del os.environ[k]

print(actor_model_config)
# llm = LLM(model=None,
# tokenizer=tokenizer,
# model_hf_config=actor_model_config,
# tensor_parallel_size=tensor_model_parallel_size,
# enforce_eager=True,
# dtype='bfloat16',
# load_format='dummy_dtensor',
# gpu_memory_utilization=0.1,
# trust_remote_code=True)
Comment on lines +164 to +172
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(see above)

changed_model_path = local_model_path.replace("-Instruct", "")
assert changed_model_path != local_model_path
print(f"{changed_model_path=}")
llm = VerlEngine(
model_path=changed_model_path, # use model of same type but different weight to test update_weights
dtype="bfloat16",
mem_fraction_static=0.2,
device_mesh_cpu=inference_device_mesh_cpu["tp"],
base_gpu_id=0,
gpu_id_step=1,
)

t = time.time()
if 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the meaning of this line?

Copy link
Collaborator Author

@fzyzcjy fzyzcjy Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(see above - this is a to-be-deleted adhoc script)

# most naive way
state_dict_full = {k: v.full_tensor() for k, v in state_dict.items()}
print(f"gather full tensor: {time.time() - t:.2f}")
llm.update_weights_from_tensor([(k, v) for k, v in state_dict_full.items()])
else:
llm.update_weights_from_tensor([(k, v) for k, v in state_dict.items()])
print(f"[{tp_rank=}] gather + update weights: {time.time() - t:.2f}")

input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda()
idx_list = []
batch_size = input_ids.shape[0]

pad_token_id = (
tokenizer.pad_token_id
if tokenizer.pad_token_id is not None
else tokenizer.eos_token_id
)
# from verl.workers.rollout.vllm_rollout.vllm_rollout import _pre_process_inputs
for i in range(batch_size):
idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))
print("start generation")
# outputs = llm.generate(prompt_token_ids=idx_list, sampling_params=sampling_params, use_tqdm=False)
outputs = llm.generate(input_ids=idx_list, sampling_params=sampling_params)

# vllm_output = outputs[0].cuda()
if torch.distributed.get_rank() == 0:
print(f"hf response: {tokenizer.batch_decode(response)}")
# print(f'vllm response: {tokenizer.batch_decode(vllm_output)}')
print(f'vllm response: {[o["text"] for o in outputs]}')

llm.shutdown()


# COPIED FROM verl
def initialize_global_process_group(timeout_second=36000):
from datetime import timedelta

import torch.distributed

# NOTE MODIFIED should provide backend=None to have nccl+gloo
# torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second))
torch.distributed.init_process_group(timeout=timedelta(seconds=timeout_second))

local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])

if torch.distributed.is_initialized():
torch.cuda.set_device(local_rank)
return local_rank, rank, world_size


# COPIED FROM verl
def pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False):
"""
pad a 2D tensors (e.g. responses, logprobs) in the last dim to max_seq_length.
input shape: [bs, seq_length]
output shape: [bs, max_seq_length]
(0, max_seq_len - tensors.shape[-1]) means right pad to max_seq_length and no left pad
"""
if tensors.shape[-1] >= max_seq_len:
return tensors
pad_tuple = (
(max_seq_len - tensors.shape[-1], 0)
if left_pad
else (0, max_seq_len - tensors.shape[-1])
)
return F.pad(tensors, pad_tuple, "constant", pad_token_id)


# COPIED FROM verl
# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding.
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]:
# remove the left padding in the prompt token_id
# pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][
0
]
token_ids = prompt_token_ids[non_pad_index:].tolist()
return token_ids


if __name__ == "__main__":
main()
81 changes: 81 additions & 0 deletions examples/runtime/engine/offline_batch_inference_torchrun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import datetime
import os
import sys

from torch.distributed.device_mesh import init_device_mesh

from sglang.srt.entrypoints.verl_engine import VerlEngine


def run():
"""
Example command:
```
torchrun --nproc_per_node=8 offline_batch_inference_torchrun.py
```
"""

local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])

def _log(text):
t = datetime.datetime.now().strftime("%H:%M:%S")
print(f"[{t}] [rank={rank}] {text}")

_log(
f'start {local_rank=} {rank=} {world_size=} {sys.argv=} {os.environ.get("CUDA_VISIBLE_DEVICES")}'
)

tp_size = 4
dp_size = 2
assert world_size == tp_size * dp_size

device_mesh_kwargs = dict(
mesh_shape=(tp_size, dp_size, 1), mesh_dim_names=["tp", "dp", "pp"]
)
device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs)
_log(f"{device_mesh_cpu=}")

tp_rank = device_mesh_cpu.get_local_rank("tp")
dp_rank = device_mesh_cpu.get_local_rank("dp")
_log(f"{tp_rank=} {tp_size=} ; {dp_rank=} {dp_size=}")

model_name, mem_fraction_static = "meta-llama/Llama-3.2-1B-Instruct", 0.1
# model_name, mem_fraction_static = "meta-llama/Llama-3.1-70B-Instruct", 0.9 # test large models
# model_name, mem_fraction_static = "deepseek-ai/DeepSeek-V2-Lite", 0.8

for k in ["TORCHELASTIC_USE_AGENT_STORE"]:
if k in os.environ:
del os.environ[k]

fragment = VerlEngine(
model_path=model_name,
mem_fraction_static=mem_fraction_static,
device_mesh_cpu=device_mesh_cpu["tp"],
base_gpu_id=dp_rank,
gpu_id_step=dp_size,
port=30000,
# for DeepSeek-V2-Lite + DP Attention
# enable_dp_attention=True, port=30000 + dp_rank * 100,
)
_log(f"{fragment=}")

prompt_all = [
["1+1=2, 1+2=3, 1+3=4, 1+4=", "9-1=8, 8-1=7, 7-1="],
["2*1=2, 2*2=4, 2*3=", "8/2=4, 6/2="],
]
prompt = prompt_all[dp_rank]

output = fragment.generate(
prompt=prompt,
sampling_params=dict(max_new_tokens=16, temperature=0.0),
)
_log(f"{prompt=} {output=}")

fragment.shutdown()
_log(f"End script")


if __name__ == "__main__":
run()
Loading
Loading