Skip to content

Commit e1c4a12

Browse files
SigureMoCopilot
andauthored
[Graph Optimization][CINN] Use CINN in PaddleOCR-VL ViT part (#5223)
--------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 8d99bac commit e1c4a12

File tree

8 files changed

+120
-10
lines changed

8 files changed

+120
-10
lines changed

fastdeploy/engine/common_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,6 +1585,7 @@ def _setting_environ_variables(self):
15851585
"SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"),
15861586
"SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"),
15871587
"SOT_SPECIALIZED_DIM_NUMBERS": os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"),
1588+
"SOT_ENABLE_COMPILE_TIME_LIMIT": os.getenv("SOT_ENABLE_COMPILE_TIME_LIMIT", default="0"),
15881589
"FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"),
15891590
"FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"),
15901591
"FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv(

fastdeploy/engine/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@ def _setting_environ_variables(self):
464464
"SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"),
465465
"SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"),
466466
"SOT_SPECIALIZED_DIM_NUMBERS": os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"),
467+
"SOT_ENABLE_COMPILE_TIME_LIMIT": os.getenv("SOT_ENABLE_COMPILE_TIME_LIMIT", default="0"),
467468
"FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"),
468469
"FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"),
469470
"FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv(

fastdeploy/model_executor/models/paddleocr_vl/siglip.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,6 @@ class SiglipMLP(nn.Layer):
281281
def __init__(self, config):
282282
super().__init__()
283283
self.config = config
284-
self.activation_fn = get_activation_fn(config.hidden_act)
285284
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
286285
self.fc1.weight.weight_loader = self.weight_loader
287286
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
@@ -304,7 +303,7 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
304303

305304
def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
306305
hidden_states = self.fc1(hidden_states)
307-
hidden_states = self.activation_fn(hidden_states[0])
306+
hidden_states = get_activation_fn(self.config.hidden_act)(hidden_states[0])
308307
hidden_states = self.fc2(hidden_states)
309308
return hidden_states
310309

@@ -318,7 +317,6 @@ def __init__(self, config):
318317
self.layer_norm2 = paddle.nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps)
319318
self.mlp = SiglipMLP(config)
320319

