Skip to content

Commit 1593b7b

Browse files
feat:user defined struct casting
1 parent 4905649 commit 1593b7b

File tree

6 files changed

+176
-63
lines changed

6 files changed

+176
-63
lines changed

pythonbpf/allocation_pass.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,18 @@ def _allocate_for_call(
114114
# Struct constructors
115115
elif call_type in structs_sym_tab:
116116
struct_info = structs_sym_tab[call_type]
117-
var = builder.alloca(struct_info.ir_type, name=var_name)
118-
local_sym_tab[var_name] = LocalSymbol(var, struct_info.ir_type, call_type)
119-
logger.info(f"Pre-allocated {var_name} for struct {call_type}")
117+
if len(rval.args) == 0:
118+
# Zero-arg constructor: allocate the struct itself
119+
var = builder.alloca(struct_info.ir_type, name=var_name)
120+
local_sym_tab[var_name] = LocalSymbol(var, struct_info.ir_type, call_type)
121+
logger.info(f"Pre-allocated {var_name} for struct {call_type}")
122+
else:
123+
# Pointer cast: allocate as pointer to struct
124+
ptr_type = ir.PointerType(struct_info.ir_type)
125+
var = builder.alloca(ptr_type, name=var_name)
126+
var.align = 8
127+
local_sym_tab[var_name] = LocalSymbol(var, ptr_type, call_type)
128+
logger.info(f"Pre-allocated {var_name} for struct pointer cast to {call_type}")
120129

121130
elif VmlinuxHandlerRegistry.is_vmlinux_struct(call_type):
122131
# When calling struct_name(pointer), we're doing a cast, not construction

pythonbpf/assign_pass.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,23 @@ def handle_variable_assignment(
174174
f"Type mismatch: vmlinux struct pointer requires i64, got {var_type}"
175175
)
176176
return False
177+
# Handle user-defined struct pointer casts
178+
# val_type is a string (struct name), var_type is a pointer to the struct
179+
if isinstance(val_type, str) and val_type in structs_sym_tab:
180+
struct_info = structs_sym_tab[val_type]
181+
expected_ptr_type = ir.PointerType(struct_info.ir_type)
182+
183+
# Check if var_type matches the expected pointer type
184+
if isinstance(var_type, ir.PointerType):
185+
# val is already the correct pointer type from inttoptr/bitcast
186+
builder.store(val, var_ptr)
187+
logger.info(f"Assigned user-defined struct pointer cast to {var_name}")
188+
return True
189+
else:
190+
logger.error(
191+
f"Type mismatch: user-defined struct pointer cast requires pointer type, got {var_type}"
192+
)
193+
return False
177194
if isinstance(val_type, Field):
178195
logger.info("Handling assignment to struct field")
179196
# Special handling for struct_xdp_md i32 fields that are zero-extended to i64

pythonbpf/expr/expr_pass.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def _handle_boolean_op(
618618

619619

620620
# ============================================================================
621-
# VMLinux casting
621+
# Struct casting (including vmlinux struct casting)
622622
# ============================================================================
623623

624624

@@ -667,7 +667,7 @@ def _handle_vmlinux_cast(
667667
# If arg_val is an integer type, we need to inttoptr it
668668
ptr_type = ir.PointerType()
669669
# TODO: add a field value type check here
670-
print(arg_type)
670+
# print(arg_type)
671671
if isinstance(arg_type, Field):
672672
if ctypes_to_ir(arg_type.type.__name__):
673673
# Cast integer to pointer
@@ -681,6 +681,69 @@ def _handle_vmlinux_cast(
681681
return casted_ptr, vmlinux_struct_type
682682

683683

684+
def _handle_user_defined_struct_cast(
685+
func,
686+
module,
687+
builder,
688+
expr,
689+
local_sym_tab,
690+
map_sym_tab,
691+
structs_sym_tab,
692+
):
693+
"""Handle user-defined struct cast expressions like iphdr(nh).
694+
695+
This casts a pointer/integer value to a pointer to the user-defined struct,
696+
similar to how vmlinux struct casts work but for user-defined @struct types.
697+
"""
698+
if len(expr.args) != 1:
699+
logger.info("User-defined struct cast takes exactly one argument")
700+
return None
701+
702+
# Get the struct name
703+
struct_name = expr.func.id
704+
705+
if struct_name not in structs_sym_tab:
706+
logger.error(f"Struct {struct_name} not found in structs_sym_tab")
707+
return None
708+
709+
struct_info = structs_sym_tab[struct_name]
710+
711+
# Evaluate the argument (e.g.,
712+
# an address/pointer value)
713+
arg_result = eval_expr(
714+
func,
715+
module,
716+
builder,
717+
expr.args[0],
718+
local_sym_tab,
719+
map_sym_tab,
720+
structs_sym_tab,
721+
)
722+
723+
if arg_result is None:
724+
logger.info("Failed to evaluate argument to user-defined struct cast")
725+
return None
726+
727+
arg_val, arg_type = arg_result
728+
729+
# Cast the integer/pointer value to a pointer to the struct type
730+
# The struct pointer type is a pointer to the struct's IR type
731+
struct_ptr_type = ir.PointerType(struct_info.ir_type)
732+
733+
# If arg_val is an integer type (like i64), convert to pointer using inttoptr
734+
if isinstance(arg_val.type, ir.IntType):
735+
casted_ptr = builder.inttoptr(arg_val, struct_ptr_type)
736+
logger.info(f"Cast integer to pointer for struct {struct_name}")
737+
elif isinstance(arg_val.type, ir.PointerType):
738+
# If already a pointer, bitcast to the struct pointer type
739+
casted_ptr = builder.bitcast(arg_val, struct_ptr_type)
740+
logger.info(f"Bitcast pointer to struct pointer for {struct_name}")
741+
else:
742+
logger.error(f"Unsupported type for user-defined struct cast: {arg_val.type}")
743+
return None
744+
745+
return casted_ptr, struct_name
746+
684747
# ============================================================================
685748
# Expression Dispatcher
686749
# ============================================================================
@@ -726,6 +789,16 @@ def eval_expr(
726789
map_sym_tab,
727790
structs_sym_tab,
728791
)
792+
if isinstance(expr.func, ast.Name) and (expr.func.id in structs_sym_tab):
793+
return _handle_user_defined_struct_cast(
794+
func,
795+
module,
796+
builder,
797+
expr,
798+
local_sym_tab,
799+
map_sym_tab,
800+
structs_sym_tab,
801+
)
729802

730803
result = CallHandlerRegistry.handle_call(
731804
expr, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab

pythonbpf/type_deducer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
"c_int": ir.IntType(32),
2121
"c_ushort": ir.IntType(16),
2222
"c_short": ir.IntType(16),
23+
"c_ubyte": ir.IntType(8),
24+
"c_byte": ir.IntType(8),
2325
# Not so sure about this one
2426
"str": ir.PointerType(ir.IntType(8)),
2527
}

tests/c-form/xdp_test.bpf.c

Lines changed: 24 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,38 @@
1-
// xdp_ip_map.c
21
#include <linux/bpf.h>
3-
#include <bpf/bpf_helpers.h>
4-
#include <bpf/bpf_endian.h>
52
#include <linux/if_ether.h>
63
#include <linux/ip.h>
4+
#include <bpf/bpf_helpers.h>
75

8-
struct ip_key {
9-
__u8 family; // 4 = IPv4
10-
__u8 pad[3]; // padding for alignment
11-
__u8 addr[16]; // IPv4 uses first 4 bytes
6+
struct fake_iphdr {
7+
unsigned short useless;
8+
unsigned short tot_len;
9+
unsigned short id;
10+
unsigned short frag_off;
11+
unsigned char ttl;
12+
unsigned char protocol;
13+
unsigned short check;
14+
unsigned int saddr;
15+
unsigned int daddr;
1216
};
1317

14-
// key → packet count
15-
struct {
16-
__uint(type, BPF_MAP_TYPE_HASH);
17-
__uint(max_entries, 16384);
18-
__type(key, struct ip_key);
19-
__type(value, __u64);
20-
} ip_count_map SEC(".maps");
21-
2218
SEC("xdp")
23-
int xdp_ip_map(struct xdp_md *ctx)
19+
int xdp_prog(struct xdp_md *ctx)
2420
{
25-
void *data_end = (void *)(long)ctx->data_end;
26-
void *data = (void *)(long)ctx->data;
27-
struct ethhdr *eth = data;
28-
29-
if (eth + 1 > (struct ethhdr *)data_end)
30-
return XDP_PASS;
31-
32-
__u16 h_proto = eth->h_proto;
33-
void *nh = data + sizeof(*eth);
34-
35-
// VLAN handling: single tag
36-
if (h_proto == bpf_htons(ETH_P_8021Q) ||
37-
h_proto == bpf_htons(ETH_P_8021AD)) {
38-
39-
if (nh + 4 > data_end)
40-
return XDP_PASS;
41-
42-
h_proto = *(__u16 *)(nh + 2);
43-
nh += 4;
44-
}
45-
46-
struct ip_key key = {};
47-
48-
// IPv4
49-
if (h_proto == bpf_htons(ETH_P_IP)) {
50-
struct iphdr *iph = nh;
51-
if (iph + 1 > (struct iphdr *)data_end)
52-
return XDP_PASS;
53-
54-
key.family = 4;
55-
// Copy 4 bytes of IPv4 address
56-
__builtin_memcpy(key.addr, &iph->saddr, 4);
21+
void *data_end = (void *)(long)ctx->data_end;
22+
void *data = (void *)(long)ctx->data;
5723

58-
__u64 *val = bpf_map_lookup_elem(&ip_count_map, &key);
59-
if (val)
60-
(*val)++;
61-
else {
62-
__u64 init = 1;
63-
bpf_map_update_elem(&ip_count_map, &key, &init, BPF_ANY);
64-
}
24+
struct ethhdr *eth = data;
25+
if ((void *)(eth + 1) > data_end)
26+
return XDP_ABORTED;
27+
if (eth->h_proto != __constant_htons(ETH_P_IP))
28+
return XDP_PASS;
6529

66-
return XDP_PASS;
67-
}
30+
struct fake_iphdr *iph = (struct fake_iphdr *)(eth + 1);
31+
if ((void *)(iph + 1) > data_end)
32+
return XDP_ABORTED;
33+
bpf_printk("%d", iph->saddr);
6834

69-
return XDP_PASS;
35+
return XDP_PASS;
7036
}
7137

7238
char _license[] SEC("license") = "GPL";
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from vmlinux import XDP_PASS, XDP_DROP
2+
from vmlinux import (
3+
struct_xdp_md,
4+
struct_ethhdr,
5+
)
6+
from pythonbpf import bpf, section, bpfglobal, compile, compile_to_ir, struct
7+
from ctypes import c_int64, c_ubyte, c_ushort, c_uint32
8+
9+
@bpf
10+
@struct
11+
class iphdr:
12+
useless: c_ushort
13+
tot_len: c_ushort
14+
id: c_ushort
15+
frag_off: c_ushort
16+
ttl: c_ubyte
17+
protocol: c_ubyte
18+
check: c_ushort
19+
saddr: c_uint32
20+
daddr: c_uint32
21+
22+
@bpf
23+
@section("xdp")
24+
def ip_detector(ctx: struct_xdp_md) -> c_int64:
25+
data = ctx.data
26+
data_end = ctx.data_end
27+
eth = struct_ethhdr(ctx.data)
28+
nh = ctx.data + 14
29+
if nh + 20 > data_end:
30+
return c_int64(XDP_DROP)
31+
iph = iphdr(nh)
32+
h_proto = eth.h_proto
33+
h_proto_ext = c_int64(h_proto)
34+
ipv4 = iph.saddr
35+
print(f"ipaddress: {ipv4}")
36+
return c_int64(XDP_PASS)
37+
38+
39+
@bpf
40+
@bpfglobal
41+
def LICENSE() -> str:
42+
return "GPL"
43+
44+
45+
compile_to_ir("xdp_test_1.py", "xdp_test_1.ll")
46+
compile()

0 commit comments

Comments
 (0)