Skip to content

Commit 2fd2a46

Browse files
committed
Add LocalSymbol dataclass
1 parent 1a66887 commit 2fd2a46

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

pythonbpf/functions_pass.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
11
from llvmlite import ir
22
import ast
3+
import logging
34
from typing import Any
5+
from dataclasses import dataclass
46

57
from .helper import HelperHandlerRegistry, handle_helper_call
68
from .type_deducer import ctypes_to_ir
79
from .binary_ops import handle_binary_op
810
from .expr_pass import eval_expr, handle_expr
911

1012
local_var_metadata: dict[str | Any, Any] = {}
13+
logger = logging.getLogger(__name__)
14+
15+
16+
@dataclass
17+
class LocalSymbol:
18+
var: ir.AllocaInstr
19+
ir_type: ir.Type
20+
metadata: Any = None
1121

1222

1323
def get_probe_string(func_node):
@@ -83,16 +93,19 @@ def handle_assign(
8393
elif isinstance(rval, ast.Constant):
8494
if isinstance(rval.value, bool):
8595
if rval.value:
86-
builder.store(ir.Constant(ir.IntType(1), 1), local_sym_tab[var_name][0])
96+
builder.store(ir.Constant(ir.IntType(1), 1),
97+
local_sym_tab[var_name][0])
8798
else:
88-
builder.store(ir.Constant(ir.IntType(1), 0), local_sym_tab[var_name][0])
99+
builder.store(ir.Constant(ir.IntType(1), 0),
100+
local_sym_tab[var_name][0])
89101
print(f"Assigned constant {rval.value} to {var_name}")
90102
elif isinstance(rval.value, int):
91103
# Assume c_int64 for now
92104
# var = builder.alloca(ir.IntType(64), name=var_name)
93105
# var.align = 8
94106
builder.store(
95-
ir.Constant(ir.IntType(64), rval.value), local_sym_tab[var_name][0]
107+
ir.Constant(ir.IntType(64),
108+
rval.value), local_sym_tab[var_name][0]
96109
)
97110
# local_sym_tab[var_name] = var
98111
print(f"Assigned constant {rval.value} to {var_name}")
@@ -107,7 +120,8 @@ def handle_assign(
107120
global_str.linkage = "internal"
108121
global_str.global_constant = True
109122
global_str.initializer = str_const
110-
str_ptr = builder.bitcast(global_str, ir.PointerType(ir.IntType(8)))
123+
str_ptr = builder.bitcast(
124+
global_str, ir.PointerType(ir.IntType(8)))
111125
builder.store(str_ptr, local_sym_tab[var_name][0])
112126
print(f"Assigned string constant '{rval.value}' to {var_name}")
113127
else:
@@ -126,7 +140,8 @@ def handle_assign(
126140
# var = builder.alloca(ir_type, name=var_name)
127141
# var.align = ir_type.width // 8
128142
builder.store(
129-
ir.Constant(ir_type, rval.args[0].value), local_sym_tab[var_name][0]
143+
ir.Constant(
144+
ir_type, rval.args[0].value), local_sym_tab[var_name][0]
130145
)
131146
print(
132147
f"Assigned {call_type} constant "
@@ -172,7 +187,8 @@ def handle_assign(
172187
ir_type = struct_info.ir_type
173188
# var = builder.alloca(ir_type, name=var_name)
174189
# Null init
175-
builder.store(ir.Constant(ir_type, None), local_sym_tab[var_name][0])
190+
builder.store(ir.Constant(ir_type, None),
191+
local_sym_tab[var_name][0])
176192
local_var_metadata[var_name] = call_type
177193
print(f"Assigned struct {call_type} to {var_name}")
178194
# local_sym_tab[var_name] = var
@@ -243,7 +259,8 @@ def handle_cond(func, module, builder, cond, local_sym_tab, map_sym_tab):
243259
print(f"Undefined variable {cond.id} in condition")
244260
return None
245261
elif isinstance(cond, ast.Compare):
246-
lhs = eval_expr(func, module, builder, cond.left, local_sym_tab, map_sym_tab)[0]
262+
lhs = eval_expr(func, module, builder, cond.left,
263+
local_sym_tab, map_sym_tab)[0]
247264
if len(cond.ops) != 1 or len(cond.comparators) != 1:
248265
print("Unsupported complex comparison")
249266
return None
@@ -296,7 +313,8 @@ def handle_if(
296313
else:
297314
else_block = None
298315

299-
cond = handle_cond(func, module, builder, stmt.test, local_sym_tab, map_sym_tab)
316+
cond = handle_cond(func, module, builder, stmt.test,
317+
local_sym_tab, map_sym_tab)
300318
if else_block:
301319
builder.cbranch(cond, then_block, else_block)
302320
else:
@@ -441,7 +459,8 @@ def allocate_mem(
441459
ir_type = ctypes_to_ir(call_type)
442460
var = builder.alloca(ir_type, name=var_name)
443461
var.align = ir_type.width // 8
444-
print(f"Pre-allocated variable {var_name} of type {call_type}")
462+
print(
463+
f"Pre-allocated variable {var_name} of type {call_type}")
445464
elif HelperHandlerRegistry.has_handler(call_type):
446465
# Assume return type is int64 for now
447466
ir_type = ir.IntType(64)
@@ -662,7 +681,8 @@ def _expr_type(e):
662681
if found_type is None:
663682
found_type = t
664683
elif found_type != t:
665-
raise ValueError("Conflicting return types:" f"{found_type} vs {t}")
684+
raise ValueError("Conflicting return types:" f"{
685+
found_type} vs {t}")
666686
return found_type or "None"
667687

668688

@@ -699,7 +719,8 @@ def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_l
699719
char = builder.load(src_ptr)
700720

701721
# Store character in target
702-
dst_ptr = builder.gep(target_array_ptr, [ir.Constant(ir.IntType(32), 0), idx])
722+
dst_ptr = builder.gep(
723+
target_array_ptr, [ir.Constant(ir.IntType(32), 0), idx])
703724
builder.store(char, dst_ptr)
704725

705726
# Increment counter
@@ -710,5 +731,6 @@ def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_l
710731

711732
# Ensure null termination
712733
last_idx = ir.Constant(ir.IntType(32), array_length - 1)
713-
null_ptr = builder.gep(target_array_ptr, [ir.Constant(ir.IntType(32), 0), last_idx])
734+
null_ptr = builder.gep(
735+
target_array_ptr, [ir.Constant(ir.IntType(32), 0), last_idx])
714736
builder.store(ir.Constant(ir.IntType(8), 0), null_ptr)

pythonbpf/maps/maps_pass.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def create_bpf_map(module, map_name, map_params):
8585

8686

8787
def create_map_debug_info(module, map_global, map_name, map_params):
88-
"""Generate debug information metadata for BPF maps HASH and PERF_EVENT_ARRAY"""
88+
"""Generate debug info metadata for BPF maps HASH and PERF_EVENT_ARRAY"""
8989
generator = DebugInfoGenerator(module)
9090

9191
uint_type = generator.get_uint32_type()
@@ -158,15 +158,17 @@ def create_ringbuf_debug_info(module, map_global, map_name, map_params):
158158
type_ptr = generator.create_pointer_type(type_array, 64)
159159
type_member = generator.create_struct_member("type", type_ptr, 0)
160160

161-
max_entries_array = generator.create_array_type(int_type, map_params["max_entries"])
161+
max_entries_array = generator.create_array_type(
162+
int_type, map_params["max_entries"])
162163
max_entries_ptr = generator.create_pointer_type(max_entries_array, 64)
163164
max_entries_member = generator.create_struct_member(
164165
"max_entries", max_entries_ptr, 64
165166
)
166167

167168
elements_arr = [type_member, max_entries_member]
168169

169-
struct_type = generator.create_struct_type(elements_arr, 128, is_distinct=True)
170+
struct_type = generator.create_struct_type(
171+
elements_arr, 128, is_distinct=True)
170172

171173
global_var = generator.create_global_var_debug_info(
172174
map_name, struct_type, is_local=False

0 commit comments

Comments
 (0)