321-
# @paddle.jit.to_static
322320
def forward(
323321
self,
324322
hidden_states,
@@ -527,7 +525,37 @@ def forward(
527525
else:
528526
attn_cu_seqlens = cu_seqlens
529527

530-
max_seqlen = (attn_cu_seqlens[1:] - attn_cu_seqlens[:-1]).max().item()
528+
return self._run_encoder_layer(
529+
encoder_states=encoder_states,
530+
all_attentions=all_attentions,
531+
attn_cu_seqlens=attn_cu_seqlens,
532+
output_hidden_states=output_hidden_states,
533+
reversed_window_indices=reversed_window_indices if output_hidden_states else None,
534+
use_window_attn=use_window_attn,
535+
hidden_states=hidden_states,
536+
attention_mask=attention_mask,
537+
output_attentions=output_attentions,
538+
cos_emb=cos_emb,
539+
sin_emb=sin_emb,
540+
)
541+
542+
# This function will be compiled with CINN when graph_opt_level >= 2
543+
# TODO(SigureMo): Use a new decorator to mark the function for CINN compilation
544+
def _run_encoder_layer(
545+
self,
546+
encoder_states: Optional[Tuple[()]],
547+
all_attentions: Optional[Tuple[()]],
548+
attn_cu_seqlens: Optional[paddle.Tensor],
549+
output_hidden_states: Optional[bool],
550+
reversed_window_indices: paddle.Tensor,
551+
use_window_attn: bool,
552+
hidden_states: paddle.Tensor,
553+
attention_mask: Optional[paddle.Tensor],
554+
output_attentions: bool,
555+
cos_emb: Optional[paddle.Tensor],
556+
sin_emb: Optional[paddle.Tensor],
557+
) -> paddle.Tensor:
558+
max_seqlen = (attn_cu_seqlens[1:] - attn_cu_seqlens[:-1]).max().cpu()
531559

532560
for encoder_layer in self.layers:
533561
if output_hidden_states:

fastdeploy/model_executor/models/paddleocr_vl/siglip_ops.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727

2828
def rotate_half(x):
2929
Dh = x.shape[-1]
30+
if Dh == -1:
31+
Dh = paddle.shape(x)[-1]
3032
x1 = x[..., : Dh // 2]
3133
x2 = x[..., Dh // 2 :]
3234
return paddle.concat([-x2, x1], axis=-1)
@@ -41,6 +43,8 @@ def apply_rotary_pos_emb_vision(x, cos, sin):
4143

4244
def native_neox_rope_embedding(qkv, cos, sin, num_heads):
4345
B, seq_length, D = qkv.shape
46+
if seq_length == -1:
47+
_, seq_length, _ = paddle.shape(qkv)
4448
qkv = qkv.reshape(
4549
[
4650
seq_length,
@@ -55,18 +59,23 @@ def native_neox_rope_embedding(qkv, cos, sin, num_heads):
5559
return q, k, v
5660

5761

62+
jit_unified_marker = paddle.jit.marker.unified if hasattr(paddle.jit.marker, "unified") else lambda fn: fn
63+
64+
65+
@jit_unified_marker
5866
def neox_rope_embedding(
5967
qkv: paddle.Tensor, cos_emb: paddle.Tensor, sin_emb: paddle.Tensor, num_heads: int, head_dim: int
6068
) -> List[paddle.Tensor]:
61-
if current_platform.is_cuda():
69+
if current_platform.is_cuda() and paddle.in_dynamic_mode():
6270
return fused_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads, head_dim)
6371
else:
6472
return native_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads)
6573

6674

75+
@jit_unified_marker
6776
def get_activation_fn(hidden_act: str):
6877
if hidden_act == "gelu_pytorch_tanh":
69-
if current_platform.is_cuda():
78+
if current_platform.is_cuda() and paddle.in_dynamic_mode():
7079
return gelu_tanh
7180
else:
7281
return ACT2FN["gelu_new"]

fastdeploy/worker/gpu_model_runner.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2183,6 +2183,30 @@ def capture_model(self) -> None:
21832183
time_after_capture = time.perf_counter()
21842184
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
21852185

2186+
def vision_encoder_compile(self):
2187+
if self.graph_opt_config.graph_opt_level == 0:
2188+
return
2189+
# Currently only PaddleOCR-VL model is supported for vision encoder layer
2190+
if self.model_config.model_type != "paddleocr_vl":
2191+
return
2192+
2193+
# Compile for paddleocr_vl vision encoder layer
2194+
def apply_compile(fn):
2195+
backend = "CINN" if self.graph_opt_config.graph_opt_level >= 2 else None
2196+
return paddle.jit.to_static(
2197+
fn,
2198+
full_graph=False,
2199+
backend=backend,
2200+
)
2201+
2202+
from fastdeploy.model_executor.models.paddleocr_vl.siglip import SiglipEncoder
2203+
2204+
SiglipEncoder._run_encoder_layer = apply_compile(SiglipEncoder._run_encoder_layer)
2205+
2206+
# Warmup for paddleocr_vl vision encoder layer
2207+
logger.info(f"Warmup for {self.model_config.model_type} compile...")
2208+
self._dummy_run_extract_vision_features()
2209+
21862210
@sot_warmup_guard(True)
21872211
def sot_warmup(self) -> None:
21882212
start_time = time.perf_counter()
@@ -2891,6 +2915,40 @@ def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
28912915
else:
28922916
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
28932917

2918+
@paddle.no_grad()
2919+
def _dummy_run_extract_vision_features(self):
2920+
grid_thw_list = ([(1, 10, 88), (1, 10, 80)], [(1, 14, 62), (1, 20, 42), (1, 14, 60)])
2921+
for grid_thw in grid_thw_list:
2922+
images = []
2923+
position_ids = []
2924+
cu_seqlens = [0]
2925+
for idx, thw in enumerate(grid_thw):
2926+
numel = np.prod(np.array(thw))
2927+
images.append(paddle.uniform(shape=[numel, 3, 14, 14], dtype="float32", min=0.0, max=1.0))
2928+
position_ids.append(paddle.arange(numel) % np.prod(thw[1:]))
2929+
cu_seqlens.append(cu_seqlens[-1] + numel)
2930+
2931+
images = paddle.concat(images, axis=0)
2932+
position_ids = paddle.concat(position_ids, axis=0).to(images.place)
2933+
cu_seqlens = paddle.to_tensor(cu_seqlens, dtype=paddle.int32).to(images.place)
2934+
2935+
with paddle.amp.auto_cast(
2936+
True,
2937+
custom_black_list=self.amp_black,
2938+
custom_white_list=self.amp_white,
2939+
level="O2",
2940+
dtype=self.model_config.dtype,
2941+
):
2942+
self.model.visual(
2943+
pixel_values=images,
2944+
image_grid_thw=grid_thw,
2945+
position_ids=position_ids,
2946+
interpolate_pos_encoding=True,
2947+
cu_seqlens=cu_seqlens,
2948+
use_rope=True,
2949+
window_size=-1,
2950+
)
2951+
28942952
@paddle.no_grad()
28952953
def prepare_rope3d(
28962954
self, position_ids: paddle.Tensor, max_len_lst: list[int], cumsum_seqlens: list[int]

fastdeploy/worker/gpu_worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def graph_optimize_and_warm_up_model(self) -> None:
209209
"""
210210
if self.fd_config.graph_opt_config.graph_opt_level >= 1 and not self.model_runner.use_cudagraph:
211211
self.model_runner.sot_warmup()
212+
if self.fd_config.graph_opt_config.graph_opt_level >= 1:
213+
self.model_runner.vision_encoder_compile()
212214
# Trigger cuda graph capture
213215
self.model_runner.capture_model()
214216

fastdeploy/worker/model_runner_base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,9 @@ def profile_run(self) -> None:
9393
Execute a forward pass with dummy inputs to profile the memory usage of the model."
9494
"""
9595
raise NotImplementedError
96+
97+
def vision_encoder_compile(self):
98+
"""
99+
Compile the vision encoder if applicable
100+
"""
101+
logger.info(f"No vision encoder compilation for base {self.__class__.__name__}")

tests/e2e/test_paddleocr_vl_serving.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,14 @@
3333
os.environ["FD_USE_MACHETE"] = "0"
3434

3535

36-
@pytest.fixture(scope="session", autouse=True)
37-
def setup_and_run_server():
36+
@pytest.fixture(scope="session", autouse=True, params=[0, 2])
37+
def setup_and_run_server(request):
3838
"""
39-
Pytest fixture that runs once per test session:
39+
Pytest fixture that runs once per test session, parameterized by `graph_opt_level`:
40+
- Runs tests with graph_opt_level=0 (dynamic graph with fused ops)
41+
- Runs tests with graph_opt_level=2 (CINN compilation)
42+
43+
This ensures the API server is tested under both graph optimization configurations.
4044
- Cleans ports before tests
4145
- Starts the API server as a subprocess
4246
- Waits for server port to open (up to 30 seconds)
@@ -55,6 +59,7 @@ def setup_and_run_server():
5559
model_path = "./PaddleOCR-VL-0.9B"
5660

5761
log_path = "server.log"
62+
graph_opt_level = request.param
5863

5964
cmd = [
6065
sys.executable,
@@ -80,7 +85,7 @@ def setup_and_run_server():
8085
"--gpu-memory-utilization",
8186
"0.9",
8287
"--graph-optimization-config",
83-
'{"graph_opt_level":0, "use_cudagraph":true}',
88+
f'{{"graph_opt_level":{graph_opt_level}, "use_cudagraph":true}}',
8489
]
8590

8691
# Start subprocess in new process group

0 commit comments

Comments
 (0)