Skip to content

Commit 0bb0472

Browse files
committed
Introduce New and Delete constructs for object management, updating the parser, type checker, and IR generation accordingly. Enhance code generation for malloc and free operations in userspace.
1 parent c9d22b3 commit 0bb0472

16 files changed

+613
-296
lines changed

examples/object_allocation.ks

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// Simple XDP packet inspector with object allocation
2+
// Demonstrates new/delete for connection tracking
3+
4+
// XDP context struct (from BTF)
5+
struct xdp_md {
6+
data: u64,
7+
data_end: u64,
8+
data_meta: u64,
9+
ingress_ifindex: u32,
10+
rx_queue_index: u32,
11+
egress_ifindex: u32,
12+
}
13+
14+
// XDP action enum (from BTF)
15+
enum xdp_action {
16+
XDP_ABORTED = 0,
17+
XDP_DROP = 1,
18+
XDP_PASS = 2,
19+
XDP_REDIRECT = 3,
20+
XDP_TX = 4,
21+
}
22+
23+
struct ConnStats {
24+
packet_count: u64,
25+
byte_count: u64,
26+
first_seen: u64,
27+
last_seen: u64,
28+
}
29+
30+
// Map to store connection statistics
31+
map<u32, *ConnStats> conn_tracker : HashMap(1024)
32+
33+
@xdp fn packet_inspector(ctx: *xdp_md) -> xdp_action {
34+
// Simple source IP extraction (in real code, would parse ethernet/IP headers)
35+
var src_ip: u32 = 0x08080808 // Simulated source IP
36+
var packet_size: u32 = 64 // Simulated packet size
37+
38+
// Look up existing connection stats
39+
var stats = conn_tracker[src_ip]
40+
41+
if (stats == none) {
42+
// First packet from this IP - allocate new stats object
43+
stats = new ConnStats()
44+
if (stats == null) {
45+
return XDP_DROP // Allocation failed
46+
}
47+
48+
// Initialize new connection stats
49+
stats->packet_count = 1
50+
stats->byte_count = packet_size
51+
stats->first_seen = 12345 // Fake timestamp
52+
stats->last_seen = 12345
53+
54+
// Store in map
55+
conn_tracker[src_ip] = stats
56+
} else {
57+
// Update existing stats
58+
stats->packet_count = stats->packet_count + 1
59+
stats->byte_count = stats->byte_count + packet_size
60+
stats->last_seen = 12346 // Updated timestamp
61+
}
62+
63+
// Simple rate limiting: drop if too many packets
64+
if (stats->packet_count > 100) {
65+
return XDP_DROP
66+
}
67+
68+
return XDP_PASS
69+
}
70+
71+
fn main() -> i32 {
72+
// Test userspace allocation
73+
var test_stats = new ConnStats()
74+
if (test_stats == null) {
75+
return 1
76+
}
77+
78+
test_stats->packet_count = 42
79+
test_stats->byte_count = 2048
80+
81+
// Clean up
82+
delete test_stats
83+
84+
// Load and attach the XDP program
85+
var prog = load(packet_inspector)
86+
attach(prog, "eth0", 0)
87+
88+
return 0
89+
}

src/ast.ml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ and expr_desc =
170170
| UnaryOp of unary_op * expr
171171
| StructLiteral of string * (string * expr) list
172172
| Match of expr * match_arm list (* match (expr) { arms } *)
173+
| New of bpf_type (* new Type() - object allocation *)
173174

