Skip to content

Commit 274d7fc

Browse files
committed
Prevent helper functions from being converted to tail calls.
1 parent ec525f3 commit 274d7fc

File tree

2 files changed

+145
-10
lines changed

2 files changed

+145
-10
lines changed

src/ir_generator.ml

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ type ir_context = {
4848
mutable current_function: string option;
4949
(* Symbol table reference *)
5050
symbol_table: Symbol_table.symbol_table;
51+
(* Helper function names to avoid tail call conversion *)
52+
helper_functions: (string, unit) Hashtbl.t;
5153
(* Assignment optimization info *)
5254
mutable assignment_optimizations: Map_assignment.optimization_info option;
5355
(* Constant environment for loop analysis *)
@@ -70,7 +72,7 @@ type ir_context = {
7072
}
7173

7274
(** Create new IR generation context *)
73-
let create_context ?(global_variables = []) symbol_table = {
75+
let create_context ?(global_variables = []) ?(helper_functions = []) symbol_table = {
7476
variables = Hashtbl.create 32;
7577
next_register = 0;
7678
current_block = [];
@@ -93,6 +95,9 @@ let create_context ?(global_variables = []) symbol_table = {
9395
tbl);
9496
map_origin_variables = Hashtbl.create 32;
9597
variable_types = Hashtbl.create 32;
98+
helper_functions = (let tbl = Hashtbl.create 16 in
99+
List.iter (fun helper_name -> Hashtbl.add tbl helper_name ()) helper_functions;
100+
tbl);
96101
}
97102

98103
(** Allocate a new register for intermediate values *)
@@ -1602,9 +1607,14 @@ and lower_statement ctx stmt =
16021607
(* Check if this is a simple function call that could be a tail call *)
16031608
(match callee_expr.expr_desc with
16041609
| Ast.Identifier name ->
1605-
(* This will be converted to tail call by tail call analyzer *)
1606-
let arg_vals = List.map (lower_expression ctx) args in
1607-
IRReturnCall (name, arg_vals)
1610+
(* Check if this is a helper function - if so, treat as regular call *)
1611+
if Hashtbl.mem ctx.helper_functions name then
1612+
let ret_val = lower_expression ctx expr in
1613+
IRReturnValue ret_val
1614+
else
1615+
(* This will be converted to tail call by tail call analyzer *)
1616+
let arg_vals = List.map (lower_expression ctx) args in
1617+
IRReturnCall (name, arg_vals)
16081618
| _ ->
16091619
(* Function pointer call - treat as regular return *)
16101620
let ret_val = lower_expression ctx expr in
@@ -1627,8 +1637,13 @@ and lower_statement ctx stmt =
16271637
(* Check if this is a simple function call that could be a tail call *)
16281638
(match callee_expr.expr_desc with
16291639
| Ast.Identifier name ->
1630-
let arg_vals = List.map (lower_expression ctx) args in
1631-
IRReturnCall (name, arg_vals)
1640+
(* Check if this is a helper function - if so, treat as regular call *)
1641+
if Hashtbl.mem ctx.helper_functions name then
1642+
let ret_val = lower_expression ctx return_expr in
1643+
IRReturnValue ret_val
1644+
else
1645+
let arg_vals = List.map (lower_expression ctx) args in
1646+
IRReturnCall (name, arg_vals)
16321647
| _ ->
16331648
(* Function pointer call - treat as regular return *)
16341649
let ret_val = lower_expression ctx return_expr in
@@ -1645,8 +1660,13 @@ and lower_statement ctx stmt =
16451660
| Ast.Call (callee_expr, args) ->
16461661
(match callee_expr.expr_desc with
16471662
| Ast.Identifier name ->
1648-
let arg_vals = List.map (lower_expression ctx) args in
1649-
IRReturnCall (name, arg_vals)
1663+
(* Check if this is a helper function - if so, treat as regular call *)
1664+
if Hashtbl.mem ctx.helper_functions name then
1665+
let ret_val = lower_expression ctx expr in
1666+
IRReturnValue ret_val
1667+
else
1668+
let arg_vals = List.map (lower_expression ctx) args in
1669+
IRReturnCall (name, arg_vals)
16501670
| _ ->
16511671
let ret_val = lower_expression ctx expr in
16521672
IRReturnValue ret_val)
@@ -2953,8 +2973,11 @@ let lower_multi_program ast symbol_table source_name =
29532973
(* Combine regular kernel functions with helper functions *)
29542974
let all_kernel_shared_functions = kernel_shared_functions @ helper_functions in
29552975

2976+
(* Extract helper function names for context *)
2977+
let helper_function_names = List.map (fun func -> func.Ast.func_name) helper_functions in
2978+
29562979
(* Lower kernel functions once - they are shared across all programs *)
2957-
let kernel_ctx = create_context ~global_variables:ir_global_variables symbol_table in
2980+
let kernel_ctx = create_context ~global_variables:ir_global_variables ~helper_functions:helper_function_names symbol_table in
29582981
(* Copy maps from main context to kernel context *)
29592982
Hashtbl.iter (fun map_name map_def ->
29602983
Hashtbl.add kernel_ctx.maps map_name map_def
@@ -2964,7 +2987,7 @@ let lower_multi_program ast symbol_table source_name =
29642987
(* Lower each program *)
29652988
let ir_programs = List.map (fun prog_def ->
29662989
(* Create a fresh context for each program *)
2967-
let prog_ctx = create_context ~global_variables:ir_global_variables symbol_table in
2990+
let prog_ctx = create_context ~global_variables:ir_global_variables ~helper_functions:helper_function_names symbol_table in
29682991
(* Copy maps from main context to program context *)
29692992
Hashtbl.iter (fun map_name map_def ->
29702993
Hashtbl.add prog_ctx.maps map_name map_def

tests/test_tail_call.ml

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,116 @@ let test_tail_calls_in_if_statements _ =
380380
(* Verify index mapping contains the target *)
381381
check bool "drop_handler should be in mapping" true (Hashtbl.mem analysis.index_mapping "drop_handler")
382382

383+
(** Test helper functions are NOT converted to tail calls - regression test for helper tail call bug *)
384+
let test_helper_functions_not_tail_called _ =
385+
(* Create helper functions with @helper attribute *)
386+
let rate_limit_syn_helper = make_test_func "rate_limit_syn" [("ip", U32)] (Some (make_unnamed_return Xdp_action)) [
387+
make_stmt (Return (Some (make_expr (Identifier "XDP_PASS") make_test_position))) make_test_position
388+
] in
389+
390+
let rate_limit_dns_helper = make_test_func "rate_limit_dns" [("ip", U32)] (Some (make_unnamed_return Xdp_action)) [
391+
make_stmt (Return (Some (make_expr (Identifier "XDP_PASS") make_test_position))) make_test_position
392+
] in
393+
394+
(* Create eBPF function that calls helpers in return position within match expression *)
395+
let protocol_var = make_expr (Identifier "protocol") make_test_position in
396+
let src_ip_var = make_expr (Identifier "src_ip") make_test_position in
397+
398+
(* Helper calls in return position - this was the problematic pattern *)
399+
let syn_helper_call = make_expr (Call (make_expr (Identifier "rate_limit_syn") make_test_position, [src_ip_var])) make_test_position in
400+
let dns_helper_call = make_expr (Call (make_expr (Identifier "rate_limit_dns") make_test_position, [src_ip_var])) make_test_position in
401+
let xdp_drop_const = make_expr (Identifier "XDP_DROP") make_test_position in
402+
403+
let match_arms = [
404+
{ arm_pattern = ConstantPattern (IntLit (6, None)); arm_body = SingleExpr syn_helper_call; arm_pos = make_test_position }; (* TCP *)
405+
{ arm_pattern = ConstantPattern (IntLit (17, None)); arm_body = SingleExpr dns_helper_call; arm_pos = make_test_position }; (* UDP *)
406+
{ arm_pattern = DefaultPattern; arm_body = SingleExpr xdp_drop_const; arm_pos = make_test_position };
407+
] in
408+
409+
let match_expr = make_expr (Match (protocol_var, match_arms)) make_test_position in
410+
411+
let ddos_protection = make_test_func "ddos_protection" [("ctx", Xdp_md)] (Some (make_unnamed_return Xdp_action)) [
412+
make_stmt (Declaration ("protocol", Some U32, Some (make_expr (Literal (IntLit (6, None))) make_test_position))) make_test_position;
413+
make_stmt (Declaration ("src_ip", Some U32, Some (make_expr (Literal (IntLit (0xc0a80101, None))) make_test_position))) make_test_position;
414+
make_stmt (Return (Some match_expr)) make_test_position
415+
] in
416+
417+
(* Mark helpers with @helper attribute - this is critical *)
418+
let attr_syn_helper = make_test_attr_func [SimpleAttribute "helper"] rate_limit_syn_helper in
419+
let attr_dns_helper = make_test_attr_func [SimpleAttribute "helper"] rate_limit_dns_helper in
420+
let attr_ddos_protection = make_test_attr_func [SimpleAttribute "xdp"] ddos_protection in
421+
422+
let ast = [AttributedFunction attr_syn_helper; AttributedFunction attr_dns_helper; AttributedFunction attr_ddos_protection] in
423+
let analysis = analyze_tail_calls ast in
424+
425+
(* Critical assertions: helper functions should NOT create tail call dependencies *)
426+
check int "helper functions should not create tail call dependencies" 0 (List.length analysis.dependencies);
427+
428+
(* No prog_array should be needed since only helpers are called *)
429+
check int "prog_array_size should be 0 when only helpers called" 0 analysis.prog_array_size;
430+
431+
(* Verify that helper function names are not in the index mapping *)
432+
check bool "rate_limit_syn should NOT be in tail call mapping" false (Hashtbl.mem analysis.index_mapping "rate_limit_syn");
433+
check bool "rate_limit_dns should NOT be in tail call mapping" false (Hashtbl.mem analysis.index_mapping "rate_limit_dns");
434+
() (* Close the function *)
435+
436+
(** Test mixed scenario: helpers and actual eBPF programs - regression test for proper differentiation *)
437+
let test_mixed_helpers_and_tail_calls _ =
438+
(* Helper function *)
439+
let log_packet_helper = make_test_func "log_packet" [("ctx", Xdp_md)] (Some (make_unnamed_return Xdp_action)) [
440+
make_stmt (Return (Some (make_expr (Identifier "XDP_PASS") make_test_position))) make_test_position
441+
] in
442+
443+
(* Actual eBPF program that can be tail called *)
444+
let process_tcp_program = make_test_func "process_tcp" [("ctx", Xdp_md)] (Some (make_unnamed_return Xdp_action)) [
445+
make_stmt (Return (Some (make_expr (Identifier "XDP_PASS") make_test_position))) make_test_position
446+
] in
447+
448+
(* Main eBPF function that calls both helper and eBPF program *)
449+
let protocol_var = make_expr (Identifier "protocol") make_test_position in
450+
let ctx_var = make_expr (Identifier "ctx") make_test_position in
451+
452+
let helper_call = make_expr (Call (make_expr (Identifier "log_packet") make_test_position, [ctx_var])) make_test_position in
453+
let program_call = make_expr (Call (make_expr (Identifier "process_tcp") make_test_position, [ctx_var])) make_test_position in
454+
let xdp_drop_const = make_expr (Identifier "XDP_DROP") make_test_position in
455+
456+
let match_arms = [
457+
{ arm_pattern = ConstantPattern (IntLit (1, None)); arm_body = SingleExpr helper_call; arm_pos = make_test_position }; (* Call helper *)
458+
{ arm_pattern = ConstantPattern (IntLit (6, None)); arm_body = SingleExpr program_call; arm_pos = make_test_position }; (* Call eBPF program *)
459+
{ arm_pattern = DefaultPattern; arm_body = SingleExpr xdp_drop_const; arm_pos = make_test_position };
460+
] in
461+
462+
let match_expr = make_expr (Match (protocol_var, match_arms)) make_test_position in
463+
464+
let packet_classifier = make_test_func "packet_classifier" [("ctx", Xdp_md)] (Some (make_unnamed_return Xdp_action)) [
465+
make_stmt (Declaration ("protocol", Some U32, Some (make_expr (Literal (IntLit (6, None))) make_test_position))) make_test_position;
466+
make_stmt (Return (Some match_expr)) make_test_position
467+
] in
468+
469+
(* Mark helper with @helper, others with @xdp *)
470+
let attr_helper = make_test_attr_func [SimpleAttribute "helper"] log_packet_helper in
471+
let attr_program = make_test_attr_func [SimpleAttribute "xdp"] process_tcp_program in
472+
let attr_classifier = make_test_attr_func [SimpleAttribute "xdp"] packet_classifier in
473+
474+
let ast = [AttributedFunction attr_helper; AttributedFunction attr_program; AttributedFunction attr_classifier] in
475+
let analysis = analyze_tail_calls ast in
476+
477+
(* Should have exactly 1 dependency: packet_classifier -> process_tcp (NOT to the helper) *)
478+
check int "should have 1 tail call dependency (only to eBPF program)" 1 (List.length analysis.dependencies);
479+
480+
(* prog_array should have size 1 (only for the eBPF program) *)
481+
check int "prog_array_size should be 1 (only eBPF program)" 1 analysis.prog_array_size;
482+
483+
(* Verify specific dependency *)
484+
let dep = List.hd analysis.dependencies in
485+
check string "caller should be packet_classifier" "packet_classifier" dep.caller;
486+
check string "target should be process_tcp (NOT helper)" "process_tcp" dep.target;
487+
488+
(* Verify mapping contents *)
489+
check bool "process_tcp should be in tail call mapping" true (Hashtbl.mem analysis.index_mapping "process_tcp");
490+
check bool "log_packet helper should NOT be in tail call mapping" false (Hashtbl.mem analysis.index_mapping "log_packet");
491+
() (* Close the function *)
492+
383493
let suite = [
384494
"test_tail_call_detection", `Quick, test_tail_call_detection;
385495
"test_program_type_compatibility", `Quick, test_program_type_compatibility;
@@ -392,6 +502,8 @@ let suite = [
392502
"nested_match_tail_calls", `Quick, test_nested_match_tail_calls;
393503
"match_with_mixed_tail_calls", `Quick, test_match_with_mixed_tail_calls;
394504
"test_tail_calls_in_if_statements", `Quick, test_tail_calls_in_if_statements;
505+
"test_helper_functions_not_tail_called", `Quick, test_helper_functions_not_tail_called;
506+
"test_mixed_helpers_and_tail_calls", `Quick, test_mixed_helpers_and_tail_calls;
395507
]
396508

397509
let () = Alcotest.run "Tail Call Tests" [("main", suite)]

0 commit comments

Comments
 (0)