diff --git a/caboose.go b/caboose.go index 869fc11..2a49981 100644 --- a/caboose.go +++ b/caboose.go @@ -2,13 +2,13 @@ package caboose import ( "context" + "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "os" - "strings" "time" "github.com/filecoin-saturn/caboose/tieredhashing" @@ -32,7 +32,7 @@ type Config struct { // OrchestratorClient is the HTTP client to use when communicating with the Saturn orchestrator. OrchestratorClient *http.Client // OrchestratorOverride replaces calls to the orchestrator with a fixed response. - OrchestratorOverride []string + OrchestratorOverride []tieredhashing.NodeInfo // LoggingEndpoint is the URL of the logging endpoint where we submit logs pertaining to our Saturn retrieval requests. LoggingEndpoint url.URL @@ -75,6 +75,8 @@ type Config struct { SaturnNodeCoolOff time.Duration TieredHashingOpts []tieredhashing.Option + + ComplianceCidPeriod int64 } const DefaultLoggingInterval = 5 * time.Second @@ -90,7 +92,7 @@ const defaultMaxRetries = 3 const defaultMirrorFraction = 0.01 const maxBlockSize = 4194305 // 4 Mib + 1 byte -const DefaultOrchestratorEndpoint = "https://orchestrator.strn.pl/nodes/nearby?count=200" +const DefaultOrchestratorEndpoint = "https://orchestrator.strn.pl/nodes?maxNodes=200" const DefaultPoolRefreshInterval = 5 * time.Minute // we cool off sending requests to Saturn for a cid for a certain duration @@ -104,6 +106,10 @@ const defaultFetchKeyCoolDownDuration = 1 * time.Minute // how long will a sane // however, only upto a certain max number of cool-offs. const defaultSaturnNodeCoolOff = 5 * time.Minute +// This represents, on average, how many requests caboose makes before requesting a compliance cid. +// Example: a period of 100 implies Caboose will on average make a compliance CID request once every 100 requests. +const DefaultComplianceCidPeriod = int64(100) + var ErrNotImplemented error = errors.New("not implemented") var ErrNoBackend error = errors.New("no available saturn backend") var ErrContentProviderNotFound error = errors.New("saturn failed to find content providers") @@ -196,7 +202,13 @@ func NewCaboose(config *Config) (*Caboose, error) { config.MirrorFraction = defaultMirrorFraction } if override := os.Getenv(BackendOverrideKey); len(override) > 0 { - config.OrchestratorOverride = strings.Split(override, ",") + var overrideNodes []tieredhashing.NodeInfo + err := json.Unmarshal([]byte(override), &overrideNodes) + if err != nil { + goLogger.Warnf("Error parsing BackendOverrideKey:", "err", err) + return nil, err + } + config.OrchestratorOverride = overrideNodes } c := Caboose{ @@ -219,6 +231,10 @@ func NewCaboose(config *Config) (*Caboose, error) { } } + if c.config.ComplianceCidPeriod == 0 { + c.config.ComplianceCidPeriod = DefaultComplianceCidPeriod + } + if c.config.PoolRefresh == 0 { c.config.PoolRefresh = DefaultPoolRefreshInterval } diff --git a/caboose_test.go b/caboose_test.go index 3fe47b7..3d89285 100644 --- a/caboose_test.go +++ b/caboose_test.go @@ -6,6 +6,15 @@ import ( "crypto/tls" "encoding/json" "fmt" + "io" + "math/rand" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + "github.com/filecoin-saturn/caboose" "github.com/filecoin-saturn/caboose/tieredhashing" "github.com/ipfs/go-cid" @@ -17,13 +26,6 @@ import ( selectorparse "github.com/ipld/go-ipld-prime/traversal/selector/parse" "github.com/multiformats/go-multicodec" "github.com/stretchr/testify/require" - "io" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" - "time" ) func TestCidCoolDown(t *testing.T) { @@ -212,11 +214,19 @@ func BuildCabooseHarness(t *testing.T, n int, maxRetries int, opts ...HarnessOpt ch := &CabooseHarness{} ch.pool = make([]*ep, n) - purls := make([]string, n) + purls := make([]tieredhashing.NodeInfo, n) for i := 0; i < len(ch.pool); i++ { ch.pool[i] = &ep{} ch.pool[i].Setup() - purls[i] = strings.TrimPrefix(ch.pool[i].server.URL, "https://") + ip := strings.TrimPrefix(ch.pool[i].server.URL, "https://") + cid, _ := cid.V1Builder{Codec: uint64(multicodec.Raw), MhType: uint64(multicodec.Sha2_256)}.Sum([]byte(ip)) + purls[i] = tieredhashing.NodeInfo{ + IP: ip, + ID: "node-id", + Weight: rand.Intn(100), + Distance: rand.Float32(), + ComplianceCid: cid.String(), + } } ch.goodOrch = true orch := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/failure_test.go b/failure_test.go index 5543f3e..dca7255 100644 --- a/failure_test.go +++ b/failure_test.go @@ -3,15 +3,16 @@ package caboose_test import ( "context" "errors" - "github.com/filecoin-saturn/caboose" - "github.com/ipfs/go-cid" - "github.com/multiformats/go-multicodec" - "github.com/stretchr/testify/require" "net/http" "net/http/httptest" "sync" "testing" "time" + + "github.com/filecoin-saturn/caboose" + "github.com/ipfs/go-cid" + "github.com/multiformats/go-multicodec" + "github.com/stretchr/testify/require" ) var expRetryAfter = 1 * time.Second diff --git a/fetcher.go b/fetcher.go index 28ce84d..b9173ef 100644 --- a/fetcher.go +++ b/fetcher.go @@ -32,6 +32,7 @@ const ( saturnRetryAfterKey = "Retry-After" resourceTypeCar = "car" resourceTypeBlock = "block" + complianceCidPeriod = 200 ) var ( diff --git a/metrics.go b/metrics.go index 07404a6..b13331b 100644 --- a/metrics.go +++ b/metrics.go @@ -238,6 +238,10 @@ var ( mirroredTrafficTotalMetric = prometheus.NewCounterVec(prometheus.CounterOpts{ Name: prometheus.BuildFQName("ipfs", "caboose", "mirrored_traffic_total"), }, []string{"error_status"}) + + complianceCidCallsTotalMetric = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: prometheus.BuildFQName("ipfs", "caboose", "compliance_cids_total"), + }, []string{"error_status"}) ) var CabooseMetrics = prometheus.NewRegistry() @@ -292,6 +296,7 @@ func init() { CabooseMetrics.MustRegister(saturnCallsTotalMetric) CabooseMetrics.MustRegister(saturnCallsFailureTotalMetric) CabooseMetrics.MustRegister(saturnConnectionFailureTotalMetric) + CabooseMetrics.MustRegister(complianceCidCallsTotalMetric) CabooseMetrics.MustRegister(saturnCallsSuccessTotalMetric) diff --git a/pool.go b/pool.go index b20e156..4b8956d 100644 --- a/pool.go +++ b/pool.go @@ -2,11 +2,14 @@ package caboose import ( "context" + cryptoRand "crypto/rand" "encoding/json" "errors" "fmt" "io" + "math/big" "math/rand" + "net/http" "net/url" "os" "sync" @@ -30,25 +33,38 @@ const ( BackendOverrideKey = "CABOOSE_BACKEND_OVERRIDE" ) +var complianceCidReqTemplate = "/ipfs/%s?format=raw" + // loadPool refreshes the set of Saturn endpoints in the pool by fetching an updated list of responsive Saturn nodes from the // Saturn Orchestrator. -func (p *pool) loadPool() ([]string, error) { +func (p *pool) loadPool() ([]tieredhashing.NodeInfo, error) { + if p.config.OrchestratorOverride != nil { return p.config.OrchestratorOverride, nil } + client := p.config.OrchestratorClient + + req, err := http.NewRequest("GET", p.config.OrchestratorEndpoint.String(), nil) - resp, err := p.config.OrchestratorClient.Get(p.config.OrchestratorEndpoint.String()) if err != nil { - goLogger.Warnw("failed to get backends from orchestrator", "err", err, "endpoint", p.config.OrchestratorEndpoint.String()) + goLogger.Warnw("failed to create request to orchestrator", "err", err, "endpoint", p.config.OrchestratorEndpoint) + return nil, err + } + resp, err := client.Do(req) + + if err != nil { + goLogger.Warnw("failed to get backends from orchestrator", "err", err, "endpoint", p.config.OrchestratorEndpoint) return nil, err } defer resp.Body.Close() - responses := make([]string, 0) + responses := make([]tieredhashing.NodeInfo, 0) + if err := json.NewDecoder(resp.Body).Decode(&responses); err != nil { goLogger.Warnw("failed to decode backends from orchestrator", "err", err, "endpoint", p.config.OrchestratorEndpoint.String()) return nil, err } + goLogger.Infow("got backends from orchestrators", "cnt", len(responses), "endpoint", p.config.OrchestratorEndpoint.String()) return responses, nil } @@ -116,7 +132,7 @@ func (p *pool) doRefresh() { } } -func (p *pool) refreshWithNodes(newEP []string) { +func (p *pool) refreshWithNodes(newEP []tieredhashing.NodeInfo) { p.lk.Lock() defer p.lk.Unlock() @@ -188,6 +204,20 @@ func (p *pool) refreshPool() { } } +func (p *pool) fetchComplianceCid(node string) error { + sc, err := p.th.GetComplianceCid(node) + if err != nil { + goLogger.Warnw("failed to find compliance cid ", "err", err) + return err + } + trialTimeout, cancel := context.WithTimeout(context.Background(), 30*time.Second) + reqUrl := fmt.Sprintf(complianceCidReqTemplate, sc) + goLogger.Debugw("fetching compliance cid", "cid", reqUrl, "from", node) + err = p.fetchResourceAndUpdate(trialTimeout, node, reqUrl, 0, p.mirrorValidator) + cancel() + return err +} + func (p *pool) checkPool() { for { select { @@ -204,7 +234,25 @@ func (p *pool) checkPool() { continue } trialTimeout, cancel := context.WithTimeout(context.Background(), 30*time.Second) - err := p.fetchResourceAndUpdate(trialTimeout, testNodes[0], msg.path, 0, p.mirrorValidator) + + node := testNodes[0] + err := p.fetchResourceAndUpdate(trialTimeout, node, msg.path, 0, p.mirrorValidator) + + rand := big.NewInt(1) + if p.config.ComplianceCidPeriod > 0 { + rand, _ = cryptoRand.Int(cryptoRand.Reader, big.NewInt(p.config.ComplianceCidPeriod)) + } + + if rand.Cmp(big.NewInt(0)) == 0 { + err := p.fetchComplianceCid(node) + if err != nil { + goLogger.Warnw("failed to fetch compliance cid ", "err", err) + complianceCidCallsTotalMetric.WithLabelValues("error").Add(1) + } else { + complianceCidCallsTotalMetric.WithLabelValues("success").Add(1) + } + } + cancel() if err != nil { mirroredTrafficTotalMetric.WithLabelValues("error").Inc() @@ -277,6 +325,7 @@ func cidToKey(c cid.Cid) string { } func (p *pool) fetchBlockWith(ctx context.Context, c cid.Cid, with string) (blk blocks.Block, err error) { + fetchCalledTotalMetric.WithLabelValues(resourceTypeBlock).Add(1) if recordIfContextErr(resourceTypeBlock, ctx, "fetchBlockWith") { return nil, ctx.Err() diff --git a/pool_refresh_test.go b/pool_refresh_test.go index 1012585..c3c014c 100644 --- a/pool_refresh_test.go +++ b/pool_refresh_test.go @@ -1,9 +1,11 @@ package caboose import ( + "math/rand" + "testing" + "github.com/filecoin-saturn/caboose/tieredhashing" "github.com/stretchr/testify/require" - "testing" ) func TestPoolRefresh(t *testing.T) { @@ -59,7 +61,20 @@ func TestPoolRefreshWithLatencyDistribution(t *testing.T) { } func andAndAssertPool(t *testing.T, p *pool, nodes []string, expectedMain, expectedUnknown, expectedTotal, expectedNew int) { - p.refreshWithNodes(nodes) + + parsedNodes := make([]tieredhashing.NodeInfo, 0) + + for _, n := range nodes { + parsedNodes = append(parsedNodes, tieredhashing.NodeInfo{ + IP: n, + ID: n, + Weight: rand.Intn(100), + Distance: rand.Float32(), + ComplianceCid: n, + }) + } + + p.refreshWithNodes(parsedNodes) nds := p.th.GetPerf() require.Equal(t, expectedTotal, len(nds)) mts := p.th.GetPoolMetrics() diff --git a/pool_test.go b/pool_test.go index fa10a90..4cf0cc0 100644 --- a/pool_test.go +++ b/pool_test.go @@ -4,6 +4,8 @@ import ( "bytes" "context" "crypto/tls" + "encoding/json" + "math/rand" "net/http" "net/http/httptest" "net/url" @@ -22,6 +24,7 @@ import ( "github.com/ipld/go-ipld-prime/storage/memstore" selectorparse "github.com/ipld/go-ipld-prime/traversal/selector/parse" "github.com/multiformats/go-multicodec" + "github.com/stretchr/testify/assert" ) type ep struct { @@ -66,6 +69,130 @@ func TestPoolMiroring(t *testing.T) { tieredhashing.WithLatencyWindowSize(2), tieredhashing.WithMaxMainTierSize(1), } + ph := BuildPoolHarness(t, 2, opts) + + p := ph.p + p.config.ComplianceCidPeriod = 0 + nodes := ph.p.config.OrchestratorOverride + p.doRefresh() + p.config.OrchestratorOverride = nil + p.Start() + + // promote one node to main pool. other will remain in uknown pool. + eURL := nodes[0].IP + p.th.RecordSuccess(eURL, tieredhashing.ResponseMetrics{Success: true, TTFBMs: 30, SpeedPerMs: 30}) + p.th.RecordSuccess(eURL, tieredhashing.ResponseMetrics{Success: true, TTFBMs: 30, SpeedPerMs: 30}) + p.th.UpdateMainTierWithTopN() + + ls := cidlink.DefaultLinkSystem() + lsm := memstore.Store{} + ls.SetReadStorage(&lsm) + ls.SetWriteStorage(&lsm) + finalCL := ls.MustStore(ipld.LinkContext{}, cidlink.LinkPrototype{Prefix: cid.NewPrefixV1(uint64(multicodec.Raw), uint64(multicodec.Sha2_256))}, basicnode.NewBytes(testBlock)) + finalC := finalCL.(cidlink.Link).Cid + + _, err := p.fetchBlockWith(context.Background(), finalC, "") + if err != nil { + t.Fatal(err) + } + + time.Sleep(100 * time.Millisecond) + p.Close() + + for _, e := range ph.eps { + e.lk.Lock() + defer e.lk.Unlock() + if e.cnt != 1 { + t.Fatalf("expected 1 primary fetch, got %d", e.cnt) + } + } +} + +func TestLoadPool(t *testing.T) { + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cid, _ := cid.V1Builder{Codec: uint64(multicodec.Raw), MhType: uint64(multicodec.Sha2_256)}.Sum([]byte("node")) + response := [1]tieredhashing.NodeInfo{{ + IP: "node", + ID: "node", + Weight: rand.Intn(100), + Distance: rand.Float32(), + ComplianceCid: cid.String(), + }} + + w.Header().Set("Content-Type", "application/json") + + // Encoding the response to JSON + json.NewEncoder(w).Encode(response) + })) + + endpoint, _ := url.Parse(server.URL) + p := &pool{ + config: &Config{ + OrchestratorEndpoint: endpoint, + OrchestratorClient: http.DefaultClient, + }, + } + + _, err := p.loadPool() + + assert.NoError(t, err) +} + +func TestFetchComplianceCid(t *testing.T) { + if unsafe.Sizeof(unsafe.Pointer(nil)) <= 4 { + t.Skip("skipping for 32bit architectures because too slow") + } + opts := []tieredhashing.Option{ + tieredhashing.WithCorrectnessWindowSize(2), + tieredhashing.WithLatencyWindowSize(2), + tieredhashing.WithMaxMainTierSize(1), + } + ph := BuildPoolHarness(t, 2, opts) + + p := ph.p + p.config.ComplianceCidPeriod = 1 + nodes := ph.p.config.OrchestratorOverride + p.doRefresh() + p.config.OrchestratorOverride = nil + p.Start() + + // promote one node to main pool. other will remain in uknown pool. + eURL := nodes[0].IP + p.th.RecordSuccess(eURL, tieredhashing.ResponseMetrics{Success: true, TTFBMs: 30, SpeedPerMs: 30}) + p.th.RecordSuccess(eURL, tieredhashing.ResponseMetrics{Success: true, TTFBMs: 30, SpeedPerMs: 30}) + p.th.UpdateMainTierWithTopN() + + ls := cidlink.DefaultLinkSystem() + lsm := memstore.Store{} + ls.SetReadStorage(&lsm) + ls.SetWriteStorage(&lsm) + finalCL := ls.MustStore(ipld.LinkContext{}, cidlink.LinkPrototype{Prefix: cid.NewPrefixV1(uint64(multicodec.Raw), uint64(multicodec.Sha2_256))}, basicnode.NewBytes(testBlock)) + finalC := finalCL.(cidlink.Link).Cid + + _, err := p.fetchBlockWith(context.Background(), finalC, "") + if err != nil { + t.Fatal(err) + } + + time.Sleep(100 * time.Millisecond) + p.Close() + + e := ph.eps[1] + e.lk.Lock() + defer e.lk.Unlock() + + if e.cnt != 2 { + t.Fatalf("expected 2 primary fetch, got %d", e.cnt) + } +} + +type PoolHarness struct { + p *pool + eps []*ep +} + +func BuildPoolHarness(t *testing.T, n int, opts []tieredhashing.Option) *PoolHarness { saturnClient := &http.Client{ Transport: &http.Transport{ @@ -75,12 +202,11 @@ func TestPoolMiroring(t *testing.T) { }, } - data := []byte("hello world") ls := cidlink.DefaultLinkSystem() lsm := memstore.Store{} ls.SetReadStorage(&lsm) ls.SetWriteStorage(&lsm) - finalCL := ls.MustStore(ipld.LinkContext{}, cidlink.LinkPrototype{Prefix: cid.NewPrefixV1(uint64(multicodec.Raw), uint64(multicodec.Sha2_256))}, basicnode.NewBytes(data)) + finalCL := ls.MustStore(ipld.LinkContext{}, cidlink.LinkPrototype{Prefix: cid.NewPrefixV1(uint64(multicodec.Raw), uint64(multicodec.Sha2_256))}, basicnode.NewBytes(testBlock)) finalC := finalCL.(cidlink.Link).Cid cw, err := car.NewSelectiveWriter(context.TODO(), &ls, finalC, selectorparse.CommonSelector_MatchAllRecursively) if err != nil { @@ -89,28 +215,33 @@ func TestPoolMiroring(t *testing.T) { carBytes := bytes.NewBuffer(nil) cw.WriteTo(carBytes) - e := ep{} - e.Setup() - e.lk.Lock() - e.resp = carBytes.Bytes() - eURL := strings.TrimPrefix(e.server.URL, "https://") - e.lk.Unlock() + nodeInfos := make([]tieredhashing.NodeInfo, n) + eps := make([]*ep, n) - e2 := ep{} - e2.Setup() - e2.lk.Lock() - e2.resp = carBytes.Bytes() - e2URL := strings.TrimPrefix(e2.server.URL, "https://") - e2.lk.Unlock() + for i := 0; i < n; i++ { + eps[i] = &ep{} + eps[i].Setup() + eps[i].lk.Lock() + eps[i].resp = carBytes.Bytes() + eURL := strings.TrimPrefix(eps[i].server.URL, "https://") + nodeInfos[i] = tieredhashing.NodeInfo{ + IP: eURL, + ID: eURL, + Weight: rand.Intn(100), + Distance: rand.Float32(), + ComplianceCid: finalC.String(), + } + eps[i].lk.Unlock() + + } conf := Config{ OrchestratorEndpoint: &url.URL{}, OrchestratorClient: http.DefaultClient, - OrchestratorOverride: []string{eURL, e2URL}, + OrchestratorOverride: nodeInfos, LoggingEndpoint: url.URL{}, LoggingClient: http.DefaultClient, LoggingInterval: time.Hour, - SaturnClient: saturnClient, DoValidation: false, PoolRefresh: time.Minute, @@ -119,32 +250,9 @@ func TestPoolMiroring(t *testing.T) { MirrorFraction: 1.0, } - p := newPool(&conf) - p.doRefresh() - p.config.OrchestratorOverride = nil - p.Start() - - // promote one node to main pool. other will remain in uknown pool. - p.th.RecordSuccess(eURL, tieredhashing.ResponseMetrics{Success: true, TTFBMs: 30, SpeedPerMs: 30}) - p.th.RecordSuccess(eURL, tieredhashing.ResponseMetrics{Success: true, TTFBMs: 30, SpeedPerMs: 30}) - p.th.UpdateMainTierWithTopN() - - _, err = p.fetchBlockWith(context.Background(), finalC, "") - if err != nil { - t.Fatal(err) - } - - time.Sleep(100 * time.Millisecond) - p.Close() - - e.lk.Lock() - defer e.lk.Unlock() - if e.cnt != 1 { - t.Fatalf("expected 1 primary fetch, got %d", e.cnt) - } - e2.lk.Lock() - defer e2.lk.Unlock() - if e2.cnt != 1 { - t.Fatalf("expected 1 mirrored fetch, got %d", e2.cnt) + ph := &PoolHarness{ + p: newPool(&conf), + eps: eps, } + return ph } diff --git a/tieredhashing/tiered_hashing.go b/tieredhashing/tiered_hashing.go index d3b7f40..d254ca8 100644 --- a/tieredhashing/tiered_hashing.go +++ b/tieredhashing/tiered_hashing.go @@ -1,6 +1,7 @@ package tieredhashing import ( + "fmt" "math" "net/http" "time" @@ -40,6 +41,14 @@ const ( type Tier string +type NodeInfo struct { + ID string `json:"id"` + IP string `json:"ip"` + Distance float32 `json:"distance"` + Weight int `json:"weight"` + ComplianceCid string `json:"complianceCid"` +} + type NodePerf struct { LatencyDigest *rolling.PointPolicy NLatencyDigest float64 @@ -55,6 +64,9 @@ type NodePerf struct { connFailures int networkErrors int responseCodes int + + // Node Info + NodeInfo } // locking is left to the caller @@ -241,7 +253,21 @@ func (t *TieredHashing) GetPerf() map[string]*NodePerf { return t.nodes } -func (t *TieredHashing) AddOrchestratorNodes(nodes []string) (added, alreadyRemoved, removedAndAddedBack int) { +func (t *TieredHashing) GetComplianceCid(ip string) (string, error) { + if node, ok := t.nodes[ip]; ok { + if len(node.ComplianceCid) > 0 { + return node.ComplianceCid, nil + } else { + return "", fmt.Errorf("compliance cid doesn't exist for node: %s ", ip) + } + + } else { + return "", fmt.Errorf("node with IP: %s is not in Caboose pool ", ip) + } +} + +func (t *TieredHashing) AddOrchestratorNodes(nodes []NodeInfo) (added, alreadyRemoved, removedAndAddedBack int) { + for _, node := range nodes { // TODO Add nodes that are closer than the ones we have even if the pool is full if len(t.nodes) >= t.cfg.MaxPoolSize { @@ -249,23 +275,25 @@ func (t *TieredHashing) AddOrchestratorNodes(nodes []string) (added, alreadyRemo } // do we already have this node ? - if _, ok := t.nodes[node]; ok { + if _, ok := t.nodes[node.IP]; ok { continue } // have we kicked this node out for bad correctness or latency ? - if _, ok := t.removedNodesTimeCache.Get(node); ok { + if _, ok := t.removedNodesTimeCache.Get(node.IP); ok { alreadyRemoved++ continue } added++ - t.nodes[node] = &NodePerf{ + t.nodes[node.IP] = &NodePerf{ LatencyDigest: rolling.NewPointPolicy(rolling.NewWindow(int(t.cfg.LatencyWindowSize))), CorrectnessDigest: rolling.NewPointPolicy(rolling.NewWindow(int(t.cfg.CorrectnessWindowSize))), Tier: TierUnknown, + + NodeInfo: node, } - t.unknownSet = t.unknownSet.AddNode(node) + t.unknownSet = t.unknownSet.AddNode(node.IP) } // Avoid Pool starvation -> if we still don't have enough nodes, add the ones we have already removed @@ -276,23 +304,23 @@ func (t *TieredHashing) AddOrchestratorNodes(nodes []string) (added, alreadyRemo } // do we already have this node ? - if _, ok := t.nodes[node]; ok { + if _, ok := t.nodes[node.IP]; ok { continue } - if _, ok := t.removedNodesTimeCache.Get(node); !ok { + if _, ok := t.removedNodesTimeCache.Get(node.IP); !ok { continue } added++ removedAndAddedBack++ - t.nodes[node] = &NodePerf{ + t.nodes[node.IP] = &NodePerf{ LatencyDigest: rolling.NewPointPolicy(rolling.NewWindow(int(t.cfg.LatencyWindowSize))), CorrectnessDigest: rolling.NewPointPolicy(rolling.NewWindow(int(t.cfg.CorrectnessWindowSize))), Tier: TierUnknown, } - t.unknownSet = t.unknownSet.AddNode(node) - t.removedNodesTimeCache.Delete(node) + t.unknownSet = t.unknownSet.AddNode(node.IP) + t.removedNodesTimeCache.Delete(node.IP) } return diff --git a/tieredhashing/tiered_hashing_test.go b/tieredhashing/tiered_hashing_test.go index 0d21bc3..2584263 100644 --- a/tieredhashing/tiered_hashing_test.go +++ b/tieredhashing/tiered_hashing_test.go @@ -2,11 +2,15 @@ package tieredhashing import ( "fmt" + "math/rand" "net/http" "sort" "testing" "github.com/asecurityteam/rolling" + "github.com/ipfs/go-cid" + "github.com/multiformats/go-multicodec" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -115,6 +119,43 @@ func TestMoveBestUnknownToMain(t *testing.T) { th.h.nodes[nodes[0]].Tier = TierUnknown } +func TestComplianceCids(t *testing.T) { + + th := NewTieredHashingHarness() + + nodes := th.genAndAddAll(t, 10) + th.h.AddOrchestratorNodes(genNodeStructs(nodes)) + + t.Run("compliance cids exist for existing nodes", func(t *testing.T) { + for _, node := range nodes { + _, err := th.h.GetComplianceCid(node) + assert.NoError(t, err, "Compliance Cids should always exist for nodes that are part of the pool") + } + }) + + newNodes := []string{"new-node1", "new-node2"} + th.addNewNodesAll(t, newNodes) + th.h.AddOrchestratorNodes(genNodeStructs(newNodes)) + t.Run("compliance cids exist for new nodes", func(t *testing.T) { + for _, node := range newNodes { + _, err := th.h.GetComplianceCid(node) + assert.NoError(t, err, "Compliance Cids should always exist for new added nodes") + + } + }) + + for _, node := range newNodes { + th.h.removeFailedNode(node) + } + + t.Run("compliance cids do not exist for removed nodes", func(t *testing.T) { + for _, node := range newNodes { + _, err := th.h.GetComplianceCid(node) + assert.Error(t, err, "Compliance cids do not exist for removed nodes") + } + }) +} + func TestNodeNotRemovedWithVar(t *testing.T) { window := 2 th := NewTieredHashingHarness(WithCorrectnessWindowSize(window), WithFailureDebounce(0), WithNoRemove(true)) @@ -467,13 +508,13 @@ func TestAddOrchestratorNodesMax(t *testing.T) { // empty -> 10 get added nodes := th.genNodes(t, 30) - a, _, _ := th.h.AddOrchestratorNodes(nodes) + a, _, _ := th.h.AddOrchestratorNodes(genNodeStructs(nodes)) require.EqualValues(t, 10, a) th.assertSize(t, 0, 10) // nothing gets added as we are full nodes2 := th.genNodes(t, 30) - a, _, _ = th.h.AddOrchestratorNodes(append(nodes, nodes2...)) + a, _, _ = th.h.AddOrchestratorNodes(append(genNodeStructs(nodes), genNodeStructs(nodes2)...)) require.EqualValues(t, 0, a) th.assertSize(t, 0, 10) @@ -484,7 +525,7 @@ func TestAddOrchestratorNodesMax(t *testing.T) { th.assertSize(t, 0, 8) // 2 get added now - a, _, _ = th.h.AddOrchestratorNodes(append(nodes, nodes2...)) + a, _, _ = th.h.AddOrchestratorNodes(append(genNodeStructs(nodes), genNodeStructs(nodes2)...)) require.EqualValues(t, 2, a) th.assertSize(t, 0, 10) @@ -492,7 +533,7 @@ func TestAddOrchestratorNodesMax(t *testing.T) { th.assertSize(t, 0, 9) // removed node does not get added back as we are already full without it - a, ar, back := th.h.AddOrchestratorNodes(append(nodes, "newnode")) + a, ar, back := th.h.AddOrchestratorNodes(append(genNodeStructs(nodes), genNodeStructs([]string{"newNode"})...)) require.EqualValues(t, 1, a) require.EqualValues(t, 3, ar) th.assertSize(t, 0, 10) @@ -526,6 +567,22 @@ func (th *TieredHashingHarness) genNodes(t *testing.T, n int) []string { return nodes } +func genNodeStructs(nodes []string) []NodeInfo { + var nodeStructs []NodeInfo + + for _, node := range nodes { + cid, _ := cid.V1Builder{Codec: uint64(multicodec.Raw), MhType: uint64(multicodec.Sha2_256)}.Sum([]byte(node)) + nodeStructs = append(nodeStructs, NodeInfo{ + IP: node, + ID: node, + Weight: rand.Intn(100), + Distance: rand.Float32(), + ComplianceCid: cid.String(), + }) + } + return nodeStructs +} + func (th *TieredHashingHarness) addNewNodesAll(t *testing.T, nodes []string) { var old []string @@ -533,13 +590,13 @@ func (th *TieredHashingHarness) addNewNodesAll(t *testing.T, nodes []string) { old = append(old, key) } - added, already, _ := th.h.AddOrchestratorNodes(append(nodes, old...)) + added, already, _ := th.h.AddOrchestratorNodes(append(genNodeStructs(nodes), genNodeStructs(old)...)) require.Zero(t, already) require.EqualValues(t, len(nodes), added) } func (th *TieredHashingHarness) addAndAssert(t *testing.T, nodes []string, added, already, ab int, main, unknown int) { - a, ar, addedBack := th.h.AddOrchestratorNodes(nodes) + a, ar, addedBack := th.h.AddOrchestratorNodes(genNodeStructs(nodes)) require.EqualValues(t, added, a) require.EqualValues(t, already, ar)