diff --git a/cmd/routesum/main.go b/cmd/routesum/main.go index 53d6f79..5f50df9 100644 --- a/cmd/routesum/main.go +++ b/cmd/routesum/main.go @@ -4,21 +4,66 @@ package main import ( "bufio" "bytes" + "flag" "fmt" "io" "os" + "runtime/pprof" "github.com/PatrickCronin/routesum/pkg/routesum" + "github.com/pkg/errors" ) func main() { - if err := summarize(os.Stdin, os.Stdout); err != nil { + cpuProfile := flag.String("cpuprofile", "", "write cpu profile to file") + memProfile := flag.String("memprofile", "", "write mem profile to file") + showMemStats := flag.Bool( + "show-mem-stats", + false, + "Whether or not to write memory usage stats to STDERR. (This functionity requires the use of `unsafe`, so may not be perfect.)", //nolint: lll + ) + flag.Parse() + + var cpuProfOut io.Writer + if *cpuProfile != "" { + var err error + if cpuProfOut, err = os.Create(*cpuProfile); err != nil { + fmt.Fprint(os.Stderr, errors.Wrap(err, "create cpu profile output file").Error()) + os.Exit(1) + } + } + + var memProfOut io.WriteCloser + if *memProfile != "" { + var err error + if memProfOut, err = os.Create(*memProfile); err != nil { + fmt.Fprint(os.Stderr, errors.Wrap(err, "create mem profile output file").Error()) + } + } + + var memStatsOut io.Writer + if *showMemStats { + memStatsOut = os.Stderr + } + + if err := summarize(os.Stdin, os.Stdout, memStatsOut, cpuProfOut, memProfOut); err != nil { fmt.Fprintf(os.Stderr, "summarize: %s\n", err.Error()) os.Exit(1) } } -func summarize(in io.Reader, out io.Writer) error { +func summarize( + in io.Reader, + out, memStatsOut, cpuProfOut io.Writer, + memProfOut io.WriteCloser, +) error { + if cpuProfOut != nil { + if err := pprof.StartCPUProfile(cpuProfOut); err != nil { + return errors.Wrap(err, "start cpu profiling") + } + defer pprof.StopCPUProfile() + } + rs := routesum.NewRouteSum() scanner := bufio.NewScanner(in) for scanner.Scan() { @@ -32,6 +77,32 @@ func summarize(in io.Reader, out io.Writer) error { } } + if memProfOut != nil { + if err := pprof.WriteHeapProfile(memProfOut); err != nil { + return errors.Wrap(err, "write mem profile") + } + if err := memProfOut.Close(); err != nil { + return errors.Wrap(err, "close mem profile") + } + } + + if memStatsOut != nil { + numInternalNodes, numLeafNodes, internalNodesTotalSize, leafNodesTotalSize := rs.MemUsage() + fmt.Fprintf(memStatsOut, + `Num internal nodes: %d +Num leaf nodes: %d +Size of all internal nodes: %d +Size of all leaf nodes: %d +Total size of data structure: %d +`, + numInternalNodes, + numLeafNodes, + internalNodesTotalSize, + leafNodesTotalSize, + internalNodesTotalSize+leafNodesTotalSize, + ) + } + for _, s := range rs.SummaryStrings() { if _, err := out.Write([]byte(s + "\n")); err != nil { return fmt.Errorf("write output: %w", err) diff --git a/cmd/routesum/main_test.go b/cmd/routesum/main_test.go index 50dc11e..2689f2b 100644 --- a/cmd/routesum/main_test.go +++ b/cmd/routesum/main_test.go @@ -1,6 +1,8 @@ package main import ( + "io" + "regexp" "strings" "testing" @@ -9,12 +11,42 @@ import ( ) func TestSummarize(t *testing.T) { + tests := []struct { + name string + showMemStats bool + expected *regexp.Regexp + }{ + { + name: "without memory statistics", + showMemStats: false, + expected: regexp.MustCompile(`^$`), + }, + { + name: "with memory statistics", + showMemStats: true, + expected: regexp.MustCompile(`Num internal nodes`), + }, + } + inStr := "\n192.0.2.0\n192.0.2.1\n" - in := strings.NewReader(inStr) - var out strings.Builder - err := summarize(in, &out) - require.NoError(t, err, "summarize does not throw an error") + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + in := strings.NewReader(inStr) + var out strings.Builder + + var memStatsBuilder strings.Builder + var memStatsOut io.Writer + + if test.showMemStats { + memStatsOut = &memStatsBuilder + } + + err := summarize(in, &out, memStatsOut, nil, nil) + require.NoError(t, err, "summarize does not throw an error") - assert.Equal(t, "192.0.2.0/31\n", out.String(), "read expected output") + assert.Equal(t, "192.0.2.0/31\n", out.String(), "read expected output") + assert.Regexp(t, test.expected, memStatsBuilder.String(), "read expected memory stats") + }) + } } diff --git a/pkg/routesum/routesum.go b/pkg/routesum/routesum.go index 45c68ab..9d7fb2f 100644 --- a/pkg/routesum/routesum.go +++ b/pkg/routesum/routesum.go @@ -7,21 +7,21 @@ import ( "strings" "github.com/PatrickCronin/routesum/pkg/routesum/bitslice" - "github.com/PatrickCronin/routesum/pkg/routesum/rstrie" + "github.com/PatrickCronin/routesum/pkg/routesum/rstree" "github.com/pkg/errors" "inet.af/netaddr" ) // RouteSum has methods supporting route summarization of networks and hosts type RouteSum struct { - ipv4, ipv6 *rstrie.RSTrie + ipv4, ipv6 *rstree.RSTree } // NewRouteSum returns an initialized RouteSum object func NewRouteSum() *RouteSum { rs := new(RouteSum) - rs.ipv4 = rstrie.NewRSTrie() - rs.ipv6 = rstrie.NewRSTrie() + rs.ipv4 = rstree.NewRSTree() + rs.ipv6 = rstree.NewRSTree() return rs } @@ -142,3 +142,13 @@ func ipv6FromBits(bits bitslice.BitSlice) netaddr.IP { copy(byteArray[:], bytes[0:16]) return netaddr.IPv6Raw(byteArray) } + +// MemUsage provides information about memory usage. +func (rs *RouteSum) MemUsage() (uint, uint, uintptr, uintptr) { + ipv4NumInternalNodes, ipv4NumLeafNodes, ipv4InternalNodesTotalSize, ipv4LeafNodesTotalSize := rs.ipv4.MemUsage() + ipv6NumInternalNodes, ipv6NumLeafNodes, ipv6InternalNodesTotalSize, ipv6LeafNodesTotalSize := rs.ipv6.MemUsage() + return ipv4NumInternalNodes + ipv6NumInternalNodes, + ipv4NumLeafNodes + ipv6NumLeafNodes, + ipv4InternalNodesTotalSize + ipv6InternalNodesTotalSize, + ipv4LeafNodesTotalSize + ipv6LeafNodesTotalSize +} diff --git a/pkg/routesum/routesum_test.go b/pkg/routesum/routesum_test.go index 9753c90..5975085 100644 --- a/pkg/routesum/routesum_test.go +++ b/pkg/routesum/routesum_test.go @@ -455,3 +455,76 @@ func TestSummarize(t *testing.T) { //nolint: funlen }) } } + +func TestMemUsage(t *testing.T) { //nolint: funlen + tests := []struct { + name string + entries []string + expectedNumInternalNodes uint + expectedNumLeafNodes uint + }{ + { + name: "new trie", + expectedNumInternalNodes: 0, + expectedNumLeafNodes: 0, + }, + { + name: "one item, IPv4", + entries: []string{ + "192.0.2.1", + }, + expectedNumInternalNodes: 32, + expectedNumLeafNodes: 1, + }, + { + name: "two items, IPv4, summarized", + entries: []string{ + "192.0.2.1", + "192.0.2.0", + }, + expectedNumInternalNodes: 31, + expectedNumLeafNodes: 1, + }, + { + name: "two items, IPv4, unsummarized", + entries: []string{ + "192.0.2.1", + "192.0.2.2", + }, + expectedNumInternalNodes: 33, + expectedNumLeafNodes: 2, + }, + { + name: "one item, IPv6", + entries: []string{ + "2001:db8::1", + }, + expectedNumInternalNodes: 128, + expectedNumLeafNodes: 1, + }, + { + name: "one IPv4 address, one IPv6 address", + entries: []string{ + "192.0.2.0", + "2001:db8::1", + }, + expectedNumInternalNodes: 160, + expectedNumLeafNodes: 2, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + rs := NewRouteSum() + + for _, entry := range test.entries { + err := rs.InsertFromString(entry) + require.NoError(t, err) + } + + numInternalNodes, numLeafNodes, _, _ := rs.MemUsage() + assert.Equal(t, test.expectedNumInternalNodes, numInternalNodes, "num internal nodes") + assert.Equal(t, test.expectedNumLeafNodes, numLeafNodes, "num leaf nodes") + }) + } +} diff --git a/pkg/routesum/rstree/rstree.go b/pkg/routesum/rstree/rstree.go new file mode 100644 index 0000000..fc1b7ce --- /dev/null +++ b/pkg/routesum/rstree/rstree.go @@ -0,0 +1,229 @@ +// Package rstree provides a datatype that supports building a space-efficient summary of networks +// and IPs. +package rstree + +import ( + "container/list" + "unsafe" + + "github.com/PatrickCronin/routesum/pkg/routesum/bitslice" +) + +type node struct { + children *[2]*node +} + +func (n *node) isLeaf() bool { + return n.children == nil +} + +// RSTree is a binary tree that supports the storage and retrieval of networks and IPs for the +// purpose of route summarization. +type RSTree struct { + root *node +} + +// NewRSTree returns an initialized RSTree for use +func NewRSTree() *RSTree { + return &RSTree{ + root: nil, + } +} + +// InsertRoute inserts a new BitSlice into the tree. Each insert results in a space-optimized tree +// structure. If a route being inserted is already covered by an existing route, it's simply +// ignored. If a route being inserted covers one or more routes already stored, those routes are +// replaced. +func (t *RSTree) InsertRoute(routeBits bitslice.BitSlice) { + // If the tree has no root node, create one. + if t.root == nil { + t.root = &node{children: nil} + + if len(routeBits) > 0 { + t.root.children = new([2]*node) + } + } + + t.root.insertRoute(routeBits) +} + +// insertRoute returns whether or not a change was made to the tree that might require upwards +// optimization. +func (n *node) insertRoute(remainingRouteBits bitslice.BitSlice) bool { + // Does the current node cover the requested route? If so, we're done. + if n.isLeaf() { + return false + } + + // Does the requested route cover the current node? If so, update the current node. + remainingRouteBitsLen := len(remainingRouteBits) + if remainingRouteBitsLen == 0 { + n.children = nil + return true + } + + // Otherwise the requested route diverges from the current node. + nextBit := remainingRouteBits[0] + if n.children[nextBit] == nil { + // As an optimization, if we would create a node only to realize it is redundant, just + // trim the rundundant child now and return. + if remainingRouteBitsLen == 1 && + n.children[^nextBit&1] != nil && + n.children[^nextBit&1].isLeaf() { + n.children = nil + return true + } + + // Otherwise we add a node + n.children[nextBit] = new(node) + if remainingRouteBitsLen > 1 { + n.children[nextBit].children = new([2]*node) + } + } + + // Traverse to the new node + if n.children[nextBit].insertRoute(remainingRouteBits[1:]) { + return n.maybeRemoveRedundantChildren() + } + + return false +} + +// A node's children are redundant if they, taken together, represent a complete subtree from the +// node's perspective. This situation can be represented more simply as the node having a nil +// children pointer. +func (n *node) maybeRemoveRedundantChildren() bool { + if n.isLeaf() { + return false + } + + if n.children[0] == nil || + !n.children[0].isLeaf() || + n.children[1] == nil || + !n.children[1].isLeaf() { + return false + } + + n.children = nil + return true +} + +type traversalStep struct { + n *node + precedingRouteBits bitslice.BitSlice +} + +// Contents returns the BitSlices contained in the RSTree. +func (t *RSTree) Contents() []bitslice.BitSlice { + // If the tree is empty + if t.root == nil { + return []bitslice.BitSlice{} + } + + // Otherwise + remainingSteps := list.New() + remainingSteps.PushFront(traversalStep{ + n: t.root, + precedingRouteBits: bitslice.BitSlice{}, + }) + + contents := []bitslice.BitSlice{} + for remainingSteps.Len() > 0 { + step := remainingSteps.Remove(remainingSteps.Front()).(traversalStep) + + if step.n.isLeaf() { + contents = append(contents, step.precedingRouteBits) + } else { + lenPrecedingRouteBits := len(step.precedingRouteBits) + + if step.n.children[1] != nil { + highChildBits := make([]byte, lenPrecedingRouteBits+1) + copy(highChildBits, step.precedingRouteBits) + highChildBits[lenPrecedingRouteBits] = 1 + remainingSteps.PushFront(traversalStep{ + n: step.n.children[1], + precedingRouteBits: highChildBits, + }) + } + + if step.n.children[0] != nil { + lowChildBits := make([]byte, lenPrecedingRouteBits+1) + copy(lowChildBits, step.precedingRouteBits) + lowChildBits[lenPrecedingRouteBits] = 0 + remainingSteps.PushFront(traversalStep{ + n: step.n.children[0], + precedingRouteBits: lowChildBits, + }) + } + } + } + + return contents +} + +func (t *RSTree) visitAll(cb func(*node)) { + // If the trie is empty + if t.root == nil { + return + } + + // Otherwise + remainingSteps := list.New() + remainingSteps.PushFront(traversalStep{ + n: t.root, + precedingRouteBits: bitslice.BitSlice{}, + }) + + for remainingSteps.Len() > 0 { + curNode := remainingSteps.Remove(remainingSteps.Front()).(traversalStep) + + // Act on this node + cb(curNode.n) + + // Traverse the remainder of the nodes + if !curNode.n.isLeaf() { + lenPrecedingRouteBits := len(curNode.precedingRouteBits) + + if curNode.n.children[1] != nil { + highChildBits := make([]byte, lenPrecedingRouteBits+1) + copy(highChildBits, curNode.precedingRouteBits) + highChildBits[lenPrecedingRouteBits] = 1 + remainingSteps.PushFront(traversalStep{ + n: curNode.n.children[1], + precedingRouteBits: highChildBits, + }) + } + + if curNode.n.children[0] != nil { + lowChildBits := make([]byte, lenPrecedingRouteBits+1) + copy(lowChildBits, curNode.precedingRouteBits) + lowChildBits[lenPrecedingRouteBits] = 0 + remainingSteps.PushFront(traversalStep{ + n: curNode.n.children[0], + precedingRouteBits: lowChildBits, + }) + } + } + } +} + +// MemUsage returns information about an RSTrie's current size in memory. +func (t *RSTree) MemUsage() (uint, uint, uintptr, uintptr) { + var numInternalNodes, numLeafNodes uint + var internalNodesTotalSize, leafNodesTotalSize uintptr + + tallyNode := func(n *node) { + baseNodeSize := unsafe.Sizeof(node{}) //nolint: exhaustruct, gosec + if n.isLeaf() { + numLeafNodes++ + leafNodesTotalSize += baseNodeSize + return + } + + numInternalNodes++ + internalNodesTotalSize += baseNodeSize + unsafe.Sizeof([2]*node{}) //nolint: gosec + } + t.visitAll(tallyNode) + + return numInternalNodes, numLeafNodes, internalNodesTotalSize, leafNodesTotalSize +} diff --git a/pkg/routesum/rstree/rstree_test.go b/pkg/routesum/rstree/rstree_test.go new file mode 100644 index 0000000..8d13621 --- /dev/null +++ b/pkg/routesum/rstree/rstree_test.go @@ -0,0 +1,156 @@ +package rstree + +import ( + "testing" + + "github.com/PatrickCronin/routesum/pkg/routesum/bitslice" + "github.com/stretchr/testify/assert" +) + +func TestRSTreeInsertRoute(t *testing.T) { //nolint: funlen + tests := []struct { + name string + routes []bitslice.BitSlice + expected *RSTree + }{ + { + name: "add one child", + routes: []bitslice.BitSlice{{0}}, + expected: &RSTree{ + root: &node{ + children: &[2]*node{ + 0: new(node), + }, + }, + }, + }, + { + name: "add two children, completing the root node's subtree", + routes: []bitslice.BitSlice{{0}, {1}}, + expected: &RSTree{ + root: &node{children: nil}, + }, + }, + { + name: "covered routes are ignored", + routes: []bitslice.BitSlice{{0}, {0, 0}}, + expected: &RSTree{ + root: &node{ + children: &[2]*node{ + 0: new(node), + }, + }, + }, + }, + { + name: "route covering node replaces it", + routes: []bitslice.BitSlice{{0, 0}, {0}}, + expected: &RSTree{ + root: &node{ + children: &[2]*node{ + 0: new(node), + }, + }, + }, + }, + { + name: "completed subtrees are simpliflied", + routes: []bitslice.BitSlice{ + {1}, + {0, 1}, + {0, 0, 1}, + {0, 0, 0}, + }, + expected: &RSTree{ + root: &node{children: nil}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tree := NewRSTree() + + for _, route := range test.routes { + tree.InsertRoute(route) + } + + assert.Equal(t, test.expected, tree, "got expected rstree") + }) + } +} + +func TestRSTreeContents(t *testing.T) { //nolint: funlen + tests := []struct { + name string + tree RSTree + expected []bitslice.BitSlice + }{ + { + name: "complete tree", + tree: RSTree{ + root: &node{children: nil}, + }, + expected: []bitslice.BitSlice{{}}, + }, + { + name: "empty tree", + tree: RSTree{ + root: nil, + }, + expected: []bitslice.BitSlice{}, + }, + { + name: "single one-child tree (0)", + tree: RSTree{ + root: &node{ + children: &[2]*node{ + 0: new(node), + }, + }, + }, + expected: []bitslice.BitSlice{{0}}, + }, + { + name: "single one-child tree (1)", + tree: RSTree{ + root: &node{ + children: &[2]*node{ + 1: new(node), + }, + }, + }, + expected: []bitslice.BitSlice{{1}}, + }, + { + name: "multi-level tree", + tree: RSTree{ + root: &node{ + children: &[2]*node{ + 0: { + children: &[2]*node{ + 0: { + children: &[2]*node{ + 0: new(node), + 1: { + children: &[2]*node{ + 0: new(node), + }, + }, + }, + }, + }, + }, + }, + }, + }, + expected: []bitslice.BitSlice{{0, 0, 0}, {0, 0, 1, 0}}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.expected, test.tree.Contents(), "got expected bits") + }) + } +} diff --git a/pkg/routesum/rstrie/rstrie.go b/pkg/routesum/rstrie/rstrie.go deleted file mode 100644 index f91870d..0000000 --- a/pkg/routesum/rstrie/rstrie.go +++ /dev/null @@ -1,196 +0,0 @@ -// Package rstrie provides a datatype that supports building a space-efficient summary of networks and IPs. -package rstrie - -import ( - "bytes" - "container/list" - - "github.com/PatrickCronin/routesum/pkg/routesum/bitslice" -) - -// RSTrie is a radix-like trie of radix 2 whose stored "words" are the binary representations of networks and IPs. An -// optimization rstrie makes over a generic radix tree is that since routes covered by other routes don't need to be -// stored, each node in the trie will have either 0 or 2 children; never 1. -type RSTrie struct { - root *node -} - -type node struct { - children *[2]*node - bits bitslice.BitSlice -} - -// NewRSTrie returns an initialized RSTrie for use -func NewRSTrie() *RSTrie { - return &RSTrie{ - root: nil, - } -} - -// InsertRoute inserts a new BitSlice into the trie. Each insert results in a space-optimized trie structure -// representing its contents. If a route being inserted is already covered by an existing route, it's simply ignored. If -// a route being inserted covers one or more routes already in the trie, those nodes are removed and replaced by the new -// route. -func (t *RSTrie) InsertRoute(routeBits bitslice.BitSlice) { - // If the trie has no root node, simply create one to store the new route - if t.root == nil { - t.root = &node{ - bits: routeBits, - children: nil, - } - return - } - - t.root.insertRoute(&t.root, routeBits) -} - -func (n *node) isLeaf() bool { - return n.children == nil -} - -// parent is a **node so that we can change what the parent is pointing to if we need to! -func (n *node) insertRoute(parent **node, remainingRouteBits bitslice.BitSlice) bool { - remainingRouteBitsLen := len(remainingRouteBits) - curNodeBitsLen := len(n.bits) - - // Does the requested route cover the current node? If so, update the current node. - if remainingRouteBitsLen <= curNodeBitsLen && bytes.HasPrefix(n.bits, remainingRouteBits) { - n.bits = remainingRouteBits - n.children = nil - return true - } - - if curNodeBitsLen <= remainingRouteBitsLen && bytes.HasPrefix(remainingRouteBits, n.bits) { - // Does the current node cover the requested route? If so, we're done. - if n.isLeaf() { - return false - } - - // Otherwise, we traverse to the correct child. - whichChild := remainingRouteBits[curNodeBitsLen] - if n.children[whichChild].insertRoute(&n.children[whichChild], remainingRouteBits[curNodeBitsLen:]) { - return n.maybeRemoveRedundantChildren() - } - - return false - } - - // Otherwise the requested route diverges from the current node. We'll need to split the current node. - - // As an optimization, if the split would result in a new node whose children represent a complete subtrie, we - // just update the current node, instead of allocating new nodes and optimizing them away immediately after. - if n.isLeaf() && - curNodeBitsLen == remainingRouteBitsLen && - commonPrefixLen(n.bits, remainingRouteBits) == len(n.bits)-1 { - n.bits = n.bits[:len(n.bits)-1] - n.children = nil - return true - } - - *parent = splitNodeForRoute(n, remainingRouteBits) - return n.maybeRemoveRedundantChildren() -} - -func commonPrefixLen(a, b bitslice.BitSlice) int { - i := 0 - maxLen := min(len(a), len(b)) - for ; i < maxLen; i++ { - if a[i] != b[i] { - break - } - } - - return i -} - -func min(a, b int) int { - if a < b { - return a - } - - return b -} - -func splitNodeForRoute(oldNode *node, routeBits bitslice.BitSlice) *node { - commonBitsLen := commonPrefixLen(oldNode.bits, routeBits) - commonBits := oldNode.bits[:commonBitsLen] - - routeNode := &node{ - bits: routeBits[commonBitsLen:], - children: nil, - } - oldNode.bits = oldNode.bits[commonBitsLen:] - - newNode := &node{ - bits: commonBits, - children: &[2]*node{}, - } - newNode.children[routeNode.bits[0]] = routeNode - newNode.children[oldNode.bits[0]] = oldNode - - return newNode -} - -// A node's children are redundant if they, taken together, represent a complete subtrie from the -// node's perspective. This situation can be represented more simply as the node having a nil -// children pointer. -func (n *node) maybeRemoveRedundantChildren() bool { - if n.isLeaf() { - return false - } - - if !n.children[0].isLeaf() || !n.children[1].isLeaf() { - return false - } - - if len(n.children[0].bits) != 1 || len(n.children[1].bits) != 1 { - return false - } - - n.children = nil - return true -} - -type traversalStep struct { - n *node - precedingRouteBits bitslice.BitSlice -} - -// Contents returns the BitSlices contained in the RSTrie. -func (t *RSTrie) Contents() []bitslice.BitSlice { - // If the trie is empty - if t.root == nil { - return []bitslice.BitSlice{} - } - - // Otherwise - remainingSteps := list.New() - remainingSteps.PushFront(traversalStep{ - n: t.root, - precedingRouteBits: bitslice.BitSlice{}, - }) - - contents := []bitslice.BitSlice{} - for remainingSteps.Len() > 0 { - step := remainingSteps.Remove(remainingSteps.Front()).(traversalStep) - - stepRouteBits := bitslice.BitSlice{} - stepRouteBits = append(stepRouteBits, step.precedingRouteBits...) - stepRouteBits = append(stepRouteBits, step.n.bits...) - - if step.n.isLeaf() { - contents = append(contents, stepRouteBits) - } else { - remainingSteps.PushFront(traversalStep{ - n: step.n.children[1], - precedingRouteBits: stepRouteBits, - }) - remainingSteps.PushFront(traversalStep{ - n: step.n.children[0], - precedingRouteBits: stepRouteBits, - }) - } - } - - return contents -} diff --git a/pkg/routesum/rstrie/rstrie_test.go b/pkg/routesum/rstrie/rstrie_test.go deleted file mode 100644 index cf2be5b..0000000 --- a/pkg/routesum/rstrie/rstrie_test.go +++ /dev/null @@ -1,238 +0,0 @@ -package rstrie - -import ( - "testing" - - "github.com/PatrickCronin/routesum/pkg/routesum/bitslice" - "github.com/stretchr/testify/assert" -) - -func TestCommonPrefixLen(t *testing.T) { - tests := []struct { - name string - a, b bitslice.BitSlice - expected int - }{ - { - name: "differing first bit", - a: bitslice.BitSlice{0}, - b: bitslice.BitSlice{1}, - expected: 0, - }, - { - name: "differing second bit", - a: bitslice.BitSlice{0, 0}, - b: bitslice.BitSlice{0, 1}, - expected: 1, - }, - { - name: "nothing different", - a: bitslice.BitSlice{0, 0, 0, 1}, - b: bitslice.BitSlice{0, 0, 0, 1}, - expected: 4, - }, - } - - for _, test := range tests { - assert.Equal( - t, - test.expected, - commonPrefixLen(test.a, test.b), - test.name, - ) - } -} - -func TestRSTrieInsertRoute(t *testing.T) { //nolint: funlen - tests := []struct { - name string - routes []bitslice.BitSlice - expected *RSTrie - }{ - { - name: "add one child", - routes: []bitslice.BitSlice{{0}}, - expected: &RSTrie{ - root: &node{ - bits: bitslice.BitSlice{0}, - children: nil, - }, - }, - }, - { - name: "add two children, completing the root node's subtrie", - routes: []bitslice.BitSlice{{0}, {1}}, - expected: &RSTrie{root: &node{ - bits: bitslice.BitSlice{}, - children: nil, - }}, - }, - { - name: "split root, root is empty", - routes: []bitslice.BitSlice{{0, 0}, {1, 1}}, - expected: &RSTrie{ - root: &node{ - bits: bitslice.BitSlice{}, - children: &[2]*node{ - 0: {bits: bitslice.BitSlice{0, 0}}, - 1: {bits: bitslice.BitSlice{1, 1}}, - }, - }, - }, - }, - { - name: "split root, root is not empty", - routes: []bitslice.BitSlice{{0, 0}, {0, 1, 0}}, - expected: &RSTrie{ - root: &node{ - bits: bitslice.BitSlice{0}, - children: &[2]*node{ - 0: {bits: bitslice.BitSlice{0}}, - 1: {bits: bitslice.BitSlice{1, 0}}, - }, - }, - }, - }, - { - name: "split root, traverse, and split internal", - routes: []bitslice.BitSlice{{0}, {1, 0, 0}, {1, 1, 0}}, - expected: &RSTrie{ - root: &node{ - bits: bitslice.BitSlice{}, - children: &[2]*node{ - 0: {bits: bitslice.BitSlice{0}}, - 1: { - bits: bitslice.BitSlice{1}, - children: &[2]*node{ - 0: {bits: bitslice.BitSlice{0, 0}}, - 1: {bits: bitslice.BitSlice{1, 0}}, - }, - }, - }, - }, - }, - }, - { - name: "covered routes are ignored", - routes: []bitslice.BitSlice{{0}, {0, 0}}, - expected: &RSTrie{ - root: &node{ - bits: bitslice.BitSlice{0}, - children: nil, - }, - }, - }, - { - name: "route covering node replaces it", - routes: []bitslice.BitSlice{{0, 0}, {0}}, - expected: &RSTrie{ - root: &node{ - bits: bitslice.BitSlice{0}, - children: nil, - }, - }, - }, - { - name: "completed subtries are simpliflied", - routes: []bitslice.BitSlice{ - {1}, - {0, 1}, - {0, 0, 1}, - {0, 0, 0}, - }, - expected: &RSTrie{root: &node{ - bits: bitslice.BitSlice{}, - children: nil, - }}, - }, - { - name: "completed subtries are simplified when new route covers current", - routes: []bitslice.BitSlice{ - {0, 0}, - {0, 1, 1}, - {0, 1}, - }, - expected: &RSTrie{root: &node{ - bits: bitslice.BitSlice{0}, - children: nil, - }}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - trie := NewRSTrie() - - for _, route := range test.routes { - trie.InsertRoute(route) - } - - assert.Equal(t, test.expected, trie, "got expected rstrie") - }) - } -} - -func TestRSTrieContents(t *testing.T) { //nolint: funlen - tests := []struct { - name string - trie RSTrie - expected []bitslice.BitSlice - }{ - { - name: "complete trie", - trie: RSTrie{ - root: &node{ - bits: nil, - children: nil, - }, - }, - expected: []bitslice.BitSlice{{}}, - }, - { - name: "empty trie", - trie: RSTrie{ - root: nil, - }, - expected: []bitslice.BitSlice{}, - }, - { - name: "single zero-child trie", - trie: RSTrie{ - root: &node{ - bits: bitslice.BitSlice{0}, - children: nil, - }, - }, - expected: []bitslice.BitSlice{{0}}, - }, - { - name: "single one-child trie", - trie: RSTrie{ - root: &node{ - bits: bitslice.BitSlice{1}, - children: nil, - }, - }, - expected: []bitslice.BitSlice{{1}}, - }, - { - name: "two-level trie", - trie: RSTrie{ - root: &node{ - bits: bitslice.BitSlice{0, 0}, - children: &[2]*node{ - 0: {bits: bitslice.BitSlice{0}}, - 1: {bits: bitslice.BitSlice{1, 0}}, - }, - }, - }, - expected: []bitslice.BitSlice{{0, 0, 0}, {0, 0, 1, 0}}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - assert.Equal(t, test.expected, test.trie.Contents(), "got expected bits") - }) - } -}