diff --git a/cmd/routesum/main.go b/cmd/routesum/main.go index dcd4259..5f50df9 100644 --- a/cmd/routesum/main.go +++ b/cmd/routesum/main.go @@ -17,6 +17,11 @@ import ( func main() { cpuProfile := flag.String("cpuprofile", "", "write cpu profile to file") memProfile := flag.String("memprofile", "", "write mem profile to file") + showMemStats := flag.Bool( + "show-mem-stats", + false, + "Whether or not to write memory usage stats to STDERR. (This functionity requires the use of `unsafe`, so may not be perfect.)", //nolint: lll + ) flag.Parse() var cpuProfOut io.Writer @@ -36,7 +41,12 @@ func main() { } } - if err := summarize(os.Stdin, os.Stdout, cpuProfOut, memProfOut); err != nil { + var memStatsOut io.Writer + if *showMemStats { + memStatsOut = os.Stderr + } + + if err := summarize(os.Stdin, os.Stdout, memStatsOut, cpuProfOut, memProfOut); err != nil { fmt.Fprintf(os.Stderr, "summarize: %s\n", err.Error()) os.Exit(1) } @@ -44,7 +54,7 @@ func main() { func summarize( in io.Reader, - out, cpuProfOut io.Writer, + out, memStatsOut, cpuProfOut io.Writer, memProfOut io.WriteCloser, ) error { if cpuProfOut != nil { @@ -76,6 +86,23 @@ func summarize( } } + if memStatsOut != nil { + numInternalNodes, numLeafNodes, internalNodesTotalSize, leafNodesTotalSize := rs.MemUsage() + fmt.Fprintf(memStatsOut, + `Num internal nodes: %d +Num leaf nodes: %d +Size of all internal nodes: %d +Size of all leaf nodes: %d +Total size of data structure: %d +`, + numInternalNodes, + numLeafNodes, + internalNodesTotalSize, + leafNodesTotalSize, + internalNodesTotalSize+leafNodesTotalSize, + ) + } + for _, s := range rs.SummaryStrings() { if _, err := out.Write([]byte(s + "\n")); err != nil { return fmt.Errorf("write output: %w", err) diff --git a/cmd/routesum/main_test.go b/cmd/routesum/main_test.go index c5aa894..2689f2b 100644 --- a/cmd/routesum/main_test.go +++ b/cmd/routesum/main_test.go @@ -1,6 +1,8 @@ package main import ( + "io" + "regexp" "strings" "testing" @@ -9,12 +11,42 @@ import ( ) func TestSummarize(t *testing.T) { + tests := []struct { + name string + showMemStats bool + expected *regexp.Regexp + }{ + { + name: "without memory statistics", + showMemStats: false, + expected: regexp.MustCompile(`^$`), + }, + { + name: "with memory statistics", + showMemStats: true, + expected: regexp.MustCompile(`Num internal nodes`), + }, + } + inStr := "\n192.0.2.0\n192.0.2.1\n" - in := strings.NewReader(inStr) - var out strings.Builder - err := summarize(in, &out, nil, nil) - require.NoError(t, err, "summarize does not throw an error") + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + in := strings.NewReader(inStr) + var out strings.Builder + + var memStatsBuilder strings.Builder + var memStatsOut io.Writer + + if test.showMemStats { + memStatsOut = &memStatsBuilder + } + + err := summarize(in, &out, memStatsOut, nil, nil) + require.NoError(t, err, "summarize does not throw an error") - assert.Equal(t, "192.0.2.0/31\n", out.String(), "read expected output") + assert.Equal(t, "192.0.2.0/31\n", out.String(), "read expected output") + assert.Regexp(t, test.expected, memStatsBuilder.String(), "read expected memory stats") + }) + } } diff --git a/pkg/routesum/routesum.go b/pkg/routesum/routesum.go index 45c68ab..5502cf0 100644 --- a/pkg/routesum/routesum.go +++ b/pkg/routesum/routesum.go @@ -142,3 +142,13 @@ func ipv6FromBits(bits bitslice.BitSlice) netaddr.IP { copy(byteArray[:], bytes[0:16]) return netaddr.IPv6Raw(byteArray) } + +// MemUsage provides information about memory usage. +func (rs *RouteSum) MemUsage() (uint, uint, uintptr, uintptr) { + ipv4NumInternalNodes, ipv4NumLeafNodes, ipv4InternalNodesTotalSize, ipv4LeafNodesTotalSize := rs.ipv4.MemUsage() + ipv6NumInternalNodes, ipv6NumLeafNodes, ipv6InternalNodesTotalSize, ipv6LeafNodesTotalSize := rs.ipv6.MemUsage() + return ipv4NumInternalNodes + ipv6NumInternalNodes, + ipv4NumLeafNodes + ipv6NumLeafNodes, + ipv4InternalNodesTotalSize + ipv6InternalNodesTotalSize, + ipv4LeafNodesTotalSize + ipv6LeafNodesTotalSize +} diff --git a/pkg/routesum/routesum_test.go b/pkg/routesum/routesum_test.go index 9753c90..65d97ca 100644 --- a/pkg/routesum/routesum_test.go +++ b/pkg/routesum/routesum_test.go @@ -455,3 +455,76 @@ func TestSummarize(t *testing.T) { //nolint: funlen }) } } + +func TestRSTrieMemUsage(t *testing.T) { //nolint: funlen + tests := []struct { + name string + entries []string + expectedNumInternalNodes uint + expectedNumLeafNodes uint + }{ + { + name: "new trie", + expectedNumInternalNodes: 0, + expectedNumLeafNodes: 0, + }, + { + name: "one item, IPv4", + entries: []string{ + "192.0.2.1", + }, + expectedNumInternalNodes: 0, + expectedNumLeafNodes: 1, + }, + { + name: "two items, IPv4, summarized", + entries: []string{ + "192.0.2.1", + "192.0.2.0", + }, + expectedNumInternalNodes: 0, + expectedNumLeafNodes: 1, + }, + { + name: "two items, IPv4, unsummarized", + entries: []string{ + "192.0.2.1", + "192.0.2.2", + }, + expectedNumInternalNodes: 1, + expectedNumLeafNodes: 2, + }, + { + name: "one item, IPv6", + entries: []string{ + "2001:db8::1", + }, + expectedNumInternalNodes: 0, + expectedNumLeafNodes: 1, + }, + { + name: "one IPv4 address, one IPv6 address", + entries: []string{ + "192.0.2.0", + "2001:db8::1", + }, + expectedNumInternalNodes: 0, + expectedNumLeafNodes: 2, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + rs := NewRouteSum() + + for _, entry := range test.entries { + err := rs.InsertFromString(entry) + require.NoError(t, err) + } + + numInternalNodes, numLeafNodes, _, _ := rs.MemUsage() + assert.Equal(t, test.expectedNumInternalNodes, numInternalNodes, "num internal nodes") + assert.Equal(t, test.expectedNumLeafNodes, numLeafNodes, "num leaf nodes") + }) + } +} diff --git a/pkg/routesum/rstrie/rstrie.go b/pkg/routesum/rstrie/rstrie.go index 9f84f2d..4ea8a01 100644 --- a/pkg/routesum/rstrie/rstrie.go +++ b/pkg/routesum/rstrie/rstrie.go @@ -4,6 +4,7 @@ package rstrie import ( "bytes" "container/list" + "unsafe" "github.com/PatrickCronin/routesum/pkg/routesum/bitslice" ) @@ -216,3 +217,61 @@ func (t *RSTrie) Contents() []bitslice.BitSlice { return contents } + +func (t *RSTrie) visitAll(cb func(*node)) { + // If the trie is empty + if t.root == nil { + return + } + + // Otherwise + remainingSteps := list.New() + remainingSteps.PushFront(traversalStep{ + n: t.root, + precedingRouteBits: bitslice.BitSlice{}, + }) + + for remainingSteps.Len() > 0 { + curNode := remainingSteps.Remove(remainingSteps.Front()).(traversalStep) + + // Act on this node + cb(curNode.n) + + // Traverse the remainder of the nodes + if !curNode.n.isLeaf() { + curNodeRouteBits := bitslice.BitSlice{} + curNodeRouteBits = append(curNodeRouteBits, curNode.precedingRouteBits...) + curNodeRouteBits = append(curNodeRouteBits, curNode.n.bits...) + + remainingSteps.PushFront(traversalStep{ + n: curNode.n.children[1], + precedingRouteBits: curNodeRouteBits, + }) + remainingSteps.PushFront(traversalStep{ + n: curNode.n.children[0], + precedingRouteBits: curNodeRouteBits, + }) + } + } +} + +// MemUsage returns information about an RSTrie's current size in memory. +func (t *RSTrie) MemUsage() (uint, uint, uintptr, uintptr) { + var numInternalNodes, numLeafNodes uint + var internalNodesTotalSize, leafNodesTotalSize uintptr + + tallyNode := func(n *node) { + baseNodeSize := unsafe.Sizeof(node{}) + uintptr(cap(n.bits))*unsafe.Sizeof([1]byte{}) //nolint: exhaustruct, gosec, lll + if n.isLeaf() { + numLeafNodes++ + leafNodesTotalSize += baseNodeSize + return + } + + numInternalNodes++ + internalNodesTotalSize += baseNodeSize + unsafe.Sizeof([2]*node{}) //nolint: gosec + } + t.visitAll(tallyNode) + + return numInternalNodes, numLeafNodes, internalNodesTotalSize, leafNodesTotalSize +} diff --git a/pkg/routesum/rstrie/rstrie_test.go b/pkg/routesum/rstrie/rstrie_test.go index 9d2f0df..93f1988 100644 --- a/pkg/routesum/rstrie/rstrie_test.go +++ b/pkg/routesum/rstrie/rstrie_test.go @@ -1,10 +1,12 @@ package rstrie import ( + "net/netip" "testing" "github.com/PatrickCronin/routesum/pkg/routesum/bitslice" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCommonPrefixLen(t *testing.T) { @@ -224,3 +226,63 @@ func TestRSTrieContents(t *testing.T) { //nolint: funlen }) } } + +func TestRSTrieMemUsage(t *testing.T) { + tests := []struct { + name string + entries []string + expectedNumInternalNodes uint + expectedNumLeafNodes uint + }{ + { + name: "new trie", + expectedNumInternalNodes: 0, + expectedNumLeafNodes: 0, + }, + { + name: "one item", + entries: []string{ + "192.0.2.1", + }, + expectedNumInternalNodes: 0, + expectedNumLeafNodes: 1, + }, + { + name: "two items, summarized", + entries: []string{ + "192.0.2.1", + "192.0.2.0", + }, + expectedNumInternalNodes: 0, + expectedNumLeafNodes: 1, + }, + { + name: "two items, unsummarized", + entries: []string{ + "192.0.2.1", + "192.0.2.2", + }, + expectedNumInternalNodes: 1, + expectedNumLeafNodes: 2, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + trie := NewRSTrie() + + for _, entry := range test.entries { + ip := netip.MustParseAddr(entry) + ipBytes, err := ip.MarshalBinary() + require.NoError(t, err) + ipBits, err := bitslice.NewFromBytes(ipBytes) + require.NoError(t, err) + trie.InsertRoute(ipBits) + } + + numInternalNodes, numLeafNodes, _, _ := trie.MemUsage() + assert.Equal(t, test.expectedNumInternalNodes, numInternalNodes, "num internal nodes") + assert.Equal(t, test.expectedNumLeafNodes, numLeafNodes, "num leaf nodes") + }) + } +}