174175
(** Module function call *)
175176
and module_call = {
@@ -218,13 +219,18 @@ and stmt_desc =
218219
| For of string * expr * expr * statement list
219220
| ForIter of string * string * expr * statement list (* for (index, value) in expr.iter() { ... } *)
220221
| While of expr * statement list
221-
| Delete of expr * expr (* delete map[key] *)
222+
| Delete of delete_target (* Unified delete: map[key] or pointer *)
222223
| Break
223224
| Continue
224225
| Try of statement list * catch_clause list (* try { statements } catch clauses *)
225226
| Throw of expr (* throw integer_expression *)
226227
| Defer of expr (* defer function_call *)
227228

229+
(** Delete target - either map entry or object pointer *)
230+
and delete_target =
231+
| DeleteMapEntry of expr * expr (* delete map[key] *)
232+
| DeletePointer of expr (* delete ptr *)
233+
228234
(** Catch clause definition *)
229235
and catch_clause = {
230236
catch_pattern: catch_pattern;
@@ -677,6 +683,7 @@ let rec string_of_expr expr =
677683
| Match (expr, arms) ->
678684
let arms_str = String.concat ",\n " (List.map string_of_match_arm arms) in
679685
Printf.sprintf "match (%s) {\n %s\n}" (string_of_expr expr) arms_str
686+
| New typ -> Printf.sprintf "new %s()" (string_of_bpf_type typ)
680687

681688
and string_of_match_pattern = function
682689
| ConstantPattern lit -> string_of_literal lit
@@ -746,8 +753,10 @@ and string_of_stmt stmt =
746753
| While (cond, body) ->
747754
let body_str = String.concat " " (List.map string_of_stmt body) in
748755
Printf.sprintf "while (%s) { %s }" (string_of_expr cond) body_str
749-
| Delete (map_expr, key_expr) ->
756+
| Delete (DeleteMapEntry (map_expr, key_expr)) ->
750757
Printf.sprintf "delete %s[%s];" (string_of_expr map_expr) (string_of_expr key_expr)
758+
| Delete (DeletePointer ptr_expr) ->
759+
Printf.sprintf "delete %s;" (string_of_expr ptr_expr)
751760
| Break -> "break;"
752761
| Continue -> "continue;"
753762
| Try (statements, catch_clauses) ->

src/ebpf_c_codegen.ml

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,10 @@ let rec collect_string_sizes_from_instr ir_instr =
475475
acc @ (collect_string_sizes_from_value arg)) [] args
476476
| IRStructOpsRegister (instance_val, struct_ops_val) ->
477477
(collect_string_sizes_from_value instance_val) @ (collect_string_sizes_from_value struct_ops_val)
478+
| IRObjectNew (dest_val, _) ->
479+
collect_string_sizes_from_value dest_val
480+
| IRObjectDelete ptr_val ->
481+
collect_string_sizes_from_value ptr_val
478482

479483
let collect_string_sizes_from_function ir_func =
480484
List.fold_left (fun acc block ->
@@ -1125,6 +1129,35 @@ let generate_includes ctx ?(program_types=[]) ?(include_builtin_headers=false) (
11251129
(* For non-kprobe programs, use standard processing *)
11261130
let all_includes = builtin_includes @ standard_includes @ unique_context_includes @ base_type_includes in
11271131
List.iter (emit_line ctx) all_includes;
1132+
emit_blank_line ctx;
1133+
1134+
(* Use proper kernel implementation: extern declarations and macros *)
1135+
emit_line ctx "extern void *bpf_obj_new_impl(__u64 local_type_id__k, void *meta__ign) __ksym;";
1136+
emit_line ctx "extern void bpf_obj_drop_impl(void *p__alloc, void *meta__ign) __ksym;";
1137+
emit_blank_line ctx;
1138+
1139+
(* Use exact kernel implementation for proper typeof handling *)
1140+
emit_line ctx "#define ___concat(a, b) a ## b";
1141+
emit_line ctx "#ifdef __clang__";
1142+
emit_line ctx "#define ___bpf_typeof(type) ((typeof(type) *) 0)";
1143+
emit_line ctx "#else";
1144+
emit_line ctx "#define ___bpf_typeof1(type, NR) ({ \\";
1145+
emit_line ctx " extern typeof(type) *___concat(bpf_type_tmp_, NR); \\";
1146+
emit_line ctx " ___concat(bpf_type_tmp_, NR); \\";
1147+
emit_line ctx "})";
1148+
emit_line ctx "#define ___bpf_typeof(type) ___bpf_typeof1(type, __COUNTER__)";
1149+
emit_line ctx "#endif";
1150+
emit_blank_line ctx;
1151+
1152+
(* Add BPF_TYPE_ID_LOCAL constant *)
1153+
emit_line ctx "#ifndef BPF_TYPE_ID_LOCAL";
1154+
emit_line ctx "#define BPF_TYPE_ID_LOCAL 1";
1155+
emit_line ctx "#endif";
1156+
emit_blank_line ctx;
1157+
1158+
emit_line ctx "#define bpf_core_type_id_kernel(type) __builtin_btf_type_id(*(type*)0, 0)";
1159+
emit_line ctx "#define bpf_obj_new(type) ((type *)bpf_obj_new_impl(bpf_core_type_id_kernel(type), NULL))";
1160+
emit_line ctx "#define bpf_obj_drop(ptr) bpf_obj_drop_impl(ptr, NULL)";
11281161
emit_blank_line ctx
11291162
)
11301163

@@ -2611,6 +2644,10 @@ let rec generate_c_instruction ctx ir_instr =
26112644
Option.iter collect_in_value result_val_opt
26122645
| IRTailCall (_, args, _) ->
26132646
List.iter collect_in_value args
2647+
| IRObjectNew (dest_val, _) ->
2648+
collect_in_value dest_val
2649+
| IRObjectDelete ptr_val ->
2650+
collect_in_value ptr_val
26142651
| IRJump _ | IRComment _ | IRBreak | IRContinue | IRThrow _ -> ()
26152652
in
26162653
collect_in_instr ir_instr
@@ -2813,6 +2850,18 @@ let rec generate_c_instruction ctx ir_instr =
28132850
(* For eBPF, struct_ops registration is handled by userspace loader *)
28142851
emit_line ctx (sprintf "/* struct_ops_register - handled by userspace */")
28152852

2853+
| IRObjectNew (dest_val, obj_type) ->
2854+
let dest_str = generate_c_value ctx dest_val in
2855+
let type_str = ebpf_type_from_ir_type obj_type in
2856+
(* Use proper kernel pattern: ptr = bpf_obj_new(type) *)
2857+
emit_line ctx (sprintf "%s = bpf_obj_new(%s);" dest_str type_str)
2858+
2859+
| IRObjectDelete ptr_val ->
2860+
let ptr_str = generate_c_value ctx ptr_val in
2861+
(* Use the proper kernel bpf_obj_drop(ptr) macro *)
2862+
emit_line ctx (sprintf "bpf_obj_drop(%s);" ptr_str)
2863+
2864+
28162865
(** Generate C code for basic block *)
28172866

28182867
let generate_c_basic_block ctx ir_block =
@@ -3016,6 +3065,10 @@ let collect_registers_in_function ir_func =
30163065
List.iter collect_in_value args
30173066
| IRStructOpsRegister (instance_val, struct_ops_val) ->
30183067
collect_in_value instance_val; collect_in_value struct_ops_val
3068+
| IRObjectNew (dest_val, _) ->
3069+
collect_in_value dest_val
3070+
| IRObjectDelete ptr_val ->
3071+
collect_in_value ptr_val
30193072
in
30203073
List.iter (fun block ->
30213074
List.iter collect_in_instr block.instructions

src/evaluator.ml

Lines changed: 59 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -716,74 +716,44 @@ and eval_expression ctx expr =
716716
StructValue field_values
717717

718718
| Match (matched_expr, arms) ->
719-
(* Evaluate the matched expression *)
720719
let matched_value = eval_expression ctx matched_expr in
721-
722-
(* Find the matching arm and evaluate its expression *)
723-
let rec find_matching_arm = function
724-
| [] -> eval_error "No matching arm found in match expression" expr.expr_pos
725-
| arm :: rest_arms ->
726-
(match arm.arm_pattern with
727-
| ConstantPattern lit ->
728-
let pattern_value = runtime_value_of_literal lit in
729-
if runtime_values_equal matched_value pattern_value then
730-
(match arm.arm_body with
731-
| SingleExpr expr -> eval_expression ctx expr
732-
| Block stmts ->
733-
let rec eval_stmts = function
734-
| [] -> IntValue 0 (* Default value if no return *)
735-
| stmt :: rest ->
736-
(try
737-
eval_statement ctx stmt;
738-
eval_stmts rest
739-
with Return_value value -> value)
740-
in
741-
eval_stmts stmts)
742-
else
743-
find_matching_arm rest_arms
744-
| IdentifierPattern name ->
745-
(* For identifier patterns, try to get the value from symbol table *)
746-
(match ctx.symbol_table with
747-
| Some symbol_table ->
748-
(match Symbol_table.lookup_symbol symbol_table name with
749-
| Some { kind = Symbol_table.EnumConstant (_, Some value); _ } ->
750-
let pattern_value = EnumValue (name, value) in
751-
if runtime_values_equal matched_value pattern_value then
752-
(match arm.arm_body with
753-
| SingleExpr expr -> eval_expression ctx expr
754-
| Block stmts ->
755-
let rec eval_stmts = function
756-
| [] -> IntValue 0 (* Default value if no return *)
757-
| stmt :: rest ->
758-
(try
759-
eval_statement ctx stmt;
760-
eval_stmts rest
761-
with Return_value value -> value)
762-
in
763-
eval_stmts stmts)
764-
else
765-
find_matching_arm rest_arms
766-
| _ ->
767-
(* Pattern not found or not an enum constant, treat as wildcard *)
768-
find_matching_arm rest_arms)
769-
| None ->
770-
find_matching_arm rest_arms)
771-
| DefaultPattern ->
772-
(* Default pattern always matches *)
773-
(match arm.arm_body with
774-
| SingleExpr expr -> eval_expression ctx expr
775-
| Block stmts ->
776-
let rec eval_stmts = function
777-
| [] -> IntValue 0 (* Default value if no return *)
778-
| stmt :: rest ->
779-
(try
780-
eval_statement ctx stmt;
781-
eval_stmts rest
782-
with Return_value value -> value)
783-
in
784-
eval_stmts stmts))
720+
let rec try_arms = function
721+
| [] -> eval_error "No matching pattern in match expression" expr.expr_pos
722+
| arm :: remaining_arms ->
723+
let pattern_matches = match arm.arm_pattern with
724+
| ConstantPattern lit ->
725+
let literal_value = runtime_value_of_literal lit in
726+
runtime_values_equal matched_value literal_value
727+
| IdentifierPattern name ->
728+
(* Check if this is an enum constant *)
729+
(match ctx.symbol_table with
730+
| Some symbol_table ->
731+
(match Symbol_table.lookup_symbol symbol_table name with
732+
| Some { kind = Symbol_table.EnumConstant (_, Some value); _ } ->
733+
(match matched_value with
734+
| EnumValue (_, matched_val) -> matched_val = value
735+
| IntValue matched_val -> matched_val = value
736+
| _ -> false)
737+
| _ -> false)
738+
| None -> false)
739+
| DefaultPattern -> true
740+
in
741+
742+
if pattern_matches then
743+
match arm.arm_body with
744+
| SingleExpr arm_expr -> eval_expression ctx arm_expr
745+
| Block arm_stmts ->
746+
eval_statements ctx arm_stmts;
747+
UnitValue (* Default return for block *)
748+
else
749+
try_arms remaining_arms
785750
in
786-
find_matching_arm arms
751+
try_arms arms
752+
753+
| New _ ->
754+
(* For evaluator, object allocation returns a mock pointer value *)
755+
(* This is just for testing - real allocation happens in generated code *)
756+
PointerValue (Random.int 1000000)
787757

788758
(** Evaluate statements *)
789759
and eval_statements ctx stmts =
@@ -968,24 +938,29 @@ and eval_statement ctx stmt =
968938
in
969939
loop ()
970940

971-
| Delete (map_expr, key_expr) ->
972-
let map_name = match map_expr.expr_desc with
973-
| Identifier name -> name
974-
| _ -> eval_error ("Delete requires a map identifier") stmt.stmt_pos
975-
in
976-
let key_result = eval_expression ctx key_expr in
977-
978-
(* Get the map storage *)
979-
let map_store =
980-
try Hashtbl.find ctx.map_storage map_name
981-
with Not_found -> eval_error ("Map not found: " ^ map_name) stmt.stmt_pos
982-
in
983-
984-
(* Perform the actual delete operation *)
985-
let key_str = string_of_runtime_value key_result in
986-
let existed = Hashtbl.mem map_store key_str in
987-
if existed then
988-
Hashtbl.remove map_store key_str
941+
| Delete target ->
942+
(match target with
943+
| DeleteMapEntry (map_expr, key_expr) ->
944+
let map_name = match map_expr.expr_desc with
945+
| Identifier name -> name
946+
| _ -> eval_error ("Delete requires a map identifier") stmt.stmt_pos
947+
in
948+
let key_result = eval_expression ctx key_expr in
949+
950+
(* Get the map storage *)
951+
let map_store =
952+
try Hashtbl.find ctx.map_storage map_name
953+
with Not_found -> eval_error ("Map not found: " ^ map_name) stmt.stmt_pos
954+
in
955+
956+
(* Perform the actual delete operation *)
957+
let key_str = string_of_runtime_value key_result in
958+
let existed = Hashtbl.mem map_store key_str in
959+
if existed then
960+
Hashtbl.remove map_store key_str
961+
| DeletePointer _ptr_expr ->
962+
(* For evaluator, pointer deletion is a no-op since we don't have real memory management *)
963+
())
989964

990965
| Break ->
991966
raise Break_loop

0 commit comments

Comments
 (0)