diff --git a/.chloggen/tail-sampling-rare-spans-policy.yaml b/.chloggen/tail-sampling-rare-spans-policy.yaml new file mode 100644 index 000000000000..845566afce1e --- /dev/null +++ b/.chloggen/tail-sampling-rare-spans-policy.yaml @@ -0,0 +1,27 @@ +# Use this changelog template to create an entry for release notes. + +# One of 'breaking', 'deprecation', 'new_component', 'enhancement', 'bug_fix' +change_type: enhancement + +# The name of the component, or a single word describing the area of concern, (e.g. filelogreceiver) +component: processor/tail_sampling + +# A brief description of the change. Surround your text with quotes ("") if it needs to start with a backtick (`). +note: Add new tail sampling policy for sampling low-frequency spans. + +# Mandatory: One or more tracking issues related to the change. You can use the PR number here if no issue exists. +issues: [36487] + +# (Optional) One or more lines of additional information to render under the primary note. +# These lines will be padded with 2 spaces and then inserted directly into the document. +# Use pipe (|) for multiline entries. +subtext: A new policy for sampling rare spans based on the use of count-min sketch. + +# If your change doesn't affect end users or the exported elements of any package, +# you should instead start your pull request title with [chore] or use the "Skip Changelog" label. +# Optional: The change log or logs in which this entry should be included. +# e.g. '[user]' or '[user, api]' +# Include 'user' if the change is relevant to end users. +# Include 'api' if there is a change to a library API. +# Default: '[user]' +change_logs: [] diff --git a/processor/tailsamplingprocessor/README.md b/processor/tailsamplingprocessor/README.md index 358450d3f458..6caba1a10995 100644 --- a/processor/tailsamplingprocessor/README.md +++ b/processor/tailsamplingprocessor/README.md @@ -35,7 +35,8 @@ Multiple policies exist today and it is straight forward to add more. These incl - `span_count`: Sample based on the minimum and/or maximum number of spans, inclusive. If the sum of all spans in the trace is outside the range threshold, the trace will not be sampled. - `boolean_attribute`: Sample based on boolean attribute (resource and record). - `ottl_condition`: Sample based on given boolean OTTL condition (span and span event). -- `and`: Sample based on multiple policies, creates an AND policy +- `and`: Sample based on multiple policies, creates an AND policy +- `rare_spans`: Sample low-frequency spans based on counting unique spans - `composite`: Sample based on a combination of above samplers, with ordering and rate allocation per sampler. Rate allocation allocates certain percentages of spans per policy order. For example if we have set max_total_spans_per_second as 100 then we can set rate_allocation as follows 1. test-composite-policy-1 = 50 % of max_total_spans_per_second = 50 spans_per_second @@ -166,6 +167,19 @@ processors: ] } }, + { + name: rare-spans-policy-1, + type: rare_spans, + rare_spans: { + error_probability: 0.01, + total_frequency: 1000, + max_error_value: 1, + observation_interval: 60m, + buckets_num: 4, + rare_span_frequency: 2, + sampled_spans_per_second: 500, + processed_spans_per_second: 1000, + } { name: composite-policy-1, type: composite, diff --git a/processor/tailsamplingprocessor/config.go b/processor/tailsamplingprocessor/config.go index 4185e7b9b0b2..6db83b6ea589 100644 --- a/processor/tailsamplingprocessor/config.go +++ b/processor/tailsamplingprocessor/config.go @@ -43,6 +43,8 @@ const ( // OTTLCondition sample traces which match user provided OpenTelemetry Transformation Language // conditions. OTTLCondition PolicyType = "ottl_condition" + // RareSpans sample traces with rare spans + RareSpans PolicyType = "rare_spans" ) // sharedPolicyCfg holds the common configuration to all policies that are used in derivative policy configurations @@ -72,6 +74,50 @@ type sharedPolicyCfg struct { BooleanAttributeCfg BooleanAttributeCfg `mapstructure:"boolean_attribute"` // Configs for OTTL condition filter sampling policy evaluator OTTLConditionCfg OTTLConditionCfg `mapstructure:"ottl_condition"` + // Configs for rare_spans policy + RareSpansCfg RareSpansCfg `mapstructure:"rare_spans"` +} + +// RareSpansCfg configuration for the rare spans sampler. +type RareSpansCfg struct { + // ErrorProbability error probability (δ, delta in the turms of Count-min sketch) + // defines the probability of error or failure rate of the estimation. If the + // probability is small, it means there is a low chance that the CMS will produce + // a count estimation that is too far from the true count. On the other hand, + // the smaller the value, the more times the hash will need to be calculated for + // each span. This in turn can negatively affect performance. + ErrorProbability float64 `mapstructure:"error_probability"` + // TotalFreq total number of spans that will need to be processed in + // ObservationInterval time interval. This parameter affects the accuracy + // of the span frequency calculation (`epsilon` in the terms of + // Count-min sketch): + // - the closer the value is to the actual number of spans, the more + // accurate the estimate will be; + // - if the value is higher than the real one, this will lead to a very + // accurate estimate; + // - the larger this value, the more memory will be needed to calculate + // the estimate. + TotalFreq float64 `mapstructure:"total_frequency"` + // MaxErrValue the maximum value of the overestimation at spans frequency + // calculation. Alongside with the TotalFreq option, it is used to calculate + // the `epsilon` (ε) parameter for the Count-Min sketch data structure. + // The lower the value of MaxErrValue, the more accurate the estimate of the + // frequency of each unique span will be. On the other hand, the smaller the + // value, the more memory will be allocated for CMS data structure. + MaxErrValue float64 `mapstructure:"max_error_value"` + // SpsSampledLimit maximum number of spans that can be sampled per second. + SpsSampledLimit int64 `mapstructure:"sampled_spans_per_second"` + // SpsSampledLimit maximum number of spans that can be processed per second. + SpsProcessedLimit int64 `mapstructure:"processed_spans_per_second"` + // ObservationInterval the time interval of the sliding window within which + // rare spans will be taken into account. + ObservationInterval time.Duration `mapstructure:"observation_interval"` + // Buckets number of segments in a sliding window. + Buckets uint8 `mapstructure:"buckets_num"` + // RareSpanFrequency frequency of occurrence of a span in the ObservationInterval + // at which the span will be sampled. For example, if the value is 1, then the + // span will be sampled only at its first occurrence in the ObservationInterval. + RareSpanFrequency uint32 `mapstructure:"rare_span_frequency"` } // CompositeSubPolicyCfg holds the common configuration to all policies under composite policy. diff --git a/processor/tailsamplingprocessor/go.mod b/processor/tailsamplingprocessor/go.mod index aae6f36991d9..1cc1f339b010 100644 --- a/processor/tailsamplingprocessor/go.mod +++ b/processor/tailsamplingprocessor/go.mod @@ -17,7 +17,7 @@ require ( go.opentelemetry.io/collector/featuregate v1.20.0 go.opentelemetry.io/collector/pdata v1.20.0 go.opentelemetry.io/collector/processor v0.114.0 - go.opentelemetry.io/collector/semconv v0.114.0 // indirect + go.opentelemetry.io/collector/semconv v0.114.0 go.opentelemetry.io/otel v1.32.0 go.opentelemetry.io/otel/metric v1.32.0 go.opentelemetry.io/otel/sdk/metric v1.32.0 @@ -27,6 +27,7 @@ require ( ) require ( + github.com/cespare/xxhash/v2 v2.3.0 go.opentelemetry.io/collector/component/componenttest v0.114.0 go.opentelemetry.io/collector/consumer/consumertest v0.114.0 go.opentelemetry.io/collector/processor/processortest v0.114.0 diff --git a/processor/tailsamplingprocessor/go.sum b/processor/tailsamplingprocessor/go.sum index c2672807f138..a8c37c20738e 100644 --- a/processor/tailsamplingprocessor/go.sum +++ b/processor/tailsamplingprocessor/go.sum @@ -8,6 +8,8 @@ github.com/antchfx/xmlquery v1.4.2 h1:MZKd9+wblwxfQ1zd1AdrTsqVaMjMCwow3IqkCSe00K github.com/antchfx/xmlquery v1.4.2/go.mod h1:QXhvf5ldTuGqhd1SHNvvtlhhdQLks4dD0awIVhXIDTA= github.com/antchfx/xpath v1.3.2 h1:LNjzlsSjinu3bQpw9hWMY9ocB80oLOWuQqFvO6xt51U= github.com/antchfx/xpath v1.3.2/go.mod h1:i54GszH55fYfBmoZXapTHN8T8tkcHfRgLyVwwqzXNcs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/processor/tailsamplingprocessor/internal/cms/cms.go b/processor/tailsamplingprocessor/internal/cms/cms.go new file mode 100644 index 000000000000..c0ee6b36fbee --- /dev/null +++ b/processor/tailsamplingprocessor/internal/cms/cms.go @@ -0,0 +1,110 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package cms // import "github.com/open-telemetry/opentelemetry-collector-contrib/processor/tailsamplingprocessor/internal/cms" + +import ( + "math" +) + +// CountMinSketch interface for Count-Min Sketch data structure. CMS is a +// probabilistic data structure that provides an estimate of the frequency +// of elements in a data stream. +type CountMinSketch interface { + // Count returns the estimated frequency of the given key in the data + // stream. + Count(key []byte) uint32 + // Insert increments the counters in the Count-Min Sketch for the given + // key. + Insert(element []byte) + // Clear resets the internal state of Count-Min Sketch + Clear() + // InsertWithCount increases the count of specified key and returns a new + // estimated frequency of the key. + InsertWithCount(element []byte) uint32 +} + +type hasher interface { + Hash([]byte) uint32 +} + +// CountMinSketchCfg describes the main configuration options for CMS data +// structure +type CountMinSketchCfg struct { + // MaxErr approximation error, determines the error bound of the frequency + // estimates. Used together with the TotalFreq option to calculate the + // epsilon (ε) parameter for the CountMin Sketch. + MaxErr float64 + // ErrorProbability error probability (δ, delta), defines the probability that + // the error exceeds the bound. + ErrorProbability float64 + // TotalFreq total number of elements (keys) in the data stream. Used in + // calculation of epsilon (ε) parameter for the CountMin Sketch + TotalFreq float64 +} + +type CMS struct { + data [][]uint32 + hs []hasher +} + +// NewCMS creates new CMS structure based on given width and depth. +func NewCMS(w, h int) *CMS { + data := make([][]uint32, h) + hs := make([]hasher, h) + for i := 0; i < h; i++ { + hs[i] = NewHWHasher(uint32(w), i) + data[i] = make([]uint32, w) + } + return &CMS{ + data: data, + hs: hs, + } +} + +// NewCMSWithErrorParams creates new CMS structure based on given config. +// There CMS width = ⌈e/ε⌉, and depth = ⌈ln(1/δ)⌉ +func NewCMSWithErrorParams(cfg *CountMinSketchCfg) *CMS { + d := math.Ceil(math.Log2(1 / cfg.ErrorProbability)) + w := math.Ceil(math.E / (cfg.MaxErr / cfg.TotalFreq)) + return NewCMS(int(w), int(d)) +} + +// Insert inserts new element in CMS +func (c *CMS) Insert(element []byte) { + for i, h := range c.hs { + c.data[i][h.Hash(element)]++ + } +} + +// Clear resets the CMS state +func (c *CMS) Clear() { + for i := range c.hs { + for k := range c.data[i] { + c.data[i][k] = 0 + } + } +} + +// Count estimates the frequency of a given element +func (c *CMS) Count(element []byte) uint32 { + var m uint32 = math.MaxUint32 + for i, h := range c.hs { + m = min(m, c.data[i][h.Hash(element)]) + } + return m +} + +// InsertWithCount inserts the element to the CMS and returns the element's +// frequency estimation. This method is equivalent to sequential calls to +// Insert(element) and Count(element). However, in comparison with Count+Insert, +// the InsertWithCount method has 2 times less number of hash calculations. +func (c *CMS) InsertWithCount(element []byte) uint32 { + var m uint32 = math.MaxUint32 + for i, h := range c.hs { + position := h.Hash(element) + c.data[i][position]++ + m = min(m, c.data[i][position]) + } + return m +} diff --git a/processor/tailsamplingprocessor/internal/cms/cms_test.go b/processor/tailsamplingprocessor/internal/cms/cms_test.go new file mode 100644 index 000000000000..a4bbaaa13ae2 --- /dev/null +++ b/processor/tailsamplingprocessor/internal/cms/cms_test.go @@ -0,0 +1,414 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package cms + +import ( + "fmt" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCmsCountAfterInsertNoUnderEstimate(t *testing.T) { + maxValue := 1000 + + testCases := []struct { + ErrorProbability float64 + ErrorBound float64 + }{ + { + ErrorProbability: .01, + ErrorBound: 1., + }, + + { + ErrorProbability: .05, + ErrorBound: 1., + }, + + { + ErrorProbability: .10, + ErrorBound: 2., + }, + + { + ErrorProbability: .20, + ErrorBound: 2., + }, + } + + for _, c := range testCases { + caseName := fmt.Sprintf("error_prob_%.2f_error_bound_%.0f", c.ErrorProbability, c.ErrorBound) + t.Run(caseName, func(t *testing.T) { + cms := NewCMSWithErrorParams(&CountMinSketchCfg{ + ErrorProbability: c.ErrorProbability, + TotalFreq: float64(countOfInsertions(maxValue)), + MaxErr: c.ErrorBound, + }) + + for i := 1; i < maxValue+1; i++ { + bytesData := []byte(strconv.Itoa(i)) + for k := 0; k < i; k++ { + cms.Insert(bytesData) + } + } + + for i := 1; i < maxValue+1; i++ { + bytesData := []byte(strconv.Itoa(i)) + cnt := int(cms.Count(bytesData)) + assert.GreaterOrEqual(t, cnt, i, "estimated cnt (%d) is less than actual (%d)", cnt, i) + } + }) + } +} + +func TestCmsCountAfterInsertErrorBound(t *testing.T) { + maxValue := 1000 + + testCases := []struct { + ErrorProbability float64 + ErrorBound float64 + }{ + { + ErrorProbability: .01, + ErrorBound: 1., + }, + + { + ErrorProbability: .05, + ErrorBound: 1., + }, + + { + ErrorProbability: .10, + ErrorBound: 2., + }, + + { + ErrorProbability: .20, + ErrorBound: 2., + }, + } + + for _, c := range testCases { + caseName := fmt.Sprintf("error_prob_%.2f_error_bound_%.0f", c.ErrorProbability, c.ErrorBound) + t.Run(caseName, func(t *testing.T) { + cms := NewCMSWithErrorParams(&CountMinSketchCfg{ + ErrorProbability: c.ErrorProbability, + TotalFreq: float64(countOfInsertions(maxValue)), + MaxErr: c.ErrorBound, + }) + + for i := 1; i < maxValue+1; i++ { + bytesData := makeTestCMSKey(i) + for k := 0; k < i; k++ { + cms.Insert(bytesData) + } + } + + for i := 1; i < maxValue+1; i++ { + bytesData := makeTestCMSKey(i) + cnt := int(cms.Count(bytesData)) + errValue := cnt - i + assert.LessOrEqualf(t, errValue, int(c.ErrorBound), + "error(%d) is greater than defined(%d); i = %d", + errValue, int(c.ErrorBound), i) + } + }) + } +} + +func TestCmsCountAfterInsertErrorProbability(t *testing.T) { + maxValue := 1000 + + testCases := []struct { + ErrorProbability float64 + ErrorBound float64 + }{ + { + ErrorProbability: .01, + ErrorBound: 1., + }, + + { + ErrorProbability: .05, + ErrorBound: 1., + }, + + { + ErrorProbability: .10, + ErrorBound: 2., + }, + + { + ErrorProbability: .20, + ErrorBound: 2., + }, + } + + for _, c := range testCases { + caseName := fmt.Sprintf("error_prob_%.2f_error_bound_%.0f", c.ErrorProbability, c.ErrorBound) + t.Run(caseName, func(t *testing.T) { + cms := NewCMSWithErrorParams(&CountMinSketchCfg{ + ErrorProbability: c.ErrorProbability, + TotalFreq: float64(countOfInsertions(maxValue)), + MaxErr: c.ErrorBound, + }) + + for i := 1; i < maxValue+1; i++ { + bytesData := makeTestCMSKey(i) + for k := 0; k < i; k++ { + cms.Insert(bytesData) + } + } + + overestimated := .0 + for i := 1; i < maxValue+1; i++ { + bytesData := makeTestCMSKey(i) + cnt := int(cms.Count(bytesData)) + if cnt != i { + overestimated++ + } + } + errProb := overestimated / float64(countOfInsertions(maxValue)) + assert.LessOrEqualf(t, errProb, c.ErrorProbability, + "%s: error(%.4f) is greater than defined(%.2f)", + caseName, errProb, c.ErrorProbability) + }) + } +} + +func TestCmsInsertWithCountNoUnderEstimate(t *testing.T) { + maxValue := 1000 + + testCases := []struct { + ErrorProbability float64 + ErrorBound float64 + }{ + { + ErrorProbability: .01, + ErrorBound: 1., + }, + + { + ErrorProbability: .05, + ErrorBound: 1., + }, + + { + ErrorProbability: .10, + ErrorBound: 2., + }, + + { + ErrorProbability: .20, + ErrorBound: 2., + }, + } + + for _, c := range testCases { + caseName := fmt.Sprintf("error_prob_%.2f_error_bound_%.0f", c.ErrorProbability, c.ErrorBound) + t.Run(caseName, func(t *testing.T) { + cms := NewCMSWithErrorParams(&CountMinSketchCfg{ + ErrorProbability: c.ErrorProbability, + TotalFreq: float64(countOfInsertions(maxValue)), + MaxErr: c.ErrorBound, + }) + + for i := 1; i < maxValue+1; i++ { + bytesData := []byte(strconv.Itoa(i)) + for k := 0; k < i-1; k++ { + cms.InsertWithCount(bytesData) + } + cnt := int(cms.InsertWithCount(bytesData)) + assert.GreaterOrEqualf(t, cnt, i, "estimated cnt (%d) is less than actual (%d)", cnt, i) + } + }) + } +} + +func TestCmsInsertWithCountErrorBound(t *testing.T) { + maxValue := 1000 + + testCases := []struct { + ErrorProbability float64 + ErrorBound float64 + }{ + { + ErrorProbability: .01, + ErrorBound: 1., + }, + + { + ErrorProbability: .05, + ErrorBound: 1., + }, + + { + ErrorProbability: .10, + ErrorBound: 2., + }, + + { + ErrorProbability: .20, + ErrorBound: 2., + }, + } + + for _, c := range testCases { + caseName := fmt.Sprintf("error_prob_%.2f_error_bound_%.0f", c.ErrorProbability, c.ErrorBound) + t.Run(caseName, func(t *testing.T) { + cms := NewCMSWithErrorParams(&CountMinSketchCfg{ + ErrorProbability: c.ErrorProbability, + TotalFreq: float64(countOfInsertions(maxValue)), + MaxErr: c.ErrorBound, + }) + + for i := 1; i < maxValue+1; i++ { + bytesData := makeTestCMSKey(i) + for k := 0; k < i-1; k++ { + cms.InsertWithCount(bytesData) + } + cnt := cms.InsertWithCount(bytesData) + errValue := int(cnt) - i + assert.LessOrEqualf(t, errValue, int(c.ErrorBound), + "error(%d) is greater than defined(%d); i = %d", + errValue, int(c.ErrorBound), i) + } + }) + } +} + +func TestCmsInsertWithCountErrorProbability(t *testing.T) { + maxValue := 1000 + + testCases := []struct { + ErrorProbability float64 + ErrorBound float64 + }{ + { + ErrorProbability: .01, + ErrorBound: 1., + }, + + { + ErrorProbability: .05, + ErrorBound: 1., + }, + + { + ErrorProbability: .10, + ErrorBound: 2., + }, + + { + ErrorProbability: .20, + ErrorBound: 2., + }, + } + + for _, c := range testCases { + caseName := fmt.Sprintf("error_prob_%.2f_error_bound_%.0f", c.ErrorProbability, c.ErrorBound) + t.Run(caseName, func(t *testing.T) { + cms := NewCMSWithErrorParams(&CountMinSketchCfg{ + ErrorProbability: c.ErrorProbability, + TotalFreq: float64(countOfInsertions(maxValue)), + MaxErr: c.ErrorBound, + }) + + overEstimated := 0. + for i := 1; i < maxValue+1; i++ { + bytesData := makeTestCMSKey(i) + + for k := 0; k < (i - 1); k++ { + cms.InsertWithCount(bytesData) + } + + cnt := cms.InsertWithCount(bytesData) + if cnt != uint32(i) { + overEstimated++ + } + } + + errProb := overEstimated / float64(countOfInsertions(maxValue)) + if overEstimated == 0 { + errProb = 0. + } + assert.LessOrEqualf(t, errProb, c.ErrorProbability, + "%s: error(%.4f) is greater than defined(%.2f)", + caseName, errProb, c.ErrorProbability) + }) + } +} + +func TestCmsEmpty(t *testing.T) { + cms := NewCMSWithErrorParams(&CountMinSketchCfg{ + ErrorProbability: .99, + TotalFreq: 100, + MaxErr: 1, + }) + + assert.Zero(t, cms.Count([]byte{1})) + assert.Zero(t, cms.Count([]byte{2})) + assert.Zero(t, cms.Count([]byte{3})) +} + +func TestCmsNonExist(t *testing.T) { + cms := NewCMSWithErrorParams(&CountMinSketchCfg{ + ErrorProbability: .99, + TotalFreq: 100, + MaxErr: 1, + }) + + cms.Insert([]byte{1}) + cms.Insert([]byte{2}) + cms.Insert([]byte{3}) + + assert.Zero(t, cms.Count([]byte{4})) + assert.Zero(t, cms.Count([]byte{5})) + assert.Zero(t, cms.Count([]byte{6})) +} + +func TestCmsInsertSimple(t *testing.T) { + cms := NewCMSWithErrorParams(&CountMinSketchCfg{ + ErrorProbability: .99, + TotalFreq: 100, + MaxErr: 1, + }) + + assert.Equal(t, uint32(1), cms.InsertWithCount([]byte{1})) + assert.Equal(t, uint32(1), cms.Count([]byte{1})) + + assert.Equal(t, uint32(2), cms.InsertWithCount([]byte{1})) + assert.Equal(t, uint32(2), cms.Count([]byte{1})) + + assert.Equal(t, uint32(3), cms.InsertWithCount([]byte{1})) + assert.Equal(t, uint32(3), cms.Count([]byte{1})) +} + +func TestCmsClear(t *testing.T) { + cms := NewCMSWithErrorParams(&CountMinSketchCfg{ + ErrorProbability: 0.001, + TotalFreq: 100, + MaxErr: 1, + }) + + cms.Insert([]byte{1}) + assert.Equal(t, uint32(1), cms.Count([]byte{1})) + + cms.Insert([]byte{2}) + cms.Insert([]byte{2}) + assert.Equal(t, uint32(2), cms.Count([]byte{2})) + + cms.Insert([]byte{3}) + cms.Insert([]byte{3}) + cms.Insert([]byte{3}) + assert.Equal(t, uint32(3), cms.Count([]byte{3})) + + cms.Clear() + assert.Equal(t, uint32(0), cms.Count([]byte{1})) + assert.Equal(t, uint32(0), cms.Count([]byte{2})) + assert.Equal(t, uint32(0), cms.Count([]byte{3})) +} diff --git a/processor/tailsamplingprocessor/internal/cms/hash.go b/processor/tailsamplingprocessor/internal/cms/hash.go new file mode 100644 index 000000000000..f7d5d2d6261c --- /dev/null +++ b/processor/tailsamplingprocessor/internal/cms/hash.go @@ -0,0 +1,69 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package cms // import "github.com/open-telemetry/opentelemetry-collector-contrib/processor/tailsamplingprocessor/internal/cms" + +import ( + "github.com/cespare/xxhash/v2" +) + +// Random generated seeds for xxHash +var hashSeeds = []uint64{ + 11595015838869646110, + 6812085131706188016, + 1803417234240274060, + 9910387706731171320, + 6192372884267156726, + 9359297034374090798, + 10699310435244823924, + 2694511717751745834, + 381360183478099046, + 3813652365689143046, + 2745863731201646884, + 9338067737075496131, + 2310364619108435937, + 2033119307453415722, + 18154805337904800280, + 16153398036464115640, + 11576467370134829350, + 16139083825458188821, + 14613404529444667025, + 8229605496796251508, + 14043697971178212370, + 5104099633310233611, + 8840567979630932215, + 3619854489427682144, + 2922888160084146262, + 1066417268237148873, + 10391653214809458763, + 5008947111593455631, + 2544378244597710161, + 1282165131157204414, + 15346189051374937777, + 8983218487838504684, +} + +type XXHasher struct { + // h xxHash implementation of hash.Hash64 + h *xxhash.Digest + // seedIdx seed serial index + seedIdx int + // maxVal max value that hasher could produce + maxVal uint32 +} + +func NewHWHasher(length uint32, idx int) *XXHasher { + h := xxhash.NewWithSeed(hashSeeds[idx]) + return &XXHasher{ + h: h, + maxVal: length, + seedIdx: idx, + } +} + +func (hw *XXHasher) Hash(data []byte) uint32 { + hw.h.ResetWithSeed(hashSeeds[hw.seedIdx]) + _, _ = hw.h.Write(data) + + return uint32(hw.h.Sum64() % uint64(hw.maxVal)) +} diff --git a/processor/tailsamplingprocessor/internal/cms/ring_buffer.go b/processor/tailsamplingprocessor/internal/cms/ring_buffer.go new file mode 100644 index 000000000000..8be537acda2d --- /dev/null +++ b/processor/tailsamplingprocessor/internal/cms/ring_buffer.go @@ -0,0 +1,160 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package cms // import "github.com/open-telemetry/opentelemetry-collector-contrib/processor/tailsamplingprocessor/internal/cms" + +import ( + "errors" +) + +var ( + errFullBuffer = errors.New("buffer is full") + errEmptyBuffer = errors.New("buffer is empty") + errIncorrectIndex = errors.New("element index is incorrect") + errEmptyData = errors.New("data is empty") +) + +// RingBufferQueue is the ring buffer data structure implementation +type RingBufferQueue[T any] struct { + capacity uint32 + head uint32 + tail uint32 + full bool + data []T +} + +func (r *RingBufferQueue[T]) Capacity() uint32 { + return r.capacity +} + +func (r *RingBufferQueue[T]) IsEmpty() bool { + return r.head == r.tail && !r.full +} + +func (r *RingBufferQueue[T]) IsFull() bool { + return r.full +} + +// Size returns the number of elements that are currently in the buffer. +func (r *RingBufferQueue[T]) Size() int { + if r.IsFull() { + return int(r.capacity) + } + + if r.IsEmpty() { + return 0 + } + + if r.tail > r.head { + return int(r.tail) - int(r.head) + } + + return int(r.capacity) - (int(r.head) - int(r.tail)) +} + +// Enqueue adds element in the buffer. +func (r *RingBufferQueue[T]) Enqueue(val T) error { + if r.IsFull() { + return errFullBuffer + } + + r.data[r.tail] = val + r.tail = (r.tail + 1) % r.capacity + r.full = r.head == r.tail + + return nil +} + +// TailMoveForward moves the tail pointer to the next position in the buffer. +// If the buffer is full an error is returned. +func (r *RingBufferQueue[T]) TailMoveForward() error { + if r.IsFull() { + return errFullBuffer + } + + r.tail = (r.tail + 1) % r.capacity + r.full = r.head == r.tail + + return nil +} + +// Visit iterates through the buffer, starting with the head element and ending +// with the tail element. For each of the iterable elements, the visit function +// is called. The iteration ends either when the tail is reached or when the +// visitor returns false. +func (r *RingBufferQueue[T]) Visit(visitor func(T) bool) { + start := int(r.head) + for i := 0; i < r.Size(); i++ { + if !visitor(r.data[(start+i)%int(r.capacity)]) { + return + } + } +} + +// At returns an element by its serial (head-first) index. If the buffer is +// empty or the index is out of range [0:len(buffer)-1] then an error is +// returned. +func (r *RingBufferQueue[T]) At(idx int) (T, error) { + if r.IsEmpty() || idx < 0 || idx >= r.Size() { + var t T + return t, errIncorrectIndex + } + + start := int(r.head) + return r.data[(start+idx)%int(r.capacity)], nil +} + +// Tail returns the last element from the ring buffer. If the buffer is empty +// the method returns the `false` flag. +func (r *RingBufferQueue[T]) Tail() (T, bool) { + if r.IsEmpty() { + var t T + return t, false + } + + if r.tail == 0 { + return r.data[r.capacity-1], true + } + return r.data[r.tail-1], true +} + +// Dequeue removes and returns the first (head) element from the buffer. If the +// buffer is empty an error is returned. +func (r *RingBufferQueue[T]) Dequeue() (T, error) { + if r.IsEmpty() { + var t T + return t, errEmptyBuffer + } + + element := r.data[r.head] + r.full = false + r.head = (r.head + 1) % r.capacity + + return element, nil +} + +func NewEmptyRingBufferQueue[T any](capacity uint32) (*RingBufferQueue[T], error) { + if capacity == 0 { + return nil, errors.New("empty capacity value passed") + } + + return &RingBufferQueue[T]{ + capacity: capacity, + head: 0, + tail: 0, + full: false, + data: make([]T, capacity), + }, nil +} + +func NewRingBufferQueue[T any](initData []T) (*RingBufferQueue[T], error) { + if len(initData) == 0 { + return nil, errEmptyData + } + + return &RingBufferQueue[T]{ + capacity: uint32(len(initData)), + full: false, + data: initData, + }, nil +} diff --git a/processor/tailsamplingprocessor/internal/cms/ring_buffer_test.go b/processor/tailsamplingprocessor/internal/cms/ring_buffer_test.go new file mode 100644 index 000000000000..312f37b3c825 --- /dev/null +++ b/processor/tailsamplingprocessor/internal/cms/ring_buffer_test.go @@ -0,0 +1,490 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package cms + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRingbufferEmptyCapacity(t *testing.T) { + rb, err := NewEmptyRingBufferQueue[int](0) + assert.Nil(t, rb) + assert.Error(t, err) +} + +func TestRingbufferEnqueueFull(t *testing.T) { + rb, err := NewEmptyRingBufferQueue[int](1) + assert.NoError(t, err) + assert.NotNil(t, rb) + + err = rb.Enqueue(1) + assert.NoError(t, err) + assert.Equal(t, 1, rb.Size()) + assert.True(t, rb.IsFull()) + + err = rb.Enqueue(2) + assert.Error(t, err) + assert.ErrorIs(t, err, errFullBuffer) + assert.Equal(t, 1, rb.Size()) + assert.True(t, rb.IsFull()) +} + +func TestRingbufferEnqueueSimple(t *testing.T) { + rb, err := NewEmptyRingBufferQueue[int](3) + assert.NoError(t, err) + + var ( + valTail int + val0 int + val1 int + val2 int + ) + + err = rb.Enqueue(3) + assert.NoError(t, err) + valTail, _ = rb.Tail() + val0, _ = rb.At(0) + assert.False(t, rb.IsEmpty()) + assert.False(t, rb.IsFull()) + assert.Equal(t, 1, rb.Size()) + assert.Equal(t, uint32(3), rb.Capacity()) + assert.Equal(t, 3, valTail) + assert.Equal(t, 3, val0) + + err = rb.Enqueue(2) + assert.NoError(t, err) + + err = rb.Enqueue(1) + assert.NoError(t, err) + + valTail, _ = rb.Tail() + val0, _ = rb.At(0) + val1, _ = rb.At(1) + val2, _ = rb.At(2) + assert.False(t, rb.IsEmpty()) + assert.True(t, rb.IsFull()) + assert.Equal(t, 3, rb.Size()) + assert.Equal(t, uint32(3), rb.Capacity()) + assert.Equal(t, 1, valTail) + assert.Equal(t, 3, val0) + assert.Equal(t, 2, val1) + assert.Equal(t, 1, val2) +} + +func TestRingbufferDequeueSequential(t *testing.T) { + rb, err := NewEmptyRingBufferQueue[int](3) + assert.NoError(t, err) + + var ( + valTail int + val0 int + val1 int + val int + ) + + err = rb.Enqueue(3) + assert.NoError(t, err) + + err = rb.Enqueue(2) + assert.NoError(t, err) + + err = rb.Enqueue(1) + assert.NoError(t, err) + + val, err = rb.Dequeue() + valTail, _ = rb.Tail() + val0, _ = rb.At(0) + val1, _ = rb.At(1) + + assert.NoError(t, err) + assert.False(t, rb.IsEmpty()) + assert.False(t, rb.IsFull()) + assert.Equal(t, 3, val) + assert.Equal(t, 1, valTail) + assert.Equal(t, 2, val0) + assert.Equal(t, 1, val1) + assert.Equal(t, 2, rb.Size()) + + val, err = rb.Dequeue() + valTail, _ = rb.Tail() + val0, _ = rb.At(0) + + assert.NoError(t, err) + assert.False(t, rb.IsEmpty()) + assert.False(t, rb.IsFull()) + assert.Equal(t, 2, val) + assert.Equal(t, 1, valTail) + assert.Equal(t, 1, val0) + assert.Equal(t, 1, rb.Size()) + + val, err = rb.Dequeue() + assert.NoError(t, err) + assert.True(t, rb.IsEmpty()) + assert.False(t, rb.IsFull()) + assert.Equal(t, 1, val) + assert.Equal(t, 0, rb.Size()) +} + +func TestRingbufferDequeueSimple(t *testing.T) { + rb, err := NewEmptyRingBufferQueue[int](3) + assert.NoError(t, err) + + var ( + valTail int + val0 int + val int + ) + + err = rb.Enqueue(3) + assert.NoError(t, err) + val, err = rb.Dequeue() + assert.NoError(t, err) + assert.True(t, rb.IsEmpty()) + assert.False(t, rb.IsFull()) + assert.Equal(t, 3, val) + assert.Equal(t, 0, rb.Size()) + + err = rb.Enqueue(2) + assert.NoError(t, err) + val, err = rb.Dequeue() + assert.NoError(t, err) + assert.True(t, rb.IsEmpty()) + assert.False(t, rb.IsFull()) + assert.Equal(t, 2, val) + assert.Equal(t, 0, rb.Size()) + + err = rb.Enqueue(1) + assert.NoError(t, err) + val, err = rb.Dequeue() + valTail, _ = rb.Tail() + val0, _ = rb.At(0) + assert.NoError(t, err) + assert.True(t, rb.IsEmpty()) + assert.False(t, rb.IsFull()) + assert.Equal(t, 1, val) + assert.Equal(t, 0, valTail) + assert.Equal(t, 0, val0) + assert.Equal(t, 0, rb.Size()) +} + +func TestRingbufferDequeueEmpty(t *testing.T) { + rb, err := NewEmptyRingBufferQueue[int](3) + assert.NoError(t, err) + + _, err = rb.Dequeue() + assert.ErrorIs(t, err, errEmptyBuffer) + + err = rb.Enqueue(1) + assert.NoError(t, err) + + _, err = rb.Dequeue() + assert.NoError(t, err) + + _, err = rb.Dequeue() + assert.ErrorIs(t, err, errEmptyBuffer) +} + +func TestRingbufferTailEmpty(t *testing.T) { + rb, err := NewEmptyRingBufferQueue[int](3) + assert.NoError(t, err) + + val, ok := rb.Tail() + assert.True(t, rb.IsEmpty()) + assert.False(t, rb.IsFull()) + assert.False(t, ok) + assert.Equal(t, 0, val) + + _ = rb.Enqueue(1) + _, _ = rb.Dequeue() + val, ok = rb.Tail() + assert.False(t, ok) + assert.True(t, rb.IsEmpty()) + assert.False(t, rb.IsFull()) + assert.Equal(t, 0, val) +} + +func TestRingbufferAccessByIndexEmpty(t *testing.T) { + rb, err := NewEmptyRingBufferQueue[int](3) + assert.NoError(t, err) + + var ( + val0 int + val1 int + ) + + val0, err = rb.At(0) + assert.True(t, rb.IsEmpty()) + assert.False(t, rb.IsFull()) + assert.ErrorIs(t, err, errIncorrectIndex) + assert.Equal(t, 0, val0) + + val1, err = rb.At(1) + assert.True(t, rb.IsEmpty()) + assert.False(t, rb.IsFull()) + assert.ErrorIs(t, err, errIncorrectIndex) + assert.Equal(t, 0, val1) + + _ = rb.Enqueue(1) + _, _ = rb.Dequeue() + + val0, err = rb.At(0) + assert.True(t, rb.IsEmpty()) + assert.False(t, rb.IsFull()) + assert.ErrorIs(t, err, errIncorrectIndex) + assert.Equal(t, 0, val0) + + val1, err = rb.At(1) + assert.True(t, rb.IsEmpty()) + assert.False(t, rb.IsFull()) + assert.ErrorIs(t, err, errIncorrectIndex) + assert.Equal(t, 0, val1) +} + +func TestRingbufferAccessByIndexSimple(t *testing.T) { + rb, _ := NewEmptyRingBufferQueue[int](3) + _ = rb.Enqueue(1) + _ = rb.Enqueue(2) + _ = rb.Enqueue(3) + + val0, err0 := rb.At(0) + val1, err1 := rb.At(1) + val2, err2 := rb.At(2) + + assert.False(t, rb.IsEmpty()) + assert.True(t, rb.IsFull()) + assert.NoError(t, err0) + assert.NoError(t, err1) + assert.NoError(t, err2) + assert.Equal(t, 1, val0) + assert.Equal(t, 2, val1) + assert.Equal(t, 3, val2) +} + +func TestRingbufferAccessByIndexWithShiftSimple(t *testing.T) { + rb, _ := NewEmptyRingBufferQueue[int](3) + _ = rb.Enqueue(1) + _ = rb.Enqueue(2) + _ = rb.Enqueue(3) + _, _ = rb.Dequeue() + _ = rb.Enqueue(4) + + val0, err0 := rb.At(0) + val1, err1 := rb.At(1) + val2, err2 := rb.At(2) + + assert.False(t, rb.IsEmpty()) + assert.True(t, rb.IsFull()) + assert.NoError(t, err0) + assert.NoError(t, err1) + assert.NoError(t, err2) + assert.Equal(t, 2, val0) + assert.Equal(t, 3, val1) + assert.Equal(t, 4, val2) +} + +func TestRingbufferAccessByIndexOverflow(t *testing.T) { + rb, _ := NewEmptyRingBufferQueue[int](3) + _ = rb.Enqueue(1) + _ = rb.Enqueue(2) + _ = rb.Enqueue(3) + + val2, err2 := rb.At(2) + assert.Equal(t, 3, val2) + assert.NoError(t, err2) + + val3, err3 := rb.At(3) + assert.Equal(t, 0, val3) + assert.ErrorIs(t, err3, errIncorrectIndex) + + val6, err6 := rb.At(6) + assert.Equal(t, 0, val6) + assert.ErrorIs(t, err6, errIncorrectIndex) +} + +func TestRingbufferAccessByNegativeIndex(t *testing.T) { + rb, _ := NewEmptyRingBufferQueue[int](3) + val, err := rb.At(-1) + assert.Equal(t, 0, val) + assert.ErrorIs(t, err, errIncorrectIndex) + + _ = rb.Enqueue(1) + + val, err = rb.At(-1) + assert.Equal(t, 0, val) + assert.ErrorIs(t, err, errIncorrectIndex) +} + +func TestRingbufferSizeSimple(t *testing.T) { + rb, _ := NewEmptyRingBufferQueue[int](3) + assert.Equal(t, 0, rb.Size()) + + _ = rb.Enqueue(1) + assert.Equal(t, 1, rb.Size()) + + _ = rb.Enqueue(2) + assert.Equal(t, 2, rb.Size()) + + _ = rb.Enqueue(3) + assert.Equal(t, 3, rb.Size()) + + _, _ = rb.Dequeue() + assert.Equal(t, 2, rb.Size()) +} + +func TestRingbufferVisitAllSimple(t *testing.T) { + rb, _ := NewEmptyRingBufferQueue[int](3) + + arr := make([]int, 0) + visitor := func(val int) bool { + arr = append(arr, val) + return true + } + + rb.Visit(visitor) + assert.Empty(t, arr) + + _ = rb.Enqueue(1) + _ = rb.Enqueue(2) + _ = rb.Enqueue(3) + rb.Visit(visitor) + assert.Equal(t, []int{1, 2, 3}, arr) + arr = arr[:0] + + _, _ = rb.Dequeue() + rb.Visit(visitor) + assert.Equal(t, []int{2, 3}, arr) + arr = arr[:0] + + _ = rb.Enqueue(4) + rb.Visit(visitor) + assert.Equal(t, []int{2, 3, 4}, arr) +} + +func TestRingbufferPartialVisitAllSimple(t *testing.T) { + rb, _ := NewEmptyRingBufferQueue[int](3) + _ = rb.Enqueue(1) + _ = rb.Enqueue(2) + _ = rb.Enqueue(3) + + v1 := newPVisitor[int](1) + rb.Visit(v1.visit) + assert.Len(t, v1.visited, 1) + assert.Equal(t, []int{1}, v1.visited) + + v2 := newPVisitor[int](2) + rb.Visit(v2.visit) + assert.Len(t, v2.visited, 2) + assert.Equal(t, []int{1, 2}, v2.visited) +} + +func TestRingbufferTailMoveForwardFull(t *testing.T) { + rb, _ := NewEmptyRingBufferQueue[int](3) + _ = rb.Enqueue(1) + _ = rb.Enqueue(2) + _ = rb.Enqueue(3) + + assert.False(t, rb.IsEmpty()) + assert.True(t, rb.IsFull()) + + err := rb.TailMoveForward() + + assert.False(t, rb.IsEmpty()) + assert.True(t, rb.IsFull()) + assert.ErrorIs(t, err, errFullBuffer) + assert.Equal(t, 3, rb.Size()) +} + +func TestRingbufferTailMoveForwardFromEmptyToFull(t *testing.T) { + rb, _ := NewEmptyRingBufferQueue[int](3) + var err error + + err = rb.TailMoveForward() + assert.False(t, rb.IsEmpty()) + assert.False(t, rb.IsFull()) + assert.Equal(t, 1, rb.Size()) + assert.NoError(t, err) + + err = rb.TailMoveForward() + assert.False(t, rb.IsEmpty()) + assert.False(t, rb.IsFull()) + assert.Equal(t, 2, rb.Size()) + assert.NoError(t, err) + + err = rb.TailMoveForward() + assert.False(t, rb.IsEmpty()) + assert.True(t, rb.IsFull()) + assert.Equal(t, 3, rb.Size()) + assert.NoError(t, err) +} + +func TestRingbufferTailMoveForwardAfterDequeue(t *testing.T) { + rb, _ := NewEmptyRingBufferQueue[int](3) + var ( + err error + val int + ok bool + ) + + _ = rb.Enqueue(1) + _ = rb.Enqueue(2) + _ = rb.Enqueue(3) + + _, _ = rb.Dequeue() + err = rb.TailMoveForward() + assert.NoError(t, err) + assert.True(t, rb.IsFull()) + + val, ok = rb.Tail() + assert.True(t, ok) + assert.Equal(t, 1, val) +} + +func TestRingbufferNewRingBufferQueue(t *testing.T) { + var ( + val int + ok bool + ) + + rb, err := NewRingBufferQueue[int]([]int{3, 2, 1}) + + assert.NoError(t, err) + assert.Equal(t, 0, rb.Size()) + assert.Equal(t, uint32(3), rb.Capacity()) + assert.False(t, rb.IsFull()) + assert.True(t, rb.IsEmpty()) + + err = rb.TailMoveForward() + assert.NoError(t, err) + assert.Equal(t, 1, rb.Size()) + + val, ok = rb.Tail() + assert.Equal(t, 3, val) + assert.True(t, ok) + + err = rb.TailMoveForward() + assert.NoError(t, err) + assert.Equal(t, 2, rb.Size()) + + val, ok = rb.Tail() + assert.Equal(t, 2, val) + assert.True(t, ok) + + err = rb.TailMoveForward() + assert.NoError(t, err) + assert.Equal(t, 3, rb.Size()) + + val, ok = rb.Tail() + assert.Equal(t, 1, val) + assert.True(t, ok) +} + +func TestRingbufferNewRingBufferQueueEmptyInitData(t *testing.T) { + rb, err := NewRingBufferQueue[int]([]int{}) + + assert.Error(t, err) + assert.ErrorIs(t, err, errEmptyData) + assert.Nil(t, rb) +} diff --git a/processor/tailsamplingprocessor/internal/cms/sliding_cms.go b/processor/tailsamplingprocessor/internal/cms/sliding_cms.go new file mode 100644 index 000000000000..a5e65e29a8e0 --- /dev/null +++ b/processor/tailsamplingprocessor/internal/cms/sliding_cms.go @@ -0,0 +1,219 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package cms // import "github.com/open-telemetry/opentelemetry-collector-contrib/processor/tailsamplingprocessor/internal/cms" + +import ( + "errors" + "time" +) + +var ( + errIncorrectBucketsCount = errors.New("incorrect buckets count") + errEmptyObservInterval = errors.New("empty observ interval") + errNonMultiplyObservInterval = errors.New("bucket interval must be a multiple of a second") +) + +// BucketsCfg configuration options for the buckets that make up the CMS +// sliding window. +type BucketsCfg struct { + // ObservationInterval time interval (duration) of a sliding window. + ObservationInterval time.Duration + // Buckets nuber of segments (buckets) in sliding window. + Buckets uint8 + // EstimationSoftLimit the soft limit for each estimation + EstimationSoftLimit uint32 +} + +// BucketInterval calculate bucket interval. If the interval is not multiple of +// a second, then an error is returned. +func (cfg BucketsCfg) BucketInterval() (time.Duration, error) { + bucketInterval := cfg.ObservationInterval / time.Duration(cfg.Buckets) + + if ((cfg.ObservationInterval % time.Duration(cfg.Buckets)) != time.Duration(0)) || + (bucketInterval < time.Second) || + (bucketInterval%time.Second != 0) { + return 0, errNonMultiplyObservInterval + } + + return bucketInterval, nil +} + +// Validate validates buckets cfg. +func (cfg BucketsCfg) Validate() error { + if cfg.ObservationInterval == 0 { + return errEmptyObservInterval + } + if cfg.Buckets == 0 { + return errIncorrectBucketsCount + } + + return nil +} + +// SlidingCms Sliding window of CMS segments (buckets). +type SlidingCms struct { + // cmsBuckets ring buffer containing CMS segments (buckets). Each segment + // contains data for a time period equal to bucketInterval. + cmsBuckets *RingBufferQueue[CountMinSketch] + // startPoint start time of the sliding time window. This time point is + // shifted every bucketInterval. + startPoint time.Time + // bucketInterval time interval of each bucket. + bucketInterval time.Duration + // softLimit + softLimit uint32 +} + +func (h *SlidingCms) shouldUpdateIntervals(tp time.Time) bool { + return h.startPoint.Add(time.Duration(h.cmsBuckets.Size()) * h.bucketInterval).Before(tp) +} + +// CurrentObservationIntervalStartTm returns the current start time of sliding +// window. +func (h *SlidingCms) CurrentObservationIntervalStartTm() time.Time { + return h.startPoint +} + +func (h *SlidingCms) updateBuckets(tp time.Time) { + if !h.shouldUpdateIntervals(tp) { + return + } + + if h.cmsBuckets.IsFull() { + firstBucket, _ := h.cmsBuckets.Dequeue() + firstBucket.Clear() + + h.startPoint = h.startPoint.Add(h.bucketInterval * (tp.Sub(h.startPoint) / h.bucketInterval)) + } + + _ = h.cmsBuckets.TailMoveForward() +} + +func (h *SlidingCms) insertWithTime(data []byte, tm time.Time) { + h.updateBuckets(tm) + t, _ := h.cmsBuckets.Tail() + t.Insert(data) +} + +func (h *SlidingCms) insertCountWithTime(data []byte, tm time.Time) uint32 { + h.updateBuckets(tm) + t, _ := h.cmsBuckets.Tail() + return t.InsertWithCount(data) +} + +// InsertWithCount inserts the element into sliding window and returns the +// frequency estimation. +func (h *SlidingCms) InsertWithCount(element []byte) uint32 { + return h.timedInsertWithCount(element, time.Now()) +} + +func (h *SlidingCms) timedInsertWithCount(element []byte, tm time.Time) uint32 { + val := h.insertCountWithTime(element, tm) + if !h.isUnderLimit(val) { + return val + } + + for i := 0; i < h.cmsBuckets.Size()-1; i++ { + cmsBucket, _ := h.cmsBuckets.At(i) + val += cmsBucket.Count(element) + if !h.isUnderLimit(val) { + return val + } + } + return val +} + +func (h *SlidingCms) isUnderLimit(val uint32) bool { + return h.softLimit == 0 || h.softLimit > val +} + +// Count estimates the frequency of occurrences of an element in a sliding window. +func (h *SlidingCms) Count(data []byte) uint32 { + var val uint32 + h.cmsBuckets.Visit(func(cms CountMinSketch) bool { + val += cms.Count(data) + return h.isUnderLimit(val) + }) + + return val +} + +// Insert inserts the element into the sliding window. +func (h *SlidingCms) Insert(data []byte) { + h.insertWithTime(data, time.Now()) +} + +// Buckets returns the ring buffer of sliding window buckets. +func (h *SlidingCms) Buckets() *RingBufferQueue[CountMinSketch] { + return h.cmsBuckets +} + +// Clear resets cms data for each bucket. +func (h *SlidingCms) Clear() { + h.cmsBuckets.Visit(func(c CountMinSketch) bool { + c.Clear() + return true + }) +} + +func NewSlidingCMSWithStartPoint(bucketsCfg BucketsCfg, cmsCfg CountMinSketchCfg, startTm time.Time) (*SlidingCms, error) { + err := bucketsCfg.Validate() + if err != nil { + return nil, err + } + + bucketInterval, err := bucketsCfg.BucketInterval() + if err != nil { + return nil, err + } + + emptyCmsBuckets := make([]CountMinSketch, 0) + for i := uint8(0); i < bucketsCfg.Buckets; i++ { + emptyCmsBuckets = append(emptyCmsBuckets, NewCMSWithErrorParams(&cmsCfg)) + } + + buckets, err := NewRingBufferQueue[CountMinSketch](emptyCmsBuckets) + _ = buckets.TailMoveForward() + + if err != nil { + return nil, err + } + + return &SlidingCms{ + cmsBuckets: buckets, + startPoint: startTm, + bucketInterval: bucketInterval, + softLimit: bucketsCfg.EstimationSoftLimit, + }, nil +} + +func NewSlidingPredefinedCMSWithStartPoint(bucketsCfg BucketsCfg, cmsBuckets []CountMinSketch, startTm time.Time) (*SlidingCms, error) { + err := bucketsCfg.Validate() + if err != nil { + return nil, err + } + + bucketInterval, err := bucketsCfg.BucketInterval() + if err != nil { + return nil, err + } + + if bucketsCfg.Buckets != uint8(len(cmsBuckets)) { + return nil, errIncorrectBucketsCount + } + + buckets, err := NewRingBufferQueue[CountMinSketch](cmsBuckets) + _ = buckets.TailMoveForward() + + if err != nil { + return nil, err + } + + return &SlidingCms{ + cmsBuckets: buckets, + startPoint: startTm, + bucketInterval: bucketInterval, + softLimit: bucketsCfg.EstimationSoftLimit, + }, nil +} diff --git a/processor/tailsamplingprocessor/internal/cms/sliding_cms_test.go b/processor/tailsamplingprocessor/internal/cms/sliding_cms_test.go new file mode 100644 index 000000000000..fa6a1e7eab13 --- /dev/null +++ b/processor/tailsamplingprocessor/internal/cms/sliding_cms_test.go @@ -0,0 +1,698 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package cms + +import ( + "fmt" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewSlidingCMSInputParamsEmptyBuckets(t *testing.T) { + sCms, err := NewSlidingCMSWithStartPoint( + BucketsCfg{ + ObservationInterval: time.Second * 2, + Buckets: 1, + }, + CountMinSketchCfg{ + ErrorProbability: .1, + TotalFreq: 1, + MaxErr: 1, + }, + time.Now(), + ) + + assert.NoError(t, err) + assert.NotNil(t, sCms) + + bCfg := BucketsCfg{ + ObservationInterval: time.Second * 2, + Buckets: 0, + } + cmsCfg := CountMinSketchCfg{ + ErrorProbability: .1, + TotalFreq: 1, + MaxErr: 1, + } + tm := time.Now() + + sCms, err = NewSlidingCMSWithStartPoint(bCfg, cmsCfg, tm) + assert.ErrorIs(t, err, errIncorrectBucketsCount) + assert.Nil(t, sCms) + + sCms, err = NewSlidingPredefinedCMSWithStartPoint(bCfg, make([]CountMinSketch, 1), tm) + assert.ErrorIs(t, err, errIncorrectBucketsCount) + assert.Nil(t, sCms) +} + +func TestNewSlidingCMSInputParamsNonMultiplyInterval(t *testing.T) { + cmsCfg := CountMinSketchCfg{ + ErrorProbability: .1, + TotalFreq: 1, + MaxErr: 1, + } + tm := time.Now() + + testCases := []struct { + ObservationInterval time.Duration + Buckets uint8 + }{ + { + ObservationInterval: time.Millisecond * 1, + Buckets: 1, + }, + { + ObservationInterval: 1001 * time.Millisecond, + Buckets: 1, + }, + { + ObservationInterval: 1001 * time.Millisecond, + Buckets: 2, + }, + { + ObservationInterval: 100 * time.Millisecond, + Buckets: 3, + }, + } + + for _, c := range testCases { + caseName := fmt.Sprintf( + "observation_interval_%s_buckets_%d", + c.ObservationInterval, + c.Buckets, + ) + + t.Run(caseName, func(t *testing.T) { + bCfg := BucketsCfg{ + ObservationInterval: c.ObservationInterval, + Buckets: c.Buckets, + } + + sCms, err := NewSlidingCMSWithStartPoint(bCfg, cmsCfg, tm) + assert.ErrorIs(t, err, errNonMultiplyObservInterval) + assert.Nil(t, sCms) + + cmsStubs := make([]CountMinSketch, c.Buckets) + sCms, err = NewSlidingPredefinedCMSWithStartPoint(bCfg, cmsStubs, tm) + assert.ErrorIs(t, err, errNonMultiplyObservInterval) + assert.Nil(t, sCms) + }) + } +} + +func TestNewSlidingCMSInputParamsEmptyInterval(t *testing.T) { + bCfg := BucketsCfg{ + ObservationInterval: 0, + Buckets: 1, + } + + cmsCfg := CountMinSketchCfg{ + ErrorProbability: .1, + TotalFreq: 1, + MaxErr: 1, + } + + cmsStubs := []CountMinSketch{&StubCms{}} + + sCms, err := NewSlidingCMSWithStartPoint(bCfg, cmsCfg, time.Now()) + assert.ErrorIs(t, err, errEmptyObservInterval) + assert.Nil(t, sCms) + + sCms, err = NewSlidingPredefinedCMSWithStartPoint(bCfg, cmsStubs, time.Now()) + assert.ErrorIs(t, err, errEmptyObservInterval) + assert.Nil(t, sCms) +} + +func TestNewPredefinedSlidingCMSIncorrectCmsNum(t *testing.T) { + bCfg := BucketsCfg{ + ObservationInterval: 2 * time.Second, + Buckets: 2, + } + + cmsStubs := []CountMinSketch{&StubCms{}} + + sCms, err := NewSlidingPredefinedCMSWithStartPoint(bCfg, cmsStubs, time.Now()) + assert.ErrorIs(t, err, errIncorrectBucketsCount) + assert.Nil(t, sCms) +} + +func TestNewSlidingCMSUpdateBucketsSimple(t *testing.T) { + tm := time.Unix(1, 0) + + cmsData := []CountMinSketch{ + NewEmptyCmsStub(1), + NewEmptyCmsStub(2), + NewEmptyCmsStub(3), + } + + bCfg := BucketsCfg{ + ObservationInterval: 3 * time.Second, + Buckets: 3, + } + + sCms, _ := NewSlidingPredefinedCMSWithStartPoint(bCfg, cmsData, tm) + + buckets := sCms.Buckets() + tail, _ := buckets.Tail() + assert.Equal(t, 1, buckets.Size()) + assert.Equal(t, uint32(3), buckets.Capacity()) + assert.Equal(t, 1, tail.(*StubCms).id) + + sCms.updateBuckets(time.Unix(1, 0)) + buckets = sCms.Buckets() + tail, _ = buckets.Tail() + assert.Equal(t, 1, buckets.Size()) + assert.Equal(t, uint32(3), buckets.Capacity()) + assert.Equal(t, 1, tail.(*StubCms).id) + + sCms.updateBuckets(time.Unix(1, 1)) + buckets = sCms.Buckets() + tail, _ = buckets.Tail() + assert.Equal(t, 1, buckets.Size()) + assert.Equal(t, uint32(3), buckets.Capacity()) + assert.Equal(t, 1, tail.(*StubCms).id) + + sCms.updateBuckets(time.Unix(2, 1)) + buckets = sCms.Buckets() + tail, _ = buckets.Tail() + assert.Equal(t, 2, buckets.Size()) + assert.Equal(t, uint32(3), buckets.Capacity()) + assert.Equal(t, 2, tail.(*StubCms).id) + + sCms.updateBuckets(time.Unix(3, 1)) + buckets = sCms.Buckets() + tail, _ = buckets.Tail() + assert.Equal(t, 3, buckets.Size()) + assert.Equal(t, uint32(3), buckets.Capacity()) + assert.Equal(t, 3, tail.(*StubCms).id) +} + +func TestNewSlidingCMSUpdateBucketsOverlap(t *testing.T) { + tm := time.Unix(1, 0) + + cmsData := []CountMinSketch{ + NewEmptyCmsStub(1), + NewEmptyCmsStub(2), + NewEmptyCmsStub(3), + } + + bCfg := BucketsCfg{ + ObservationInterval: 3 * time.Second, + Buckets: 3, + } + + sCms, _ := NewSlidingPredefinedCMSWithStartPoint(bCfg, cmsData, tm) + + sCms.updateBuckets(time.Unix(2, 1)) + sCms.updateBuckets(time.Unix(3, 1)) + sCms.updateBuckets(time.Unix(4, 1)) + + buckets := sCms.Buckets() + tail, _ := buckets.Tail() + + assert.Equal(t, 3, buckets.Size()) + assert.Equal(t, uint32(3), buckets.Capacity()) + assert.Equal(t, 1, tail.(*StubCms).id) + + assert.Equal(t, 1, cmsData[0].(*StubCms).ClearCnt) + assert.Equal(t, 0, cmsData[1].(*StubCms).ClearCnt) + assert.Equal(t, 0, cmsData[2].(*StubCms).ClearCnt) + + assert.Equal(t, time.Unix(4, 0), sCms.CurrentObservationIntervalStartTm()) +} + +func TestNewSlidingCMSUpdateBucketsTimeGap(t *testing.T) { + tm := time.Unix(1, 0) + + cmsData := []CountMinSketch{ + NewEmptyCmsStub(1), + NewEmptyCmsStub(2), + NewEmptyCmsStub(3), + } + + bCfg := BucketsCfg{ + ObservationInterval: 3 * time.Second, + Buckets: 3, + } + + sCms, _ := NewSlidingPredefinedCMSWithStartPoint(bCfg, cmsData, tm) + + sCms.updateBuckets(time.Unix(4, 1)) + buckets := sCms.Buckets() + tail, _ := buckets.Tail() + + assert.Equal(t, 2, buckets.Size()) + assert.Equal(t, uint32(3), buckets.Capacity()) + assert.Equal(t, 2, tail.(*StubCms).id) + assert.Equal(t, 0, cmsData[0].(*StubCms).ClearCnt) + assert.Equal(t, 0, cmsData[1].(*StubCms).ClearCnt) + assert.Equal(t, 0, cmsData[2].(*StubCms).ClearCnt) + assert.Equal(t, time.Unix(1, 0), sCms.CurrentObservationIntervalStartTm()) + + sCms.updateBuckets(time.Unix(5, 1)) + sCms.updateBuckets(time.Unix(10, 1)) + buckets = sCms.Buckets() + tail, _ = buckets.Tail() + + assert.Equal(t, 3, buckets.Size()) + assert.Equal(t, uint32(3), buckets.Capacity()) + assert.Equal(t, 1, tail.(*StubCms).id) + assert.Equal(t, 1, cmsData[0].(*StubCms).ClearCnt) + assert.Equal(t, 0, cmsData[1].(*StubCms).ClearCnt) + assert.Equal(t, 0, cmsData[2].(*StubCms).ClearCnt) + assert.Equal(t, time.Unix(10, 0), sCms.CurrentObservationIntervalStartTm()) +} + +func TestNewSlidingCMSInsertOverlap(t *testing.T) { + tm := time.Unix(1, 0) + bCfg := BucketsCfg{ + ObservationInterval: 3 * time.Second, + Buckets: 3, + EstimationSoftLimit: 10, + } + cmsData := []CountMinSketch{ + NewEmptyCmsStub(1), + NewEmptyCmsStub(2), + NewEmptyCmsStub(3), + } + cmsKey := []byte(strconv.Itoa(1)) + + sCms, _ := NewSlidingPredefinedCMSWithStartPoint(bCfg, cmsData, tm) + + sCms.insertWithTime(cmsKey, time.Unix(1, 1)) + sCms.insertWithTime(cmsKey, time.Unix(2, 1)) + sCms.insertWithTime(cmsKey, time.Unix(3, 1)) + sCms.insertWithTime(cmsKey, time.Unix(4, 1)) + + assert.Equal(t, 2, cmsData[0].(*StubCms).InsertionsReq) + assert.Equal(t, 1, cmsData[1].(*StubCms).InsertionsReq) + assert.Equal(t, 1, cmsData[2].(*StubCms).InsertionsReq) +} + +func TestNewSlidingCMSCountOverlapWithEmptySoftLimit(t *testing.T) { + tm := time.Unix(1, 0) + bCfg := BucketsCfg{ + ObservationInterval: 3 * time.Second, + Buckets: 3, + EstimationSoftLimit: 0, + } + cmsData := []CountMinSketch{ + NewCmsStubWithCounts(1, CntMap{"1": 1}), + NewCmsStubWithCounts(2, CntMap{"1": 2}), + NewCmsStubWithCounts(3, CntMap{"1": 3}), + } + cmsKey := []byte(strconv.Itoa(1)) + + sCms, _ := NewSlidingPredefinedCMSWithStartPoint(bCfg, cmsData, tm) + + sCms.insertWithTime(cmsKey, time.Unix(1, 1)) + assert.Equal(t, uint32(1), sCms.Count(cmsKey)) + assert.Equal(t, 1, cmsData[0].(*StubCms).CountReq) + + sCms.insertWithTime(cmsKey, time.Unix(2, 1)) + assert.Equal(t, uint32(3), sCms.Count(cmsKey)) + assert.Equal(t, 2, cmsData[0].(*StubCms).CountReq) + assert.Equal(t, 1, cmsData[1].(*StubCms).CountReq) + + sCms.insertWithTime(cmsKey, time.Unix(3, 1)) + assert.Equal(t, uint32(6), sCms.Count(cmsKey)) + assert.Equal(t, 3, cmsData[0].(*StubCms).CountReq) + assert.Equal(t, 2, cmsData[1].(*StubCms).CountReq) + assert.Equal(t, 1, cmsData[2].(*StubCms).CountReq) + + sCms.insertWithTime(cmsKey, time.Unix(4, 1)) + assert.Equal(t, uint32(6), sCms.Count(cmsKey)) + assert.Equal(t, 4, cmsData[0].(*StubCms).CountReq) + assert.Equal(t, 3, cmsData[1].(*StubCms).CountReq) + assert.Equal(t, 2, cmsData[2].(*StubCms).CountReq) +} + +func TestNewSlidingCMSInsertWithCountOverlap(t *testing.T) { + tm := time.Unix(1, 0) + bCfg := BucketsCfg{ + ObservationInterval: 3 * time.Second, + Buckets: 3, + EstimationSoftLimit: 10, + } + cmsData := []CountMinSketch{ + NewCmsStubWithCounts(1, CntMap{"1": 1}), + NewCmsStubWithCounts(2, CntMap{"1": 2}), + NewCmsStubWithCounts(3, CntMap{"1": 3}), + } + cmsKey := []byte(strconv.Itoa(1)) + + sCms, _ := NewSlidingPredefinedCMSWithStartPoint(bCfg, cmsData, tm) + + assert.Equal(t, uint32(1), sCms.timedInsertWithCount(cmsKey, time.Unix(1, 1))) + assert.Equal(t, 0, cmsData[0].(*StubCms).CountReq) + assert.Equal(t, 1, cmsData[0].(*StubCms).InsertionsWithCnt) + + assert.Equal(t, uint32(3), sCms.timedInsertWithCount(cmsKey, time.Unix(2, 1))) + assert.Equal(t, 1, cmsData[0].(*StubCms).CountReq) + assert.Equal(t, 0, cmsData[1].(*StubCms).CountReq) + + assert.Equal(t, 1, cmsData[0].(*StubCms).InsertionsWithCnt) + assert.Equal(t, 1, cmsData[1].(*StubCms).InsertionsWithCnt) + + assert.Equal(t, uint32(6), sCms.timedInsertWithCount(cmsKey, time.Unix(3, 1))) + assert.Equal(t, 2, cmsData[0].(*StubCms).CountReq) + assert.Equal(t, 1, cmsData[1].(*StubCms).CountReq) + assert.Equal(t, 0, cmsData[2].(*StubCms).CountReq) + + assert.Equal(t, 1, cmsData[0].(*StubCms).InsertionsWithCnt) + assert.Equal(t, 1, cmsData[1].(*StubCms).InsertionsWithCnt) + assert.Equal(t, 1, cmsData[2].(*StubCms).InsertionsWithCnt) + + assert.Equal(t, uint32(6), sCms.timedInsertWithCount(cmsKey, time.Unix(4, 1))) + assert.Equal(t, 2, cmsData[0].(*StubCms).CountReq) + assert.Equal(t, 2, cmsData[1].(*StubCms).CountReq) + assert.Equal(t, 1, cmsData[2].(*StubCms).CountReq) + + assert.Equal(t, 2, cmsData[0].(*StubCms).InsertionsWithCnt) + assert.Equal(t, 1, cmsData[1].(*StubCms).InsertionsWithCnt) + assert.Equal(t, 1, cmsData[2].(*StubCms).InsertionsWithCnt) +} + +func TestNewSlidingCMSCountWithExactSoftLimit(t *testing.T) { + tm := time.Unix(1, 0) + bCfg := BucketsCfg{ + ObservationInterval: 3 * time.Second, + Buckets: 3, + EstimationSoftLimit: 2, + } + cmsData := []CountMinSketch{ + NewCmsStubWithCounts(1, CntMap{"1": 1}), + NewCmsStubWithCounts(2, CntMap{"1": 2}), + NewCmsStubWithCounts(3, CntMap{"1": 3}), + } + cmsKey := []byte(strconv.Itoa(1)) + + testCases := []struct { + Name string + SoftLimit uint32 + NumOfProbes int + }{ + { + Name: "exact_soft_limit_1", + SoftLimit: 1, + NumOfProbes: 1, + }, + + { + Name: "exact_soft_limit_2", + SoftLimit: 3, + NumOfProbes: 2, + }, + + { + Name: "exact_soft_limit_3", + SoftLimit: 6, + NumOfProbes: 3, + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + bCfg.EstimationSoftLimit = tc.SoftLimit + testCms := CopyCmsStubSlice(cmsData) + + sCms, _ := NewSlidingPredefinedCMSWithStartPoint(bCfg, testCms, tm) + sCms.updateBuckets(time.Unix(2, 1)) + sCms.updateBuckets(time.Unix(3, 1)) + + assert.Equal(t, tc.SoftLimit, sCms.Count(cmsKey)) + + for i := 0; i < tc.NumOfProbes; i++ { + assert.Equal(t, 1, testCms[i].(*StubCms).CountReq) + } + for i := tc.NumOfProbes; i < len(cmsData); i++ { + assert.Equal(t, 0, testCms[i].(*StubCms).CountReq) + } + }) + } +} + +func TestNewSlidingCMSCountWithSoftLimitOverflow(t *testing.T) { + tm := time.Unix(1, 0) + bCfg := BucketsCfg{ + ObservationInterval: 3 * time.Second, + Buckets: 3, + EstimationSoftLimit: 2, + } + cmsData := []CountMinSketch{ + NewCmsStubWithCounts(1, CntMap{"1": 2}), + NewCmsStubWithCounts(2, CntMap{"1": 3}), + NewCmsStubWithCounts(3, CntMap{"1": 4}), + } + cmsKey := []byte(strconv.Itoa(1)) + + testCases := []struct { + Name string + SoftLimit uint32 + NumOfProbes int + ExpectedCount uint32 + }{ + { + Name: "soft_limit_1", + SoftLimit: 1, + NumOfProbes: 1, + ExpectedCount: 2, + }, + + { + Name: "soft_limit_2", + SoftLimit: 3, + NumOfProbes: 2, + ExpectedCount: 5, + }, + + { + Name: "soft_limit_3", + SoftLimit: 6, + NumOfProbes: 3, + ExpectedCount: 9, + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + bCfg.EstimationSoftLimit = tc.SoftLimit + testCms := CopyCmsStubSlice(cmsData) + + sCms, _ := NewSlidingPredefinedCMSWithStartPoint(bCfg, testCms, tm) + sCms.updateBuckets(time.Unix(2, 1)) + sCms.updateBuckets(time.Unix(3, 1)) + + assert.Equal(t, tc.ExpectedCount, sCms.Count(cmsKey)) + + for i := 0; i < tc.NumOfProbes; i++ { + assert.Equal(t, 1, testCms[i].(*StubCms).CountReq) + } + for i := tc.NumOfProbes; i < len(cmsData); i++ { + assert.Equal(t, 0, testCms[i].(*StubCms).CountReq) + } + }) + } +} + +func TestNewSlidingCMSInsertCountWithSoftLimitOverflow(t *testing.T) { + tm := time.Unix(1, 0) + bCfg := BucketsCfg{ + ObservationInterval: 3 * time.Second, + Buckets: 3, + EstimationSoftLimit: 2, + } + cmsData := []CountMinSketch{ + NewCmsStubWithCounts(1, CntMap{"1": 2}), + NewCmsStubWithCounts(2, CntMap{"1": 3}), + NewCmsStubWithCounts(3, CntMap{"1": 4}), + } + cmsKey := []byte(strconv.Itoa(1)) + + testCases := []struct { + Name string + SoftLimit uint32 + NumOfProbes int + ExpectedCount uint32 + }{ + { + Name: "soft_limit_1", + SoftLimit: 3, + NumOfProbes: 0, + ExpectedCount: 4, + }, + + { + Name: "soft_limit_2", + SoftLimit: 5, + NumOfProbes: 1, + ExpectedCount: 6, + }, + + { + Name: "soft_limit_3", + SoftLimit: 8, + NumOfProbes: 2, + ExpectedCount: 9, + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + bCfg.EstimationSoftLimit = tc.SoftLimit + testCms := CopyCmsStubSlice(cmsData) + + sCms, _ := NewSlidingPredefinedCMSWithStartPoint(bCfg, testCms, tm) + sCms.updateBuckets(time.Unix(2, 1)) + sCms.updateBuckets(time.Unix(3, 1)) + + assert.Equal(t, tc.ExpectedCount, sCms.timedInsertWithCount(cmsKey, time.Unix(3, 1))) + tail, _ := sCms.Buckets().Tail() + assert.Equal(t, 1, tail.(*StubCms).InsertionsWithCnt) + + for i := 0; i < tc.NumOfProbes; i++ { + assert.Equal(t, 1, testCms[i].(*StubCms).CountReq) + } + for i := tc.NumOfProbes; i < len(cmsData); i++ { + assert.Equal(t, 0, testCms[i].(*StubCms).CountReq) + } + }) + } +} + +func TestNewSlidingCMSInsertCountWithExactSoftLimit(t *testing.T) { + tm := time.Unix(1, 0) + bCfg := BucketsCfg{ + ObservationInterval: 3 * time.Second, + Buckets: 3, + EstimationSoftLimit: 2, + } + cmsData := []CountMinSketch{ + NewCmsStubWithCounts(1, CntMap{"1": 2}), + NewCmsStubWithCounts(2, CntMap{"1": 3}), + NewCmsStubWithCounts(3, CntMap{"1": 4}), + } + cmsKey := []byte(strconv.Itoa(1)) + + testCases := []struct { + Name string + SoftLimit uint32 + NumOfProbes int + ExpectedCount uint32 + }{ + { + Name: "soft_limit_1", + SoftLimit: 4, + NumOfProbes: 0, + ExpectedCount: 4, + }, + + { + Name: "soft_limit_2", + SoftLimit: 6, + NumOfProbes: 1, + ExpectedCount: 6, + }, + + { + Name: "soft_limit_3", + SoftLimit: 8, + NumOfProbes: 2, + ExpectedCount: 9, + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + bCfg.EstimationSoftLimit = tc.SoftLimit + testCms := CopyCmsStubSlice(cmsData) + + sCms, _ := NewSlidingPredefinedCMSWithStartPoint(bCfg, testCms, tm) + sCms.updateBuckets(time.Unix(2, 1)) + sCms.updateBuckets(time.Unix(3, 1)) + + assert.Equal(t, tc.ExpectedCount, sCms.timedInsertWithCount(cmsKey, time.Unix(3, 1))) + tail, _ := sCms.Buckets().Tail() + assert.Equal(t, 1, tail.(*StubCms).InsertionsWithCnt) + + for i := 0; i < tc.NumOfProbes; i++ { + assert.Equal(t, 1, testCms[i].(*StubCms).CountReq) + } + for i := tc.NumOfProbes; i < len(cmsData); i++ { + assert.Equal(t, 0, testCms[i].(*StubCms).CountReq) + } + }) + } +} + +func TestNewSlidingCMSClearSimple(t *testing.T) { + tm := time.Unix(1, 0) + bCfg := BucketsCfg{ + ObservationInterval: 3 * time.Second, + Buckets: 3, + EstimationSoftLimit: 2, + } + cmsData := []CountMinSketch{ + NewEmptyCmsStub(1), + NewEmptyCmsStub(2), + NewEmptyCmsStub(3), + } + + sCms, _ := NewSlidingPredefinedCMSWithStartPoint(bCfg, cmsData, tm) + + assert.Zero(t, cmsData[0].(*StubCms).ClearCnt) + assert.Zero(t, cmsData[1].(*StubCms).ClearCnt) + assert.Zero(t, cmsData[2].(*StubCms).ClearCnt) + + sCms.Clear() + assert.Equal(t, 1, cmsData[0].(*StubCms).ClearCnt) + assert.Equal(t, 0, cmsData[1].(*StubCms).ClearCnt) + assert.Equal(t, 0, cmsData[2].(*StubCms).ClearCnt) + + sCms.updateBuckets(time.Unix(2, 1)) + sCms.Clear() + assert.Equal(t, 2, cmsData[0].(*StubCms).ClearCnt) + assert.Equal(t, 1, cmsData[1].(*StubCms).ClearCnt) + assert.Equal(t, 0, cmsData[2].(*StubCms).ClearCnt) +} + +func TestNewSlidingCMSClearOverflow(t *testing.T) { + tm := time.Unix(1, 0) + bCfg := BucketsCfg{ + ObservationInterval: 3 * time.Second, + Buckets: 3, + EstimationSoftLimit: 2, + } + cmsData := []CountMinSketch{ + NewEmptyCmsStub(1), + NewEmptyCmsStub(2), + NewEmptyCmsStub(3), + } + + sCms, _ := NewSlidingPredefinedCMSWithStartPoint(bCfg, cmsData, tm) + + assert.Zero(t, cmsData[0].(*StubCms).ClearCnt) + assert.Zero(t, cmsData[1].(*StubCms).ClearCnt) + assert.Zero(t, cmsData[2].(*StubCms).ClearCnt) + + sCms.updateBuckets(time.Unix(2, 1)) + sCms.updateBuckets(time.Unix(3, 1)) + sCms.Clear() + assert.Equal(t, 1, cmsData[0].(*StubCms).ClearCnt) + assert.Equal(t, 1, cmsData[1].(*StubCms).ClearCnt) + assert.Equal(t, 1, cmsData[2].(*StubCms).ClearCnt) + + sCms.updateBuckets(time.Unix(4, 1)) + assert.Equal(t, 2, cmsData[0].(*StubCms).ClearCnt) + assert.Equal(t, 1, cmsData[1].(*StubCms).ClearCnt) + assert.Equal(t, 1, cmsData[2].(*StubCms).ClearCnt) + + sCms.Clear() + assert.Equal(t, 3, cmsData[0].(*StubCms).ClearCnt) + assert.Equal(t, 2, cmsData[1].(*StubCms).ClearCnt) + assert.Equal(t, 2, cmsData[2].(*StubCms).ClearCnt) +} diff --git a/processor/tailsamplingprocessor/internal/cms/test_helper.go b/processor/tailsamplingprocessor/internal/cms/test_helper.go new file mode 100644 index 000000000000..7bb4d1782903 --- /dev/null +++ b/processor/tailsamplingprocessor/internal/cms/test_helper.go @@ -0,0 +1,90 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package cms // import "github.com/open-telemetry/opentelemetry-collector-contrib/processor/tailsamplingprocessor/internal/cms" + +import "fmt" + +func countOfInsertions(n int) int { + return (n / 2) * (1 + n) +} + +func makeTestCMSKey(n int) []byte { + return []byte(fmt.Sprintf("%d", n)) +} + +type partialVisitor[T any] struct { + currIdx int + maxNumOfEl int + visited []T +} + +func newPVisitor[T any](maxNumOfEl int) *partialVisitor[T] { + return &partialVisitor[T]{ + maxNumOfEl: maxNumOfEl, + visited: make([]T, 0), + } +} + +func (p *partialVisitor[T]) visit(val T) bool { + if p.currIdx >= p.maxNumOfEl { + return false + } + p.visited = append(p.visited, val) + p.currIdx++ + return true +} + +type CntMap map[string]uint32 + +type StubCms struct { + InsertionsReq int + InsertionsWithCnt int + InsertionsWithCntKeys []string + CountReq int + ClearCnt int + countResponses CntMap + id int +} + +func CopyCmsStubSlice(src []CountMinSketch) []CountMinSketch { + dst := make([]CountMinSketch, 0, len(src)) + for _, cs := range src { + s := *(cs.(*StubCms)) + dst = append(dst, &s) + } + return dst +} + +func NewEmptyCmsStub(id int) *StubCms { + return &StubCms{ + id: id, + InsertionsWithCntKeys: make([]string, 0), + } +} + +func NewCmsStubWithCounts(id int, counts CntMap) *StubCms { + return &StubCms{ + countResponses: counts, + id: id, + } +} + +func (s *StubCms) InsertWithCount(element []byte) uint32 { + s.InsertionsWithCnt++ + s.InsertionsWithCntKeys = append(s.InsertionsWithCntKeys, string(element)) + return s.countResponses[string(element)] +} + +func (s *StubCms) Count(element []byte) uint32 { + s.CountReq++ + return s.countResponses[string(element)] +} + +func (s *StubCms) Insert([]byte) { + s.InsertionsReq++ +} + +func (s *StubCms) Clear() { + s.ClearCnt++ +} diff --git a/processor/tailsamplingprocessor/internal/sampling/rare_spans.go b/processor/tailsamplingprocessor/internal/sampling/rare_spans.go new file mode 100644 index 000000000000..ceb2741d3a56 --- /dev/null +++ b/processor/tailsamplingprocessor/internal/sampling/rare_spans.go @@ -0,0 +1,177 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package sampling // import "github.com/open-telemetry/opentelemetry-collector-contrib/processor/tailsamplingprocessor/internal/sampling" + +import ( + "context" + "time" + + "go.opentelemetry.io/collector/component" + "go.opentelemetry.io/collector/pdata/pcommon" + semconv "go.opentelemetry.io/collector/semconv/v1.6.1" + "go.uber.org/zap" + + "github.com/open-telemetry/opentelemetry-collector-contrib/processor/tailsamplingprocessor/internal/cms" +) + +const ( + spansPerSecondSampledDefaultLimit = 1000 + spansPerSecondProcessedDefaultLimit = 100000 + spanUniqIDBufferSize = 20 * 1024 +) + +type RareSpansSampler struct { + // cntMinSketch implementation of the count-min sketch used to estimate the + // frequency of spans. + cntMinSketch cms.CountMinSketch + // maxSpanFreq span frequency value below which span should be sampled. + maxSpanFreq uint32 + // spsSampledLimit maximum number of sampled spans per second. + spsSampledLimit int64 + // spsSampled the number of already sampled spans for the current second. + spsSampled int64 + // spsProcessingLimit maximum number of spans that can be processed per + // second. + spsProcessingLimit int64 + // spsPrecessed the number of already processed spans for the current + // second. + spsPrecessed int64 + // currentSecond current second in unix time stamp format. Is used for + // counting the number of spans already sampled and processed. + currentSecond int64 + // idBuff buffer for creating and storing span CMS key. Used for optimization + // purposes to avoid unnecessary copying and resource allocation at strings + // concatenation. + idBuff []byte + + tmProvider TimeProvider + + logger *zap.Logger +} + +// ShouldBeSampled returns a decision about whether the span should be sampled +// based on its name and the service name. +func (r *RareSpansSampler) ShouldBeSampled(svcName, operationName string) bool { + r.idBuff = r.idBuff[:len(svcName)] + copy(r.idBuff, svcName) + copy(r.idBuff[len(svcName):len(svcName)+1], []byte{':'}) + copy(r.idBuff[len(svcName)+1:len(svcName)+1+len(operationName)], operationName) + r.idBuff = r.idBuff[:len(svcName)+1+len(operationName)] + + cnt := r.cntMinSketch.InsertWithCount(r.idBuff) + return cnt <= r.maxSpanFreq +} + +// Evaluate looks at the trace data and returns a corresponding SamplingDecision. +func (r *RareSpansSampler) Evaluate(_ context.Context, _ pcommon.TraceID, trace *TraceData) (Decision, error) { + var ( + shouldBeSampled bool + decision = NotSampled + ) + + currentSecond := r.tmProvider.getCurSecond() + if r.currentSecond != currentSecond { + r.currentSecond = currentSecond + r.spsSampled = 0 + r.spsPrecessed = 0 + } + + sps := trace.SpanCount.Load() + possibleSpsSampled := r.spsSampled + sps + possibleProcessed := r.spsPrecessed + sps + if possibleSpsSampled > r.spsSampledLimit || possibleProcessed > r.spsProcessingLimit { + return decision, nil + } + + trace.Lock() + defer trace.Unlock() + + for i := 0; i < trace.ReceivedBatches.ResourceSpans().Len(); i++ { + rs := trace.ReceivedBatches.ResourceSpans().At(i) + svcName, _ := rs.Resource().Attributes().Get(semconv.AttributeServiceName) + for j := 0; j < rs.ScopeSpans().Len(); j++ { + rss := rs.ScopeSpans().At(j).Spans() + for k := 0; k < rss.Len(); k++ { + keyLen := len(svcName.Str()) + 1 + len(rss.At(k).Name()) + if keyLen > spanUniqIDBufferSize { + r.logger.Error("too long span key", zap.Int("key_len", keyLen)) + continue + } + operationName := rss.At(k).Name() + if r.ShouldBeSampled(svcName.Str(), operationName) { + shouldBeSampled = true + } + } + } + } + + sps = trace.SpanCount.Load() + if shouldBeSampled { + decision = Sampled + r.spsSampled += sps + } + + r.spsPrecessed += sps + + return decision, nil +} + +// NewRareSpansSamplerWithCms creates a policy evaluator that samples traces +// based on spans frequency. CMS (Count-min sketch) is used to estimate the +// frequency of occurrence of a `span key`, where the `span key` consists of the +// span service name and the span name (operation name). +func NewRareSpansSamplerWithCms( + rareSpansFreq uint32, + spsSampledLimit int64, + spsProcessedLimit int64, + tmProvider TimeProvider, + cms cms.CountMinSketch, + settings component.TelemetrySettings, +) *RareSpansSampler { + return &RareSpansSampler{ + cntMinSketch: cms, + maxSpanFreq: rareSpansFreq, + spsSampledLimit: spsSampledLimit, + spsProcessingLimit: spsProcessedLimit, + currentSecond: tmProvider.getCurSecond(), + idBuff: make([]byte, 0, spanUniqIDBufferSize), + tmProvider: tmProvider, + logger: settings.Logger, + } +} + +// NewRareSpansSampler creates a policy evaluator that samples traces based +// on spans frequency. Unlike NewRareSpansSamplerWithCms, this function takes +// explicit CMS parameters as input. +func NewRareSpansSampler( + cmsCfg cms.CountMinSketchCfg, + bucketsCfg cms.BucketsCfg, + spanFreq uint32, + spsSampledLimit int64, + spsProcessedLimit int64, + tmProvider TimeProvider, + settings component.TelemetrySettings, +) (*RareSpansSampler, error) { + if spsSampledLimit == 0 { + spsSampledLimit = spansPerSecondSampledDefaultLimit + } + + if spsProcessedLimit == 0 { + spsProcessedLimit = spansPerSecondProcessedDefaultLimit + } + + slidingCms, err := cms.NewSlidingCMSWithStartPoint(bucketsCfg, cmsCfg, time.Now()) + if err != nil { + return nil, err + } + + return NewRareSpansSamplerWithCms( + spanFreq, + spsSampledLimit, + spsProcessedLimit, + tmProvider, + slidingCms, + settings, + ), nil +} diff --git a/processor/tailsamplingprocessor/internal/sampling/rare_spans_test.go b/processor/tailsamplingprocessor/internal/sampling/rare_spans_test.go new file mode 100644 index 000000000000..cb007e79f5e7 --- /dev/null +++ b/processor/tailsamplingprocessor/internal/sampling/rare_spans_test.go @@ -0,0 +1,478 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package sampling + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "go.opentelemetry.io/collector/component/componenttest" + "go.opentelemetry.io/collector/pdata/pcommon" + "go.opentelemetry.io/collector/pdata/ptrace" + semconv "go.opentelemetry.io/collector/semconv/v1.9.0" + + "github.com/open-telemetry/opentelemetry-collector-contrib/processor/tailsamplingprocessor/internal/cms" +) + +type mockTimeProvider struct { + seconds []int64 + nReq int +} + +func (f *mockTimeProvider) getCurSecond() int64 { + v := f.seconds[f.nReq%len(f.seconds)] + f.nReq++ + return v +} + +func newMockTimer(sec ...int64) *mockTimeProvider { + return &mockTimeProvider{ + seconds: append([]int64{}, sec...), + nReq: 0, + } +} + +func TestRareSpansSamplerSimple(t *testing.T) { + serviceName := "test_svc" + spanName := "test_span" + key := serviceName + ":" + spanName + spanCnt := int64(1) + + tmProvider := newMockTimer(0) + traceMock := newMockTrace(serviceName, []string{spanName}, spanCnt) + cmsStub := cms.NewCmsStubWithCounts(1, cms.CntMap{key: 1}) + + sampler := NewRareSpansSamplerWithCms( + 1, + 1, + 1, + tmProvider, + cmsStub, + componenttest.NewNopTelemetrySettings(), + ) + + des, err := sampler.Evaluate(context.Background(), pcommon.TraceID{}, traceMock) + + assert.NoError(t, err) + assert.Equal(t, Sampled, des) + assert.Equal(t, 1, cmsStub.InsertionsWithCnt) + assert.Equal(t, 0, cmsStub.CountReq) + assert.Equal(t, 0, cmsStub.InsertionsReq) + assert.Equal(t, 0, cmsStub.ClearCnt) + assert.Len(t, cmsStub.InsertionsWithCntKeys, 1) + assert.Equal(t, key, cmsStub.InsertionsWithCntKeys[0]) +} + +func TestRareSpansSamplerSampleOneSpanInTrace(t *testing.T) { + serviceName := "test_svc" + spanName1 := "test_span1" + spanName2 := "test_span2" + key1 := serviceName + ":" + spanName1 + key2 := serviceName + ":" + spanName2 + spanCnt := int64(1) + + tmProvider := newMockTimer(0) + traceMock := newMockTrace(serviceName, []string{spanName1, spanName2}, spanCnt) + cmsStub := cms.NewCmsStubWithCounts(1, cms.CntMap{key1: 2, key2: 1}) + + sampler := NewRareSpansSamplerWithCms( + 1, + 1, + 1, + tmProvider, + cmsStub, + componenttest.NewNopTelemetrySettings(), + ) + + des, err := sampler.Evaluate(context.Background(), pcommon.TraceID{}, traceMock) + + assert.NoError(t, err) + assert.Equal(t, Sampled, des) + assert.Equal(t, 2, cmsStub.InsertionsWithCnt) + assert.Equal(t, 0, cmsStub.CountReq) + assert.Equal(t, 0, cmsStub.InsertionsReq) + assert.Equal(t, 0, cmsStub.ClearCnt) + assert.Len(t, cmsStub.InsertionsWithCntKeys, 2) + assert.Equal(t, key1, cmsStub.InsertionsWithCntKeys[0]) + assert.Equal(t, key2, cmsStub.InsertionsWithCntKeys[1]) +} + +func TestRareSpansSamplerFreqLimit(t *testing.T) { + serviceName := "test_svc" + spanName := "test_span" + key := serviceName + ":" + spanName + spanCnt := int64(1) + + tmProvider := newMockTimer(0) + traceMock := newMockTrace(serviceName, []string{spanName}, spanCnt) + + testCases := []struct { + caseName string + cmsReturnValue uint32 + decision Decision + cmsFreqLimit uint32 + }{ + { + caseName: "below_limit", + cmsReturnValue: 1, + decision: Sampled, + cmsFreqLimit: 2, + }, + { + caseName: "equal_to_limit", + cmsReturnValue: 2, + decision: Sampled, + cmsFreqLimit: 2, + }, + + { + caseName: "above_limit", + cmsReturnValue: 3, + decision: NotSampled, + cmsFreqLimit: 2, + }, + } + + for _, tCase := range testCases { + t.Run(tCase.caseName, func(t *testing.T) { + cmsStub := cms.NewCmsStubWithCounts(1, cms.CntMap{key: tCase.cmsReturnValue}) + sampler := NewRareSpansSamplerWithCms( + tCase.cmsFreqLimit, + int64(tCase.cmsFreqLimit+1), + int64(tCase.cmsFreqLimit+1), + tmProvider, + cmsStub, + componenttest.NewNopTelemetrySettings(), + ) + + des, err := sampler.Evaluate(context.Background(), pcommon.TraceID{}, traceMock) + assert.NoError(t, err) + assert.Equal(t, tCase.decision, des) + }) + } +} + +func TestRareSpansSamplerKeyLenLimit(t *testing.T) { + serviceName := "test_svc" + spanCnt := int64(1) + tmProvider := newMockTimer(0) + + testCases := []struct { + caseName string + decision Decision + cmsFreqLimit uint32 + evalErr error + spanNames []string + cmsCounts int + }{ + { + caseName: "below_limit", + decision: Sampled, + cmsFreqLimit: 1, + spanNames: []string{string(make([]byte, spanUniqIDBufferSize-len(serviceName)-2))}, + cmsCounts: 1, + }, + { + caseName: "equal_to_limit", + decision: Sampled, + cmsFreqLimit: 1, + spanNames: []string{string(make([]byte, spanUniqIDBufferSize-len(serviceName)-1))}, + cmsCounts: 1, + }, + { + caseName: "above_limit", + decision: NotSampled, + cmsFreqLimit: 1, + spanNames: []string{string(make([]byte, spanUniqIDBufferSize-len(serviceName)))}, + cmsCounts: 0, + }, + + { + caseName: "one_span_above_limit", + decision: Sampled, + cmsFreqLimit: 1, + spanNames: []string{ + string(make([]byte, spanUniqIDBufferSize-len(serviceName))), + string(make([]byte, spanUniqIDBufferSize-len(serviceName)-1)), + }, + cmsCounts: 1, + }, + } + + for _, tCase := range testCases { + t.Run(tCase.caseName, func(t *testing.T) { + cmsStub := cms.NewEmptyCmsStub(1) + sampler := NewRareSpansSamplerWithCms( + tCase.cmsFreqLimit, + int64(tCase.cmsFreqLimit+1), + int64(tCase.cmsFreqLimit+1), + tmProvider, + cmsStub, + componenttest.NewNopTelemetrySettings(), + ) + + traceMock := newMockTrace(serviceName, tCase.spanNames, spanCnt) + des, err := sampler.Evaluate(context.Background(), pcommon.TraceID{}, traceMock) + assert.NoError(t, err) + assert.Equal(t, tCase.decision, des) + assert.Equal(t, tCase.cmsCounts, cmsStub.InsertionsWithCnt) + }) + } +} + +func TestRareSpansSamplerProcessedLimitSameSecond(t *testing.T) { + serviceName := "test_svc" + spanName := "test_span" + + tmProvider := newMockTimer(0) + + testCases := []struct { + caseName string + decision Decision + processingLimit int64 + spansInTrace int64 + cmsFreqLimit uint32 + cmsProbes int + }{ + { + caseName: "below_limit", + decision: Sampled, + processingLimit: 2, + spansInTrace: 1, + cmsFreqLimit: 1, + cmsProbes: 1, + }, + { + caseName: "equal_to_limit", + decision: Sampled, + processingLimit: 2, + spansInTrace: 2, + cmsFreqLimit: 1, + cmsProbes: 1, + }, + { + caseName: "above_to_limit", + decision: NotSampled, + processingLimit: 2, + spansInTrace: 3, + cmsFreqLimit: 1, + cmsProbes: 0, + }, + } + + for _, tCase := range testCases { + t.Run(tCase.caseName, func(t *testing.T) { + traceMock := newMockTrace(serviceName, []string{spanName}, tCase.spansInTrace) + cmsStub := cms.NewEmptyCmsStub(1) + sampler := NewRareSpansSamplerWithCms( + tCase.cmsFreqLimit, + tCase.processingLimit+1, + tCase.processingLimit, + tmProvider, + cmsStub, + componenttest.NewNopTelemetrySettings(), + ) + + des, err := sampler.Evaluate(context.Background(), pcommon.TraceID{}, traceMock) + assert.NoError(t, err) + assert.Equal(t, tCase.decision, des) + assert.Equal(t, tCase.cmsProbes, cmsStub.InsertionsWithCnt) + }) + } +} + +func TestRareSpansSamplerSampledLimitSameSecond(t *testing.T) { + serviceName := "test_svc" + spanName := "test_span" + + tmProvider := newMockTimer(0) + + testCases := []struct { + caseName string + decision Decision + sampledLimit int64 + spansInTrace int64 + cmsFreqLimit uint32 + cmsProbes int + }{ + { + caseName: "below_limit", + decision: Sampled, + sampledLimit: 2, + spansInTrace: 1, + cmsFreqLimit: 1, + cmsProbes: 1, + }, + { + caseName: "equal_to_limit", + decision: Sampled, + sampledLimit: 2, + spansInTrace: 2, + cmsFreqLimit: 1, + cmsProbes: 1, + }, + { + caseName: "above_to_limit", + decision: NotSampled, + sampledLimit: 2, + spansInTrace: 3, + cmsFreqLimit: 1, + cmsProbes: 0, + }, + } + + for _, tCase := range testCases { + t.Run(tCase.caseName, func(t *testing.T) { + traceMock := newMockTrace(serviceName, []string{spanName}, tCase.spansInTrace) + cmsStub := cms.NewEmptyCmsStub(1) + sampler := NewRareSpansSamplerWithCms( + tCase.cmsFreqLimit, + tCase.sampledLimit, + tCase.sampledLimit+1, + tmProvider, + cmsStub, + componenttest.NewNopTelemetrySettings(), + ) + + des, err := sampler.Evaluate(context.Background(), pcommon.TraceID{}, traceMock) + assert.NoError(t, err) + assert.Equal(t, tCase.decision, des) + assert.Equal(t, tCase.cmsProbes, cmsStub.InsertionsWithCnt) + }) + } +} + +func TestRareSpansSamplerSampledLimitDifferentSeconds(t *testing.T) { + serviceName := "test_svc" + spanName := "test_span" + + tmProvider := newMockTimer(0, 1) + + testCases := []struct { + caseName string + decisions []Decision + sampledLimit int64 + spansInTrace []int64 + cmsFreqLimit uint32 + cmsProbes int + }{ + { + caseName: "below_limit", + decisions: []Decision{Sampled, Sampled}, + sampledLimit: 1, + spansInTrace: []int64{1, 1}, + cmsFreqLimit: 1, + cmsProbes: 2, + }, + { + caseName: "above_limit_below_limit", + decisions: []Decision{NotSampled, Sampled}, + sampledLimit: 1, + spansInTrace: []int64{3, 1}, + cmsFreqLimit: 1, + cmsProbes: 1, + }, + } + + for _, tCase := range testCases { + t.Run(tCase.caseName, func(t *testing.T) { + cmsStub := cms.NewEmptyCmsStub(1) + sampler := NewRareSpansSamplerWithCms( + tCase.cmsFreqLimit, + tCase.sampledLimit, + tCase.sampledLimit+1, + tmProvider, + cmsStub, + componenttest.NewNopTelemetrySettings(), + ) + + for i := 0; i < len(tCase.spansInTrace); i++ { + traceMock := newMockTrace(serviceName, []string{spanName}, tCase.spansInTrace[i]) + des, err := sampler.Evaluate(context.Background(), pcommon.TraceID{}, traceMock) + assert.NoError(t, err) + assert.Equal(t, tCase.decisions[i], des) + } + assert.Equal(t, tCase.cmsProbes, cmsStub.InsertionsWithCnt) + }) + } +} + +func TestRareSpansSamplerProcessedLimitDifferentSeconds(t *testing.T) { + serviceName := "test_svc" + spanName := "test_span" + + tmProvider := newMockTimer(0, 1) + + testCases := []struct { + caseName string + decisions []Decision + processedLimit int64 + spansInTrace []int64 + cmsFreqLimit uint32 + cmsProbes int + }{ + { + caseName: "below_limit", + decisions: []Decision{Sampled, Sampled}, + processedLimit: 1, + spansInTrace: []int64{1, 1}, + cmsFreqLimit: 1, + cmsProbes: 2, + }, + { + caseName: "above_limit_below_limit", + decisions: []Decision{NotSampled, Sampled}, + processedLimit: 1, + spansInTrace: []int64{3, 1}, + cmsFreqLimit: 1, + cmsProbes: 1, + }, + } + + for _, tCase := range testCases { + t.Run(tCase.caseName, func(t *testing.T) { + cmsStub := cms.NewEmptyCmsStub(1) + sampler := NewRareSpansSamplerWithCms( + tCase.cmsFreqLimit, + tCase.processedLimit+1, + tCase.processedLimit, + tmProvider, + cmsStub, + componenttest.NewNopTelemetrySettings(), + ) + + for i := 0; i < len(tCase.spansInTrace); i++ { + traceMock := newMockTrace(serviceName, []string{spanName}, tCase.spansInTrace[i]) + des, err := sampler.Evaluate(context.Background(), pcommon.TraceID{}, traceMock) + assert.NoError(t, err) + assert.Equal(t, tCase.decisions[i], des) + } + assert.Equal(t, tCase.cmsProbes, cmsStub.InsertionsWithCnt) + }) + } +} + +func newMockTrace(svcName string, spansNames []string, spansCnt int64) *TraceData { + traces := ptrace.NewTraces() + rs := traces.ResourceSpans().AppendEmpty() + rs.Resource().Attributes().PutStr(semconv.AttributeServiceName, svcName) + ils := rs.ScopeSpans().AppendEmpty() + + for i, sp := range spansNames { + span := ils.Spans().AppendEmpty() + span.SetName(sp) + span.SetTraceID([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}) + span.SetSpanID([8]byte{byte(i), 2, 3, 4, 5, 6, 7, 8}) + } + + traceSpanCount := &atomic.Int64{} + traceSpanCount.Store(spansCnt) + + return &TraceData{ + ReceivedBatches: traces, + SpanCount: traceSpanCount, + } +} diff --git a/processor/tailsamplingprocessor/processor.go b/processor/tailsamplingprocessor/processor.go index 4515290198ac..a11e6f586b33 100644 --- a/processor/tailsamplingprocessor/processor.go +++ b/processor/tailsamplingprocessor/processor.go @@ -23,6 +23,7 @@ import ( "github.com/open-telemetry/opentelemetry-collector-contrib/internal/coreinternal/timeutils" "github.com/open-telemetry/opentelemetry-collector-contrib/processor/tailsamplingprocessor/internal/cache" + "github.com/open-telemetry/opentelemetry-collector-contrib/processor/tailsamplingprocessor/internal/cms" "github.com/open-telemetry/opentelemetry-collector-contrib/processor/tailsamplingprocessor/internal/idbatcher" "github.com/open-telemetry/opentelemetry-collector-contrib/processor/tailsamplingprocessor/internal/metadata" "github.com/open-telemetry/opentelemetry-collector-contrib/processor/tailsamplingprocessor/internal/sampling" @@ -235,6 +236,25 @@ func getSharedPolicyEvaluator(settings component.TelemetrySettings, cfg *sharedP case OTTLCondition: ottlfCfg := cfg.OTTLConditionCfg return sampling.NewOTTLConditionFilter(settings, ottlfCfg.SpanConditions, ottlfCfg.SpanEventConditions, ottlfCfg.ErrorMode) + case RareSpans: + rsCfg := cfg.RareSpansCfg + return sampling.NewRareSpansSampler( + cms.CountMinSketchCfg{ + ErrorProbability: rsCfg.ErrorProbability, + TotalFreq: rsCfg.TotalFreq, + MaxErr: rsCfg.MaxErrValue, + }, + cms.BucketsCfg{ + ObservationInterval: rsCfg.ObservationInterval, + Buckets: rsCfg.Buckets, + EstimationSoftLimit: rsCfg.RareSpanFrequency, + }, + rsCfg.RareSpanFrequency, + rsCfg.SpsSampledLimit, + rsCfg.SpsProcessedLimit, + &sampling.MonotonicClock{}, + settings, + ) default: return nil, fmt.Errorf("unknown sampling policy type %s", cfg.Type)