From 6f934fd253659345c17595eadccfb34c56be68c1 Mon Sep 17 00:00:00 2001 From: Nia Waldvogel Date: Wed, 31 Dec 2025 10:51:19 -0500 Subject: [PATCH] compiler: use memcmp to compare strings This replaces our runtime stringEqual and stringLess functions with calls to libc's memcmp. This has a few advantages: 1. Memory comparison functions are not duplicated between Go and C 2. LLVM can optimize the compare-to-empty case by itself, so OptimizeStringEqual is no longer needed 3. LLVM can rewrite small constant-length memcmp operations into direct loads + compares 4. Compares to constants are generally compiled into simpler IR The downside of this is that all of the comparison logic must be handled by the compiler frontend. For string equality checks, we only need to compare the lengths and memcmp if they are equal. String order checks are messier, since we need to: 1. Find the minimum length 2. Call memcmp with the minimum length 3. Compare the lengths 4. Merge the length comparison with the memory comparison --- builder/wasmbuiltins.go | 3 + compiler/compiler.go | 31 ++++++---- compiler/string.go | 83 +++++++++++++++++++++++++++ compiler/testdata/go1.21.ll | 52 ++++++++++------- compiler/testdata/string.ll | 46 +++++++++++---- interp/interp_test.go | 1 + interp/interpreter.go | 36 ++++++++++++ interp/testdata/memcmp.ll | 36 ++++++++++++ interp/testdata/memcmp.out.ll | 14 +++++ src/runtime/string.go | 34 ----------- transform/optimizer.go | 1 - transform/rtcalls.go | 28 --------- transform/rtcalls_test.go | 8 --- transform/testdata/stringequal.ll | 19 ------ transform/testdata/stringequal.out.ll | 19 ------ 15 files changed, 258 insertions(+), 153 deletions(-) create mode 100644 compiler/string.go create mode 100644 interp/testdata/memcmp.ll create mode 100644 interp/testdata/memcmp.out.ll delete mode 100644 transform/testdata/stringequal.ll delete mode 100644 transform/testdata/stringequal.out.ll diff --git a/builder/wasmbuiltins.go b/builder/wasmbuiltins.go index e08eb7fcfa..8766188e26 100644 --- a/builder/wasmbuiltins.go +++ b/builder/wasmbuiltins.go @@ -49,6 +49,9 @@ var libWasmBuiltins = Library{ "libc-top-half/musl/src/string/memmove.c", "libc-top-half/musl/src/string/memset.c", + // memcmp is used for string comparisons + "libc-top-half/musl/src/string/memcmp.c", + // exp, exp2, and log are needed for LLVM math builtin functions // like llvm.exp.*. "libc-top-half/musl/src/math/__math_divzero.c", diff --git a/compiler/compiler.go b/compiler/compiler.go index 34300af0c2..ae4bbdc72c 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -80,6 +80,7 @@ type compilerContext struct { machine llvm.TargetMachine targetData llvm.TargetData intType llvm.Type + cIntType llvm.Type dataPtrType llvm.Type // pointer in address space 0 funcPtrType llvm.Type // pointer in function address space (1 for AVR, 0 elsewhere) funcPtrAddrSpace int @@ -117,16 +118,22 @@ func newCompilerContext(moduleName string, machine llvm.TargetMachine, config *C c.dibuilder = llvm.NewDIBuilder(c.mod) } - c.uintptrType = c.ctx.IntType(c.targetData.PointerSize() * 8) - if c.targetData.PointerSize() <= 4 { + ptrSize := c.targetData.PointerSize() + c.uintptrType = c.ctx.IntType(ptrSize * 8) + if ptrSize <= 4 { // 8, 16, 32 bits targets c.intType = c.ctx.Int32Type() - } else if c.targetData.PointerSize() == 8 { + } else if ptrSize == 8 { // 64 bits target c.intType = c.ctx.Int64Type() } else { panic("unknown pointer size") } + if ptrSize < 4 { + c.cIntType = c.ctx.Int16Type() + } else { + c.cIntType = c.ctx.Int32Type() + } c.dataPtrType = llvm.PointerType(c.ctx.Int8Type(), 0) dummyFuncType := llvm.FunctionType(c.ctx.VoidType(), nil, false) @@ -2825,20 +2832,20 @@ func (b *builder) createBinOp(op token.Token, typ, ytyp types.Type, x, y llvm.Va case token.ADD: // + return b.createRuntimeCall("stringConcat", []llvm.Value{x, y}, ""), nil case token.EQL: // == - return b.createRuntimeCall("stringEqual", []llvm.Value{x, y}, ""), nil + return b.createStringEqual(x, y), nil case token.NEQ: // != - result := b.createRuntimeCall("stringEqual", []llvm.Value{x, y}, "") - return b.CreateNot(result, ""), nil + result := b.createStringEqual(x, y) + return b.CreateNot(result, "streq.not"), nil case token.LSS: // x < y - return b.createRuntimeCall("stringLess", []llvm.Value{x, y}, ""), nil + return b.createStringLess(x, y), nil case token.LEQ: // x <= y becomes NOT (y < x) - result := b.createRuntimeCall("stringLess", []llvm.Value{y, x}, "") - return b.CreateNot(result, ""), nil + result := b.createStringLess(y, x) + return b.CreateNot(result, "strlt.not"), nil case token.GTR: // x > y becomes y < x - return b.createRuntimeCall("stringLess", []llvm.Value{y, x}, ""), nil + return b.createStringLess(y, x), nil case token.GEQ: // x >= y becomes NOT (x < y) - result := b.createRuntimeCall("stringLess", []llvm.Value{x, y}, "") - return b.CreateNot(result, ""), nil + result := b.createStringLess(x, y) + return b.CreateNot(result, "strlt.not"), nil default: panic("binop on string: " + op.String()) } diff --git a/compiler/string.go b/compiler/string.go new file mode 100644 index 0000000000..6efd94aa93 --- /dev/null +++ b/compiler/string.go @@ -0,0 +1,83 @@ +package compiler + +import ( + "strconv" + + "tinygo.org/x/go-llvm" +) + +func (b *builder) createStringEqual(lhs, rhs llvm.Value) llvm.Value { + // Compare the lengths. + lhsLen := b.CreateExtractValue(lhs, 1, "streq.lhs.len") + rhsLen := b.CreateExtractValue(rhs, 1, "streq.rhs.len") + lenCmp := b.CreateICmp(llvm.IntEQ, lhsLen, rhsLen, "streq.len.eq") + + // Branch on the length comparison. + bodyCmpBlock := b.ctx.AddBasicBlock(b.llvmFn, "streq.body") + nextBlock := b.ctx.AddBasicBlock(b.llvmFn, "streq.next") + b.CreateCondBr(lenCmp, bodyCmpBlock, nextBlock) + + // Use memcmp to compare the contents if the lengths are equal. + b.SetInsertPointAtEnd(bodyCmpBlock) + lhsPtr := b.CreateExtractValue(lhs, 0, "streq.lhs.ptr") + rhsPtr := b.CreateExtractValue(rhs, 0, "streq.rhs.ptr") + memcmp := b.createMemCmp(lhsPtr, rhsPtr, lhsLen, "streq.memcmp") + memcmpEq := b.CreateICmp(llvm.IntEQ, memcmp, llvm.ConstNull(b.cIntType), "streq.memcmp.eq") + b.CreateBr(nextBlock) + + // Create a phi to join the results. + b.SetInsertPointAtEnd(nextBlock) + result := b.CreatePHI(b.ctx.Int1Type(), "") + result.AddIncoming([]llvm.Value{llvm.ConstNull(b.ctx.Int1Type()), memcmpEq}, []llvm.BasicBlock{b.currentBlockInfo.exit, bodyCmpBlock}) + b.currentBlockInfo.exit = nextBlock // adjust outgoing block for phi nodes + + return result +} + +func (b *builder) createStringLess(lhs, rhs llvm.Value) llvm.Value { + // Calculate the minimum of the two string lengths. + lhsLen := b.CreateExtractValue(lhs, 1, "strlt.lhs.len") + rhsLen := b.CreateExtractValue(rhs, 1, "strlt.rhs.len") + minFnName := "llvm.umin.i" + strconv.Itoa(b.uintptrType.IntTypeWidth()) + minFn := b.mod.NamedFunction(minFnName) + if minFn.IsNil() { + fnType := llvm.FunctionType(b.uintptrType, []llvm.Type{b.uintptrType, b.uintptrType}, false) + minFn = llvm.AddFunction(b.mod, minFnName, fnType) + } + minLen := b.CreateCall(minFn.GlobalValueType(), minFn, []llvm.Value{lhsLen, rhsLen}, "strlt.min.len") + + // Compare the common-length body. + lhsPtr := b.CreateExtractValue(lhs, 0, "strlt.lhs.ptr") + rhsPtr := b.CreateExtractValue(rhs, 0, "strlt.rhs.ptr") + memcmp := b.createMemCmp(lhsPtr, rhsPtr, minLen, "strlt.memcmp") + + // Evaluate the result as: memcmp == 0 ? lhsLen < rhsLen : memcmp < 0 + zero := llvm.ConstNull(b.cIntType) + memcmpEQ := b.CreateICmp(llvm.IntEQ, memcmp, zero, "strlt.memcmp.eq") + lenLT := b.CreateICmp(llvm.IntULT, lhsLen, rhsLen, "strlt.len.lt") + memcmpLT := b.CreateICmp(llvm.IntSLT, memcmp, zero, "strlt.memcmp.lt") + return b.CreateSelect(memcmpEQ, lenLT, memcmpLT, "strlt.result") +} + +// createMemCmp compares memory by calling the libc function memcmp. +// This function is handled specially by LLVM: +// - It can be constant-folded in some trivial cases (e.g. len 0) +// - It can be replaced with loads and compares when the length is small and known +func (b *builder) createMemCmp(lhs, rhs, len llvm.Value, name string) llvm.Value { + memcmp := b.mod.NamedFunction("memcmp") + if memcmp.IsNil() { + fnType := llvm.FunctionType(b.cIntType, []llvm.Type{b.dataPtrType, b.dataPtrType, b.uintptrType}, false) + memcmp = llvm.AddFunction(b.mod, "memcmp", fnType) + + // The memcmp call does not capture the string. + nocapture := b.ctx.CreateEnumAttribute(llvm.AttributeKindID("nocapture"), 0) + memcmp.AddAttributeAtIndex(1, nocapture) + memcmp.AddAttributeAtIndex(2, nocapture) + + // The memcmp call does not modify the string. + readonly := b.ctx.CreateEnumAttribute(llvm.AttributeKindID("readonly"), 0) + memcmp.AddAttributeAtIndex(1, readonly) + memcmp.AddAttributeAtIndex(2, readonly) + } + return b.CreateCall(memcmp.GlobalValueType(), memcmp, []llvm.Value{lhs, rhs, len}, name) +} diff --git a/compiler/testdata/go1.21.ll b/compiler/testdata/go1.21.ll index 6af9776bc3..be49379a58 100644 --- a/compiler/testdata/go1.21.ll +++ b/compiler/testdata/go1.21.ll @@ -84,14 +84,22 @@ entry: %2 = insertvalue %runtime._string zeroinitializer, ptr %b.data, 0 %3 = insertvalue %runtime._string %2, i32 %b.len, 1 %stackalloc = alloca i8, align 1 - %4 = call i1 @runtime.stringLess(ptr %a.data, i32 %a.len, ptr %b.data, i32 %b.len, ptr undef) #5 - %5 = select i1 %4, %runtime._string %1, %runtime._string %3 - %6 = select i1 %4, ptr %a.data, ptr %b.data - call void @runtime.trackPointer(ptr %6, ptr nonnull %stackalloc, ptr undef) #5 - ret %runtime._string %5 + %strlt.min.len = call i32 @llvm.umin.i32(i32 %a.len, i32 %b.len) + %strlt.memcmp = call i32 @memcmp(ptr %a.data, ptr %b.data, i32 %strlt.min.len) #5 + %strlt.memcmp.eq = icmp eq i32 %strlt.memcmp, 0 + %strlt.len.lt = icmp ult i32 %a.len, %b.len + %strlt.memcmp.lt = icmp slt i32 %strlt.memcmp, 0 + %strlt.result = select i1 %strlt.memcmp.eq, i1 %strlt.len.lt, i1 %strlt.memcmp.lt + %4 = select i1 %strlt.result, %runtime._string %1, %runtime._string %3 + %5 = select i1 %strlt.result, ptr %a.data, ptr %b.data + call void @runtime.trackPointer(ptr %5, ptr nonnull %stackalloc, ptr undef) #5 + ret %runtime._string %4 } -declare i1 @runtime.stringLess(ptr readonly, i32, ptr readonly, i32, ptr) #1 +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.umin.i32(i32, i32) #3 + +declare i32 @memcmp(ptr nocapture readonly, ptr nocapture readonly, i32) ; Function Attrs: nounwind define hidden i32 @main.maxInt(i32 %a, i32 %b, ptr %context) unnamed_addr #2 { @@ -123,11 +131,16 @@ entry: %2 = insertvalue %runtime._string zeroinitializer, ptr %b.data, 0 %3 = insertvalue %runtime._string %2, i32 %b.len, 1 %stackalloc = alloca i8, align 1 - %4 = call i1 @runtime.stringLess(ptr %b.data, i32 %b.len, ptr %a.data, i32 %a.len, ptr undef) #5 - %5 = select i1 %4, %runtime._string %1, %runtime._string %3 - %6 = select i1 %4, ptr %a.data, ptr %b.data - call void @runtime.trackPointer(ptr %6, ptr nonnull %stackalloc, ptr undef) #5 - ret %runtime._string %5 + %strlt.min.len = call i32 @llvm.umin.i32(i32 %b.len, i32 %a.len) + %strlt.memcmp = call i32 @memcmp(ptr %b.data, ptr %a.data, i32 %strlt.min.len) #5 + %strlt.memcmp.eq = icmp eq i32 %strlt.memcmp, 0 + %strlt.len.lt = icmp ult i32 %b.len, %a.len + %strlt.memcmp.lt = icmp slt i32 %strlt.memcmp, 0 + %strlt.result = select i1 %strlt.memcmp.eq, i1 %strlt.len.lt, i1 %strlt.memcmp.lt + %4 = select i1 %strlt.result, %runtime._string %1, %runtime._string %3 + %5 = select i1 %strlt.result, ptr %a.data, ptr %b.data + call void @runtime.trackPointer(ptr %5, ptr nonnull %stackalloc, ptr undef) #5 + ret %runtime._string %4 } ; Function Attrs: nounwind @@ -139,7 +152,7 @@ entry: } ; Function Attrs: nocallback nofree nounwind willreturn memory(argmem: write) -declare void @llvm.memset.p0.i32(ptr nocapture writeonly, i8, i32, i1 immarg) #3 +declare void @llvm.memset.p0.i32(ptr nocapture writeonly, i8, i32, i1 immarg) #4 ; Function Attrs: nounwind define hidden void @main.clearZeroSizedSlice(ptr %s.data, i32 %s.len, i32 %s.cap, ptr %context) unnamed_addr #2 { @@ -157,23 +170,20 @@ entry: declare void @runtime.hashmapClear(ptr dereferenceable_or_null(40), ptr) #1 ; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) -declare i32 @llvm.smin.i32(i32, i32) #4 - -; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) -declare i8 @llvm.umin.i8(i8, i8) #4 +declare i32 @llvm.smin.i32(i32, i32) #3 ; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) -declare i32 @llvm.umin.i32(i32, i32) #4 +declare i8 @llvm.umin.i8(i8, i8) #3 ; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) -declare i32 @llvm.smax.i32(i32, i32) #4 +declare i32 @llvm.smax.i32(i32, i32) #3 ; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) -declare i32 @llvm.umax.i32(i32, i32) #4 +declare i32 @llvm.umax.i32(i32, i32) #3 attributes #0 = { allockind("alloc,zeroed") allocsize(0) "alloc-family"="runtime.alloc" "target-features"="+bulk-memory,+bulk-memory-opt,+call-indirect-overlong,+mutable-globals,+nontrapping-fptoint,+sign-ext,-multivalue,-reference-types" } attributes #1 = { "target-features"="+bulk-memory,+bulk-memory-opt,+call-indirect-overlong,+mutable-globals,+nontrapping-fptoint,+sign-ext,-multivalue,-reference-types" } attributes #2 = { nounwind "target-features"="+bulk-memory,+bulk-memory-opt,+call-indirect-overlong,+mutable-globals,+nontrapping-fptoint,+sign-ext,-multivalue,-reference-types" } -attributes #3 = { nocallback nofree nounwind willreturn memory(argmem: write) } -attributes #4 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #3 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #4 = { nocallback nofree nounwind willreturn memory(argmem: write) } attributes #5 = { nounwind } diff --git a/compiler/testdata/string.ll b/compiler/testdata/string.ll index 8c95323ccf..19dae52f84 100644 --- a/compiler/testdata/string.ll +++ b/compiler/testdata/string.ll @@ -48,7 +48,7 @@ lookup.next: ; preds = %entry ret i8 %1 lookup.throw: ; preds = %entry - call void @runtime.lookupPanic(ptr undef) #3 + call void @runtime.lookupPanic(ptr undef) #4 unreachable } @@ -57,28 +57,51 @@ declare void @runtime.lookupPanic(ptr) #1 ; Function Attrs: nounwind define hidden i1 @main.stringCompareEqual(ptr readonly %s1.data, i32 %s1.len, ptr readonly %s2.data, i32 %s2.len, ptr %context) unnamed_addr #2 { entry: - %0 = call i1 @runtime.stringEqual(ptr %s1.data, i32 %s1.len, ptr %s2.data, i32 %s2.len, ptr undef) #3 + %streq.len.eq = icmp eq i32 %s1.len, %s2.len + br i1 %streq.len.eq, label %streq.body, label %streq.next + +streq.body: ; preds = %entry + %streq.memcmp = call i32 @memcmp(ptr %s1.data, ptr %s2.data, i32 %s1.len) #4 + %streq.memcmp.eq = icmp eq i32 %streq.memcmp, 0 + br label %streq.next + +streq.next: ; preds = %streq.body, %entry + %0 = phi i1 [ false, %entry ], [ %streq.memcmp.eq, %streq.body ] ret i1 %0 } -declare i1 @runtime.stringEqual(ptr readonly, i32, ptr readonly, i32, ptr) #1 +declare i32 @memcmp(ptr nocapture readonly, ptr nocapture readonly, i32) ; Function Attrs: nounwind define hidden i1 @main.stringCompareUnequal(ptr readonly %s1.data, i32 %s1.len, ptr readonly %s2.data, i32 %s2.len, ptr %context) unnamed_addr #2 { entry: - %0 = call i1 @runtime.stringEqual(ptr %s1.data, i32 %s1.len, ptr %s2.data, i32 %s2.len, ptr undef) #3 - %1 = xor i1 %0, true - ret i1 %1 + %streq.len.eq = icmp eq i32 %s1.len, %s2.len + br i1 %streq.len.eq, label %streq.body, label %streq.next + +streq.body: ; preds = %entry + %streq.memcmp = call i32 @memcmp(ptr %s1.data, ptr %s2.data, i32 %s1.len) #4 + %streq.memcmp.eq = icmp ne i32 %streq.memcmp, 0 + br label %streq.next + +streq.next: ; preds = %streq.body, %entry + %streq.not = phi i1 [ true, %entry ], [ %streq.memcmp.eq, %streq.body ] + ret i1 %streq.not } ; Function Attrs: nounwind define hidden i1 @main.stringCompareLarger(ptr readonly %s1.data, i32 %s1.len, ptr readonly %s2.data, i32 %s2.len, ptr %context) unnamed_addr #2 { entry: - %0 = call i1 @runtime.stringLess(ptr %s2.data, i32 %s2.len, ptr %s1.data, i32 %s1.len, ptr undef) #3 - ret i1 %0 + %strlt.min.len = call i32 @llvm.umin.i32(i32 %s2.len, i32 %s1.len) + %strlt.memcmp = call i32 @memcmp(ptr %s2.data, ptr %s1.data, i32 %strlt.min.len) #4 + %strlt.memcmp.eq = icmp eq i32 %strlt.memcmp, 0 + %strlt.len.lt = icmp ult i32 %s2.len, %s1.len + %strlt.memcmp.lt = icmp slt i32 %strlt.memcmp, 0 + %strlt.result = select i1 %strlt.memcmp.eq, i1 %strlt.len.lt, i1 %strlt.memcmp.lt + ret i1 %strlt.result } -declare i1 @runtime.stringLess(ptr readonly, i32, ptr readonly, i32, ptr) #1 +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.umin.i32(i32, i32) #3 ; Function Attrs: nounwind define hidden i8 @main.stringLookup(ptr readonly %s.data, i32 %s.len, i8 %x, ptr %context) unnamed_addr #2 { @@ -93,11 +116,12 @@ lookup.next: ; preds = %entry ret i8 %2 lookup.throw: ; preds = %entry - call void @runtime.lookupPanic(ptr undef) #3 + call void @runtime.lookupPanic(ptr undef) #4 unreachable } attributes #0 = { allockind("alloc,zeroed") allocsize(0) "alloc-family"="runtime.alloc" "target-features"="+bulk-memory,+bulk-memory-opt,+call-indirect-overlong,+mutable-globals,+nontrapping-fptoint,+sign-ext,-multivalue,-reference-types" } attributes #1 = { "target-features"="+bulk-memory,+bulk-memory-opt,+call-indirect-overlong,+mutable-globals,+nontrapping-fptoint,+sign-ext,-multivalue,-reference-types" } attributes #2 = { nounwind "target-features"="+bulk-memory,+bulk-memory-opt,+call-indirect-overlong,+mutable-globals,+nontrapping-fptoint,+sign-ext,-multivalue,-reference-types" } -attributes #3 = { nounwind } +attributes #3 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #4 = { nounwind } diff --git a/interp/interp_test.go b/interp/interp_test.go index cac5650879..650acdf5fc 100644 --- a/interp/interp_test.go +++ b/interp/interp_test.go @@ -25,6 +25,7 @@ func TestInterp(t *testing.T) { "interface", "revert", "alloc", + "memcmp", } { name := name // make local to this closure if name == "slice-copy" && llvmVersion < 14 { diff --git a/interp/interpreter.go b/interp/interpreter.go index 629c17d56d..d4295c7490 100644 --- a/interp/interpreter.go +++ b/interp/interpreter.go @@ -407,6 +407,42 @@ func (r *runner) run(fn *function, params []value, parentMem *memoryView, indent copy(dstBuf.buf[dst.offset():dst.offset()+nBytes], srcBuf.buf[src.offset():]) dstObj.buffer = dstBuf mem.put(dst.index(), dstObj) + case callFn.name == "memcmp": + // Compare two byte strings. + nBytes := uint32(operands[3].Uint(r)) + var cmp uint64 + if nBytes > 0 { + lhs, err := operands[1].asPointer(r) + if err != nil { + return nil, mem, r.errorAt(inst, err) + } + rhs, err := operands[2].asPointer(r) + if err != nil { + return nil, mem, r.errorAt(inst, err) + } + lhsData := mem.get(lhs.index()).buffer.asRawValue(r).buf[lhs.offset():][:nBytes] + rhsData := mem.get(rhs.index()).buffer.asRawValue(r).buf[rhs.offset():][:nBytes] + for i, left := range lhsData { + right := rhsData[i] + if left >= 256 || right >= 256 { + // Do not attempt to compare pointers. + err := r.runAtRuntime(fn, inst, locals, &mem, indent) + if err != nil { + return nil, mem, err + } + continue + } + if left != right { + if left < right { + cmp = ^uint64(0) + } else { + cmp = 1 + } + break + } + } + } + locals[inst.localIndex] = makeLiteralInt(cmp, inst.llvmInst.Type().IntTypeWidth()) case callFn.name == "runtime.typeAssert": // This function must be implemented manually as it is normally // implemented by the interface lowering pass. diff --git a/interp/testdata/memcmp.ll b/interp/testdata/memcmp.ll new file mode 100644 index 0000000000..99e5e96686 --- /dev/null +++ b/interp/testdata/memcmp.ll @@ -0,0 +1,36 @@ +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64--linux" + +@str1 = internal global [4 x i8] c"aacd" +@str2 = internal global [4 x i8] c"aazw" + +@cmpLen0 = global i32 0 +@cmp12 = global i32 0 +@cmp21 = global i32 0 +@cmp11 = global i32 0 +@cmp22 = global i32 0 +@cmp12Partial = global i32 0 +@cmp21Partial = global i32 0 + +define void @runtime.initAll() unnamed_addr { + call void @main.init() + ret void +} + +define internal void @main.init() unnamed_addr { + call void @cmpAndStore(ptr @cmpLen0, ptr null, ptr null, i64 0) + call void @cmpAndStore(ptr @cmp12, ptr @str1, ptr @str2, i64 4) + call void @cmpAndStore(ptr @cmp21, ptr @str2, ptr @str1, i64 4) + call void @cmpAndStore(ptr @cmp11, ptr @str1, ptr @str1, i64 4) + call void @cmpAndStore(ptr @cmp22, ptr @str2, ptr @str2, i64 4) + call void @cmpAndStore(ptr @cmp12Partial, ptr getelementptr inbounds (i8, ptr @str1, i32 1), ptr getelementptr inbounds (i8, ptr @str2, i32 1), i64 1) + ret void +} + +define internal void @cmpAndStore(ptr %dst, ptr %lhs, ptr %rhs, i64 %len) unnamed_addr { + %cmp = call i32 @memcmp(ptr %lhs, ptr %rhs, i64 %len) + store i32 %cmp, ptr %dst + ret void +} + +declare i32 @memcmp(ptr nocapture readonly, ptr nocapture readonly, i64) diff --git a/interp/testdata/memcmp.out.ll b/interp/testdata/memcmp.out.ll new file mode 100644 index 0000000000..67851af3db --- /dev/null +++ b/interp/testdata/memcmp.out.ll @@ -0,0 +1,14 @@ +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64--linux" + +@cmpLen0 = local_unnamed_addr global i32 0 +@cmp12 = local_unnamed_addr global i32 -1 +@cmp21 = local_unnamed_addr global i32 1 +@cmp11 = local_unnamed_addr global i32 0 +@cmp22 = local_unnamed_addr global i32 0 +@cmp12Partial = local_unnamed_addr global i32 0 +@cmp21Partial = local_unnamed_addr global i32 0 + +define void @runtime.initAll() unnamed_addr { + ret void +} diff --git a/src/runtime/string.go b/src/runtime/string.go index 1136ef94a4..c115e6ae53 100644 --- a/src/runtime/string.go +++ b/src/runtime/string.go @@ -18,40 +18,6 @@ type stringIterator struct { byteindex uintptr } -// Return true iff the strings match. -// -//go:nobounds -func stringEqual(x, y string) bool { - if len(x) != len(y) { - return false - } - for i := 0; i < len(x); i++ { - if x[i] != y[i] { - return false - } - } - return true -} - -// Return true iff x < y. -// -//go:nobounds -func stringLess(x, y string) bool { - l := len(x) - if m := len(y); m < l { - l = m - } - for i := 0; i < l; i++ { - if x[i] < y[i] { - return true - } - if x[i] > y[i] { - return false - } - } - return len(x) < len(y) -} - // Add two strings together. func stringConcat(x, y _string) _string { if x.length == 0 { diff --git a/transform/optimizer.go b/transform/optimizer.go index 54f9762bc4..5a9993b1f1 100644 --- a/transform/optimizer.go +++ b/transform/optimizer.go @@ -91,7 +91,6 @@ func Optimize(mod llvm.Module, config *compileopts.Config) []error { fmt.Fprintln(os.Stderr, pos.String()+": "+msg) }) OptimizeStringToBytes(mod) - OptimizeStringEqual(mod) } else { // Must be run at any optimization level. diff --git a/transform/rtcalls.go b/transform/rtcalls.go index 3abc1d3952..0759d3ddf6 100644 --- a/transform/rtcalls.go +++ b/transform/rtcalls.go @@ -73,34 +73,6 @@ func OptimizeStringToBytes(mod llvm.Module) { } } -// OptimizeStringEqual transforms runtime.stringEqual(...) calls into simple -// integer comparisons if at least one of the sides of the comparison is zero. -// Ths converts str == "" into len(str) == 0 and "" == "" into false. -func OptimizeStringEqual(mod llvm.Module) { - stringEqual := mod.NamedFunction("runtime.stringEqual") - if stringEqual.IsNil() { - // nothing to optimize - return - } - - builder := mod.Context().NewBuilder() - defer builder.Dispose() - - for _, call := range getUses(stringEqual) { - str1len := call.Operand(1) - str2len := call.Operand(3) - - zero := llvm.ConstInt(str1len.Type(), 0, false) - if str1len == zero || str2len == zero { - builder.SetInsertPointBefore(call) - icmp := builder.CreateICmp(llvm.IntEQ, str1len, str2len, "") - call.ReplaceAllUsesWith(icmp) - call.EraseFromParentAsInstruction() - continue - } - } -} - // OptimizeReflectImplements optimizes the following code: // // implements := someType.Implements(someInterfaceType) diff --git a/transform/rtcalls_test.go b/transform/rtcalls_test.go index 9073b0ea5b..19b8eebe33 100644 --- a/transform/rtcalls_test.go +++ b/transform/rtcalls_test.go @@ -15,14 +15,6 @@ func TestOptimizeStringToBytes(t *testing.T) { }) } -func TestOptimizeStringEqual(t *testing.T) { - t.Parallel() - testTransform(t, "testdata/stringequal", func(mod llvm.Module) { - // Run optimization pass. - transform.OptimizeStringEqual(mod) - }) -} - func TestOptimizeReflectImplements(t *testing.T) { t.Parallel() testTransform(t, "testdata/reflect-implements", func(mod llvm.Module) { diff --git a/transform/testdata/stringequal.ll b/transform/testdata/stringequal.ll deleted file mode 100644 index 0d6ed7fb20..0000000000 --- a/transform/testdata/stringequal.ll +++ /dev/null @@ -1,19 +0,0 @@ -target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64" -target triple = "armv7m-none-eabi" - -@zeroString = constant [0 x i8] zeroinitializer - -declare i1 @runtime.stringEqual(ptr, i32, ptr, i32, ptr) - -define i1 @main.stringCompareEqualConstantZero(ptr %s1.data, i32 %s1.len, ptr %context) { -entry: - %0 = call i1 @runtime.stringEqual(ptr %s1.data, i32 %s1.len, ptr @zeroString, i32 0, ptr undef) - ret i1 %0 -} - -define i1 @main.stringCompareUnequalConstantZero(ptr %s1.data, i32 %s1.len, ptr %context) { -entry: - %0 = call i1 @runtime.stringEqual(ptr %s1.data, i32 %s1.len, ptr @zeroString, i32 0, ptr undef) - %1 = xor i1 %0, true - ret i1 %1 -} diff --git a/transform/testdata/stringequal.out.ll b/transform/testdata/stringequal.out.ll deleted file mode 100644 index f2aeb95aba..0000000000 --- a/transform/testdata/stringequal.out.ll +++ /dev/null @@ -1,19 +0,0 @@ -target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64" -target triple = "armv7m-none-eabi" - -@zeroString = constant [0 x i8] zeroinitializer - -declare i1 @runtime.stringEqual(ptr, i32, ptr, i32, ptr) - -define i1 @main.stringCompareEqualConstantZero(ptr %s1.data, i32 %s1.len, ptr %context) { -entry: - %0 = icmp eq i32 %s1.len, 0 - ret i1 %0 -} - -define i1 @main.stringCompareUnequalConstantZero(ptr %s1.data, i32 %s1.len, ptr %context) { -entry: - %0 = icmp eq i32 %s1.len, 0 - %1 = xor i1 %0, true - ret i1 %1 -}