From dd8e1a9f066def4f5648061b0ee34f22a67c2709 Mon Sep 17 00:00:00 2001 From: Leon Hwang Date: Sat, 8 Mar 2025 22:08:38 +0800 Subject: [PATCH] compile: Fix compiling skb != 0 If there is no struct/union member access, the zeroing R0 insn is missing. So, add the missing `xor r0, r0`, then add label when necessary. Signed-off-by: Leon Hwang --- compile.go | 6 +++--- compile_test.go | 26 ++++++++++++++++++++++---- simple_test.go | 16 +++++++++++++++- 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/compile.go b/compile.go index e57803f..2f15081 100644 --- a/compile.go +++ b/compile.go @@ -372,12 +372,12 @@ func compile(expr *cc.Expr, typ btf.Type) (asm.Instructions, error) { return nil, fmt.Errorf("failed to convert operator to instructions: %w", err) } + xorR0 := asm.Xor.Reg(asm.R0, asm.R0) if labelUsed { - insns = append(insns, - asm.Mov.Imm(asm.R0, 0).WithSymbol(labelExitFail), // r0 = 0; __exit - ) + xorR0 = xorR0.WithSymbol(labelExitFail) } insns = append(insns, + xorR0, // r0 = 0 asm.Return().WithSymbol(labelReturn), // return; __return ) diff --git a/compile_test.go b/compile_test.go index dee3f40..6b80384 100644 --- a/compile_test.go +++ b/compile_test.go @@ -49,6 +49,12 @@ func getBpfProgTypeBtf(t *testing.T) *btf.Enum { return enum.(*btf.Enum) } +func getU64Btf(t *testing.T) btf.Type { + u64, err := testBtf.AnyTypeByName("u64") + test.AssertNoErr(t, err) + return u64 +} + func TestIsMemberBitfield(t *testing.T) { test.AssertFalse(t, isMemberBitfield(nil)) test.AssertTrue(t, isMemberBitfield(&btf.Member{Offset: 1, BitfieldSize: 1})) @@ -135,6 +141,17 @@ func TestExpr2offset(t *testing.T) { test.AssertFalse(t, ast.bigEndian) }) + t.Run("(u64)skb != 0", func(t *testing.T) { + expr, err := parse("skb != 0") + test.AssertNoErr(t, err) + + u64 := getU64Btf(t) + ast, err := expr2offset(expr.Left, u64) + test.AssertNoErr(t, err) + test.AssertEmptySlice(t, ast.offsets) + test.AssertTrue(t, ast.lastField == u64) + }) + t.Run("skb->len > 1024", func(t *testing.T) { expr, err := parse("skb->len > 1024") test.AssertNoErr(t, err) @@ -661,6 +678,7 @@ func TestCompile(t *testing.T) { asm.Mov.Reg(asm.R3, asm.R1), asm.Mov.Imm(asm.R0, 1), asm.JNE.Imm(asm.R3, 0, labelReturn), + asm.Xor.Reg(asm.R0, asm.R0), asm.Return().WithSymbol(labelReturn), }) }) @@ -701,7 +719,7 @@ func TestCompile(t *testing.T) { asm.RSh.Imm(asm.R3, 32), asm.Mov.Imm(asm.R0, 1), asm.JEq.Imm(asm.R3, 9, labelReturn), - asm.Mov.Imm(asm.R0, 0).WithSymbol(labelExitFail), + asm.Xor.Reg(asm.R0, asm.R0).WithSymbol(labelExitFail), asm.Return().WithSymbol(labelReturn), }) }) @@ -719,12 +737,12 @@ var skbLen1024Insns = asm.Instructions{ asm.RSh.Imm(asm.R3, 32), asm.Mov.Imm(asm.R0, 1), asm.JGT.Imm(asm.R3, 1024, labelReturn), - asm.Mov.Imm(asm.R0, 0).WithSymbol(labelExitFail), + asm.Xor.Reg(asm.R0, asm.R0).WithSymbol(labelExitFail), asm.Return().WithSymbol(labelReturn), } func cloneSkbLen1024InsnsWithoutExitLabel() asm.Instructions { insns := slices.Clone(skbLen1024Insns) - insns[len(insns)-2] = insns[len(insns)-1] - return insns[:len(insns)-1] + insns[len(insns)-2] = insns[len(insns)-2].WithMetadata(asm.Metadata{}) + return insns } diff --git a/simple_test.go b/simple_test.go index 0551aca..fdc2f15 100644 --- a/simple_test.go +++ b/simple_test.go @@ -7,6 +7,8 @@ import ( "testing" "github.com/cilium/ebpf" + "github.com/cilium/ebpf/asm" + "github.com/leonhwangprojects/bice/internal/test" ) @@ -29,11 +31,23 @@ func TestSimpleCompile(t *testing.T) { test.AssertStrPrefix(t, err.Error(), "failed to compile expression") }) - t.Run("success", func(t *testing.T) { + t.Run("(struct sk_buff *)skb->len > 1024", func(t *testing.T) { insns, err := SimpleCompile("skb->len > 1024", getSkbBtf(t)) test.AssertNoErr(t, err) test.AssertEqualSlice(t, insns, cloneSkbLen1024InsnsWithoutExitLabel()) }) + + t.Run("(u64)skb != 0", func(t *testing.T) { + insns, err := SimpleCompile("skb != 0", getU64Btf(t)) + test.AssertNoErr(t, err) + test.AssertEqualSlice(t, insns, asm.Instructions{ + asm.Mov.Reg(asm.R3, asm.R1), + asm.Mov.Imm(asm.R0, 1), + asm.JNE.Imm(asm.R3, 0, labelReturn), + asm.Xor.Reg(asm.R0, asm.R0), + asm.Return().WithSymbol(labelReturn), + }) + }) } func TestSimpleInjectFilter(t *testing.T) {