From 26f533851510a9e43b336adb24eaf4ad7ae705ed Mon Sep 17 00:00:00 2001 From: Aviral Takkar Date: Mon, 15 Apr 2024 11:23:47 -0700 Subject: [PATCH 1/3] chore: refactoring --- cmd/proxy/main.go | 4 +- internal/context/context.go | 27 -- internal/context/context_test.go | 58 ---- internal/remote/reader.go | 290 ---------------- internal/remote/reader_test.go | 314 ------------------ pkg/discovery/content/consts.go | 17 + pkg/discovery/content/http.go | 14 + .../{ => content/provider}/provider.go | 32 +- .../{ => content/provider}/provider_test.go | 60 +++- .../discovery/content/reader}/interface.go | 2 +- .../mockreader.go => pkg/mocks/reader.go | 8 +- .../routing/tests/mock.go => mocks/router.go} | 5 +- 12 files changed, 128 insertions(+), 703 deletions(-) delete mode 100644 internal/remote/reader.go delete mode 100644 internal/remote/reader_test.go create mode 100644 pkg/discovery/content/consts.go create mode 100644 pkg/discovery/content/http.go rename pkg/discovery/{ => content/provider}/provider.go (84%) rename pkg/discovery/{ => content/provider}/provider_test.go (64%) rename {internal/remote => pkg/discovery/content/reader}/interface.go (97%) rename internal/remote/tests/mockreader.go => pkg/mocks/reader.go (82%) rename pkg/{discovery/routing/tests/mock.go => mocks/router.go} (93%) diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index b0eb90c..9a230b9 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -19,7 +19,7 @@ import ( "github.com/azure/peerd/internal/files/store" "github.com/azure/peerd/internal/handlers" "github.com/azure/peerd/pkg/containerd" - "github.com/azure/peerd/pkg/discovery" + "github.com/azure/peerd/pkg/discovery/content/provider" "github.com/azure/peerd/pkg/discovery/routing" "github.com/azure/peerd/pkg/k8s" "github.com/azure/peerd/pkg/k8s/events" @@ -157,7 +157,7 @@ func serverCommand(ctx context.Context, args *ServerCmd) (err error) { g, ctx := errgroup.WithContext(ctx) g.Go(func() error { - discovery.Provide(ctx, r, containerdStore, filesStore.Subscribe()) + provider.Provide(ctx, r, containerdStore, filesStore.Subscribe()) return nil }) diff --git a/internal/context/context.go b/internal/context/context.go index 2dc8014..e77b87f 100644 --- a/internal/context/context.go +++ b/internal/context/context.go @@ -9,7 +9,6 @@ import ( "os" "strconv" "strings" - "sync" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -92,32 +91,6 @@ func SetOutboundHeaders(r *http.Request, c *gin.Context) { r.Header.Set(NodeHeaderKey, NodeName) } -// Merge merges multiple input channels into a single output channel. -// It starts a goroutine for each input channel and sends the values from each input channel to the output channel. -// Once all input channels are closed, it closes the output channel. -// The function returns the output channel. -func Merge[T any](cs ...<-chan T) <-chan T { - var wg sync.WaitGroup - out := make(chan T) - - output := func(c <-chan T) { - for n := range c { - out <- n - } - wg.Done() - } - wg.Add(len(cs)) - for _, c := range cs { - go output(c) - } - - go func() { - wg.Wait() - close(out) - }() - return out -} - // RangeStartIndex returns the start index of a byte range specified in the given range header value. // It expects the range value to be in the format "bytes=startIndex-endIndex". func RangeStartIndex(rangeValue string) (int64, error) { diff --git a/internal/context/context_test.go b/internal/context/context_test.go index fad6c7e..435d5ea 100644 --- a/internal/context/context_test.go +++ b/internal/context/context_test.go @@ -3,11 +3,9 @@ package context import ( - "fmt" "net/http" "net/http/httptest" "os" - "strings" "testing" "github.com/gin-gonic/gin" @@ -70,62 +68,6 @@ func TestSetOutboundHeaders(t *testing.T) { } } -func TestMerge(t *testing.T) { - - ch1 := make(chan string, 10) - ch2 := make(chan string) - ch3 := make(chan string, 100) - ch4 := make(chan string, 1000) - - mergedChan := Merge(ch1, ch2, ch3, ch4) - - // Write to the channels. - go func() { - for i := 0; i < 100; i++ { - ch1 <- fmt.Sprintf("ch1-%d", i) - } - close(ch1) - }() - - go func() { - for i := 0; i < 100; i++ { - ch2 <- fmt.Sprintf("ch2-%d", i) - } - close(ch2) - }() - - go func() { - for i := 0; i < 100; i++ { - ch3 <- fmt.Sprintf("ch3-%d", i) - } - close(ch3) - }() - - go func() { - for i := 0; i < 100; i++ { - ch4 <- fmt.Sprintf("ch4-%d", i) - } - close(ch4) - }() - - // Read from the merged channel. - total := 0 - for val := range mergedChan { - if strings.HasPrefix(val, "ch1-") || - strings.HasPrefix(val, "ch2-") || - strings.HasPrefix(val, "ch3-") || - strings.HasPrefix(val, "ch4-") { - total++ - } else { - t.Errorf("unexpected value: %v", val) - } - } - - if total != 400 { - t.Errorf("expected: %v, got: %v", 400, total) - } -} - func TestBlobUrl(t *testing.T) { // Create a new request with a URL that has a query string. req, err := http.NewRequest("GET", "http://127.0.0.1:5000/blobs/"+u, nil) diff --git a/internal/remote/reader.go b/internal/remote/reader.go deleted file mode 100644 index 80c3a09..0000000 --- a/internal/remote/reader.go +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. -package remote - -import ( - "context" - "errors" - "fmt" - "io" - "net/http" - "strconv" - "strings" - "time" - - p2pcontext "github.com/azure/peerd/internal/context" - "github.com/azure/peerd/pkg/discovery/routing" - "github.com/azure/peerd/pkg/metrics" - "github.com/gin-gonic/gin" - "github.com/rs/zerolog" -) - -type operation int - -const ( - operationFstatRemote = operation(iota) - operationPreadRemote -) - -var errPeerNotFound = errors.New("peer not found") - -// reader is a Reader implementation. -type reader struct { - context *gin.Context - resolveTimeout time.Duration - - router routing.Router - resolveRetries int - defaultHttpClient *http.Client - - metricsRecorder metrics.Metrics -} - -var _ Reader = &reader{} - -// Log returns the logger with context for this reader. -func (r *reader) Log() *zerolog.Logger { - l := p2pcontext.Logger(r.context) - return &l -} - -// PreadRemote is like pread but to a remote file. -func (r *reader) PreadRemote(buf []byte, offset int64) (int, error) { - key := r.context.GetString(p2pcontext.FileChunkCtxKey) - start := offset - end := int64(len(buf)) + offset - 1 - - log := r.Log().With().Str("operation", "preadremote").Str("key", key).Int64("start", start).Int64("end", end).Logger() - - count, err := r.doP2p(log, key, start, end, operationPreadRemote, buf) - if err == nil { - return int(count), nil - } - - // Could not find a peer that has this file, request origin. - startTime := time.Now() - originReq, err := r.originRequest(start, end) - if err != nil { - return -1, err - } - - count32 := int(0) - defer func() { - r.metricsRecorder.RecordUpstreamResponse(originReq.URL.Hostname(), key, "pread", time.Since(startTime).Seconds(), int64(count32)) - }() - count32, err = r.preadRemote(log, originReq, r.defaultHttpClient, buf) - return count32, err -} - -// FstatRemote stats a remote file. -func (r *reader) FstatRemote() (int64, error) { - key := r.context.GetString(p2pcontext.FileChunkCtxKey) - start := int64(0) - end := int64(0) - - log := r.Log().With().Str("operation", "fstatremote").Int64("start", start).Int64("end", end).Str("key", key).Logger() - - startTime := time.Now() - originReq, err := r.originRequest(start, end) - if err != nil { - return -1, err - } - - var count int64 - defer func() { - r.metricsRecorder.RecordUpstreamResponse(originReq.URL.Hostname(), key, "fstat", time.Since(startTime).Seconds(), count) - }() - count, err = r.fstatRemote(log, originReq, r.defaultHttpClient) - return count, err -} - -// doP2p tries to resolve the key in the p2p network and if successful, it will perform the operation on the peer, and return the result. -func (r *reader) doP2p(log zerolog.Logger, fileChunkKey string, start, end int64, o operation, buf []byte) (int64, error) { - if p2pcontext.IsRequestFromAPeer(r.context) { - log.Warn().Msg("refusing to propagate request from one peer to another") - return -1, errPeerNotFound - } - - log.Debug().Msg(p2pcontext.PeerResolutionStartLog) - defer log.Debug().Msg(p2pcontext.PeerResolutionStopLog) - - resolveCtx, cancel := context.WithTimeout(log.WithContext(r.context), r.resolveTimeout) - defer cancel() - - startTime := time.Now() - peerCount := 0 - peersCh, negCacheCallback, err := r.router.ResolveWithNegativeCacheCallback(resolveCtx, fileChunkKey, false, r.resolveRetries) - if err != nil { - //nolint:errcheck // ignore - log.Error().Err(err).Msg(p2pcontext.PeerRequestErrorLog) - return -1, err - } - - // Request a peer for this file. -peerLoop: - for { - select { - - case <-resolveCtx.Done(): - // Resolving mirror has timed out. - negCacheCallback() - log.Info().Msg(p2pcontext.PeerNotFoundLog) - break peerLoop - - case peer, ok := <-peersCh: - // Channel closed means no more mirrors will be received and max retries has been reached. - if !ok { - negCacheCallback() - log.Info().Msg(p2pcontext.PeerResolutionExhaustedLog) - break peerLoop - } - - if peerCount == 0 { - // Only report the time it took to discover the first peer. - r.metricsRecorder.RecordPeerDiscovery(peer.HttpHost, time.Since(startTime).Seconds()) - peerCount++ - } - - peerReq, err := r.peerRequest(peer.HttpHost, start, end) - if err != nil { - log.Error().Err(err).Msg(p2pcontext.PeerRequestErrorLog) - // try next peer - break - } - - client := r.router.Net().HTTPClientFor(peer.ID) - - var count int64 - startTime = time.Now() - if o == operationFstatRemote { - count, err = r.fstatRemote(log, peerReq, client) - } else if o == operationPreadRemote { - var c int - c, err = r.preadRemote(log, peerReq, client, buf) - count = int64(c) - } else { - err = fmt.Errorf("unknown operation: %v", o) - } - - if err != nil { - // try next peer - log.Error().Err(err).Msg(p2pcontext.PeerRequestErrorLog) - } else { - op := "fstat" - if o == operationPreadRemote { - op = "pread" - } - r.metricsRecorder.RecordPeerResponse(peer.HttpHost, fileChunkKey, op, time.Since(startTime).Seconds(), count) - return count, nil - } - } - } - - return -1, errPeerNotFound -} - -// fstatRemote stats the file. -func (r *reader) fstatRemote(log zerolog.Logger, req *http.Request, client *http.Client) (int64, error) { - log.Debug().Str("url", req.URL.String()).Str("range", req.Header.Get("Range")).Msg("reader fstatRemote start") - defer log.Debug().Msg("reader fstatRemote stop") - - resp, err := client.Do(req) - if err != nil { - log.Error().Err(err).Msg("reader fstatRemote error") - return 0, Error{resp, err} - } - defer resp.Body.Close() - - if resp.StatusCode == 200 { - return resp.ContentLength, nil - } - - if resp.StatusCode == 206 { - l := resp.ContentLength - rs := resp.Header.Get("Content-Range") - if rs == "" { - return l, nil - } - - pos := strings.LastIndexByte(rs, '/') - if pos < 0 { - return l, nil - } - - l, _ = strconv.ParseInt(rs[pos+1:], 10, 64) - return l, nil - } - - log.Error().Err(err).Int("status", resp.StatusCode).Msg("reader fstatRemote error") - return 0, Error{resp, fmt.Errorf("unexpected response code: %d", resp.StatusCode)} -} - -// preadRemote reads the file. -func (r *reader) preadRemote(log zerolog.Logger, req *http.Request, client *http.Client, buf []byte) (int, error) { - log.Debug().Str("url", req.URL.String()).Str("range", req.Header.Get("Range")).Msg("reader preadRemote start") - statusCode := -1 - s := time.Now() - defer func() { - log.Debug().Int("status", statusCode).Dur("duration", time.Since(s)).Msg("reader preadRemote stop") - }() - - resp, err := client.Do(req) - if resp != nil { - statusCode = resp.StatusCode - } - if err != nil { - detailedErr := Error{resp, err} - log.Error().Err(detailedErr).Str("url", req.URL.String()).Str("range", req.Header.Get("Range")).Msg("reader preadRemote error") - return 0, detailedErr - } - defer resp.Body.Close() - - if resp.StatusCode != 200 && resp.StatusCode != 206 { - log.Error().Err(err).Int("status", resp.StatusCode).Msg("reader preadRemote error") - return 0, Error{resp, fmt.Errorf("unexpected response code: %d", resp.StatusCode)} - } - - return io.ReadFull(resp.Body, buf) -} - -// originRequest will create a new request to origin. -func (r *reader) originRequest(start, end int64) (*http.Request, error) { - return r.remoteRequest(r.context.GetString(p2pcontext.BlobUrlCtxKey), start, end) -} - -// perRequest will create a new request to a peer. -func (r *reader) peerRequest(peer string, start, end int64) (*http.Request, error) { - return r.remoteRequest(fmt.Sprintf("%v/blobs/%v", peer, r.context.GetString(p2pcontext.BlobUrlCtxKey)), start, end) -} - -// remoteRequest creates a new HTTP request to a remote server. -func (r *reader) remoteRequest(u string, start, end int64) (*http.Request, error) { - req, err := http.NewRequest("GET", u, nil) - if err != nil { - return nil, err - } - - for key, vals := range r.context.Request.Header { - vals2 := make([]string, len(vals)) - copy(vals2, vals) - req.Header[key] = vals2 - } - - req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) - p2pcontext.SetOutboundHeaders(req, r.context) - - return req, nil -} - -// NewReader creates a new remote reader. -func NewReader(c *gin.Context, router routing.Router, resolveRetries int, resolveTimeout time.Duration, metricsRecorder metrics.Metrics) Reader { - cc := c.Copy() - return &reader{ - context: cc, - resolveTimeout: resolveTimeout, - router: router, - resolveRetries: resolveRetries, - defaultHttpClient: router.Net().HTTPClientFor(""), - metricsRecorder: metricsRecorder, - } -} diff --git a/internal/remote/reader_test.go b/internal/remote/reader_test.go deleted file mode 100644 index 09dc320..0000000 --- a/internal/remote/reader_test.go +++ /dev/null @@ -1,314 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. -package remote - -import ( - "net/http" - "net/http/httptest" - "testing" - "time" - - p2pcontext "github.com/azure/peerd/internal/context" - "github.com/azure/peerd/pkg/discovery/routing/tests" - "github.com/azure/peerd/pkg/metrics" - "github.com/gin-gonic/gin" - "github.com/prometheus/client_golang/prometheus" - "github.com/rs/zerolog" -) - -var ( - hostAndPath = "https://avtakkartest.blob.core.windows.net/d18c7a64c5158179-ff8cb2f639ff44879c12c94361a746d0-782b855128//docker/registry/v2/blobs/sha256/d1/d18c7a64c5158179bdee531a663c5b487de57ff17cff3af29a51c7e70b491d9d/data" - query = "?se=2023-09-20T01%3A14%3A49Z&sig=m4Cr%2BYTZHZQlN5LznY7nrTQ4LCIx2OqnDDM3Dpedbhs%3D&sp=r&spr=https&sr=b&sv=2018-03-28®id=01031d61e1024861afee5d512651eb9f" - u = hostAndPath + query - mr = metrics.NewPromMetrics(prometheus.DefaultRegisterer, "test", "test") -) - -func TestPreadRemoteUpstream(t *testing.T) { - // Setup - m := map[string][]string{} - key := "somekey" - expected := "expected-result" - peersTried := 0 - svr3 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - peersTried++ - w.WriteHeader(http.StatusUnauthorized) - })) - defer svr3.Close() - svr2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - peersTried++ - w.WriteHeader(http.StatusNotFound) - })) - defer svr2.Close() - svr1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - peersTried++ - w.WriteHeader(http.StatusBadGateway) - })) - defer svr1.Close() - val := []string{svr1.URL, svr2.URL, svr3.URL} - m[key] = val - - svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if "?"+r.URL.RawQuery == query { - w.Header().Set("Content-Type", "application/octet-stream") - // nolint:errcheck - w.Write([]byte(expected)) - } else { - w.WriteHeader(http.StatusNotFound) - } - })) - defer svr.Close() - p := svr.URL + "/some-path" - u := "http://127.0.0.1:5000/blobs/" + p + query - req, err := http.NewRequest("GET", u, nil) - if err != nil { - t.Fatal(err) - } - - router := tests.NewMockRouter(m) - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = req - c.Params = []gin.Param{ - {Key: "url", Value: p}, - } - c.Set(p2pcontext.BlobUrlCtxKey, p2pcontext.BlobUrl(c)) - c.Set(p2pcontext.BlobRangeCtxKey, "bytes=0-10") - c.Set(p2pcontext.FileChunkCtxKey, key) - - r := NewReader(c, router, 3, 500*time.Millisecond, mr).(*reader) - b := make([]byte, 10) - - // Test - got, err := r.PreadRemote(b, 0) - - // Assert - if err != nil { - t.Fatal(err) - } else if got != 10 { - t.Fatalf("expected %v, got %v", 10, got) - } else if string(b) != expected[:10] { - t.Fatalf("expected %v, got %v", expected[:10], string(b)) - } else if peersTried != 3 { - t.Fatalf("expected %v, got %v", 3, peersTried) - } -} - -func TestFstatRemote(t *testing.T) { - m := map[string][]string{} - - expected := "expected-result" - svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if "?"+r.URL.RawQuery == query { - w.Header().Set("Content-Type", "application/octet-stream") - // nolint:errcheck - w.Write([]byte(expected)) - } else { - w.WriteHeader(http.StatusNotFound) - } - })) - defer svr.Close() - p := svr.URL + "/some-path" - u := "http://127.0.0.1:5000/blobs/" + p + query - req, err := http.NewRequest("GET", u, nil) - if err != nil { - t.Fatal(err) - } - - router := tests.NewMockRouter(m) - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = req - c.Params = []gin.Param{ - {Key: "url", Value: p}, - } - c.Set(p2pcontext.BlobUrlCtxKey, p2pcontext.BlobUrl(c)) - c.Set(p2pcontext.BlobRangeCtxKey, "bytes=0-0") - - r := NewReader(c, router, 3, 500*time.Millisecond, mr).(*reader) - - got, err := r.FstatRemote() - if err != nil { - t.Fatal(err) - } else if got != int64(len(expected)) { - t.Fatalf("expected %v, got %v", len(expected), got) - } -} - -func TestFstatRemotePartialContent(t *testing.T) { - m := map[string][]string{} - - expected := "expected-result" - svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if "?"+r.URL.RawQuery == query { - w.Header().Set("Content-Type", "application/octet-stream") - // nolint:errcheck - w.WriteHeader(http.StatusPartialContent) - w.Header().Set("Content-Range", "bytes 0-10/10") - // nolint:errcheck - w.Write([]byte(expected)) - } else { - w.WriteHeader(http.StatusNotFound) - } - })) - defer svr.Close() - p := svr.URL + "/some-path" - u := "http://127.0.0.1:5000/blobs/" + p + query - req, err := http.NewRequest("GET", u, nil) - if err != nil { - t.Fatal(err) - } - - router := tests.NewMockRouter(m) - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = req - c.Params = []gin.Param{ - {Key: "url", Value: p}, - } - c.Set(p2pcontext.BlobUrlCtxKey, p2pcontext.BlobUrl(c)) - c.Set(p2pcontext.BlobRangeCtxKey, "bytes=0-0") - - r := NewReader(c, router, 3, 500*time.Millisecond, mr).(*reader) - - got, err := r.FstatRemote() - if err != nil { - t.Fatal(err) - } else if got != int64(len(expected)) { - t.Fatalf("expected %v, got %v", len(expected), got) - } -} - -func TestP2pRetries(t *testing.T) { - l := zerolog.Nop() - m := map[string][]string{} - key := "somekey" - expected := "expected-result" - svr3 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/octet-stream") - // nolint:errcheck - w.Write([]byte(expected)) - })) - defer svr3.Close() - svr2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) - })) - defer svr2.Close() - svr1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadGateway) - })) - defer svr1.Close() - val := []string{svr1.URL, svr2.URL, svr3.URL} - m[key] = val - - req, err := http.NewRequest("GET", "http://127.0.0.1:5000/blobs/"+u, nil) - if err != nil { - t.Fatal(err) - } - - router := tests.NewMockRouter(m) - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = req - r := NewReader(c, router, 3, 500*time.Millisecond, mr).(*reader) - b := make([]byte, 10) - - got, err := r.doP2p(l, key, 0, 10, operationPreadRemote, b) - if err != nil { - t.Fatal(err) - } - - if got != 10 { - t.Fatalf("expected %v, got %v", 10, got) - } else if string(b) != expected[:10] { - t.Fatalf("expected %v, got %v", expected[:10], string(b)) - } -} - -func TestP2pSuccess(t *testing.T) { - l := zerolog.Nop() - m := map[string][]string{} - key := "somekey" - expected := "expected-result" - svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/octet-stream") - // nolint:errcheck - w.Write([]byte(expected)) - })) - defer svr.Close() - val := []string{svr.URL} - m[key] = val - - req, err := http.NewRequest("GET", "http://127.0.0.1:5000/blobs/"+u, nil) - if err != nil { - t.Fatal(err) - } - - router := tests.NewMockRouter(m) - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = req - r := NewReader(c, router, 3, 500*time.Millisecond, mr).(*reader) - b := make([]byte, 10) - - got, err := r.doP2p(l, key, 0, 10, operationPreadRemote, b) - if err != nil { - t.Fatal(err) - } - - if got != 10 { - t.Fatalf("expected %v, got %v", 10, got) - } else if string(b) != expected[:10] { - t.Fatalf("expected %v, got %v", expected[:10], string(b)) - } -} - -func TestP2pPeerNotFound(t *testing.T) { - l := zerolog.Nop() - m := map[string][]string{} - - req, err := http.NewRequest("GET", "http://127.0.0.1:5000/blobs/"+u, nil) - if err != nil { - t.Fatal(err) - } - - router := tests.NewMockRouter(m) - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = req - - r := NewReader(c, router, 3, 500*time.Millisecond, mr).(*reader) - - b := make([]byte, 10) - _, err = r.doP2p(l, "key", 0, 10, operationPreadRemote, b) - if err == nil { - t.Fatal("expected error") - } - - if err != errPeerNotFound { - t.Fatalf("expected %v, got %v", errPeerNotFound, err) - } -} - -func TestP2pNoInfiniteLoops(t *testing.T) { - l := zerolog.Nop() - m := map[string][]string{} - key := "some-key" - val := []string{"http://localhost"} - m[key] = val - - req, err := http.NewRequest("GET", "http://127.0.0.1:5000/blobs/"+u, nil) - if err != nil { - t.Fatal(err) - } - - router := tests.NewMockRouter(m) - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = req - c.Request.Header.Add(p2pcontext.P2PHeaderKey, "true") - - r := NewReader(c, router, 3, 500*time.Millisecond, mr).(*reader) - - b := make([]byte, 10) - _, err = r.doP2p(l, key, 0, 10, operationPreadRemote, b) - if err == nil { - t.Fatal("expected error") - } - - if err != errPeerNotFound { - t.Fatalf("expected %v, got %v", errPeerNotFound, err) - } -} diff --git a/pkg/discovery/content/consts.go b/pkg/discovery/content/consts.go new file mode 100644 index 0000000..2ab6b65 --- /dev/null +++ b/pkg/discovery/content/consts.go @@ -0,0 +1,17 @@ +package content + +// Log messages. +const ( + PeerResolutionStartLog = "peer resolution start" + PeerResolutionStopLog = "peer resolution stop" + PeerNotFoundLog = "peer not found" + PeerResolutionExhaustedLog = "peer resolution exhausted" + PeerRequestErrorLog = "peer request error" +) + +// Request headers. +const ( + P2PHeaderKey = "X-MS-Peerd-RequestFromPeer" + CorrelationHeaderKey = "X-MS-Peerd-CorrelationId" + NodeHeaderKey = "X-MS-Peerd-Node" +) diff --git a/pkg/discovery/content/http.go b/pkg/discovery/content/http.go new file mode 100644 index 0000000..25d79b1 --- /dev/null +++ b/pkg/discovery/content/http.go @@ -0,0 +1,14 @@ +package content + +import ( + "net/http" + + "github.com/azure/peerd/internal/context" +) + +// SetOutboundHeaders sets the mandatory headers for all outbound requests. +func SetOutboundHeaders(r *http.Request, correlationId string) { + r.Header.Set(P2PHeaderKey, "true") + r.Header.Set(CorrelationHeaderKey, correlationId) + r.Header.Set(NodeHeaderKey, context.NodeName) +} diff --git a/pkg/discovery/provider.go b/pkg/discovery/content/provider/provider.go similarity index 84% rename from pkg/discovery/provider.go rename to pkg/discovery/content/provider/provider.go index 54fb12b..ac10a0f 100644 --- a/pkg/discovery/provider.go +++ b/pkg/discovery/content/provider/provider.go @@ -1,14 +1,14 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -package discovery +package provider import ( "context" "errors" "fmt" + "sync" "time" - p2pcontext "github.com/azure/peerd/internal/context" "github.com/azure/peerd/pkg/containerd" "github.com/azure/peerd/pkg/discovery/routing" "github.com/rs/zerolog" @@ -41,7 +41,7 @@ func Provide(ctx context.Context, r routing.Router, containerdStore containerd.S expirationTicker := time.NewTicker(routing.MaxRecordAge - time.Minute) defer expirationTicker.Stop() - ticker := p2pcontext.Merge(immediate, expirationTicker.C) + ticker := merge(immediate, expirationTicker.C) for { select { @@ -124,3 +124,29 @@ func provideRef(ctx context.Context, l zerolog.Logger, containerdStore container return len(keys), nil } + +// Merge merges multiple input channels into a single output channel. +// It starts a goroutine for each input channel and sends the values from each input channel to the output channel. +// Once all input channels are closed, it closes the output channel. +// The function returns the output channel. +func merge[T any](cs ...<-chan T) <-chan T { + var wg sync.WaitGroup + out := make(chan T) + + output := func(c <-chan T) { + for n := range c { + out <- n + } + wg.Done() + } + wg.Add(len(cs)) + for _, c := range cs { + go output(c) + } + + go func() { + wg.Wait() + close(out) + }() + return out +} diff --git a/pkg/discovery/provider_test.go b/pkg/discovery/content/provider/provider_test.go similarity index 64% rename from pkg/discovery/provider_test.go rename to pkg/discovery/content/provider/provider_test.go index 820c035..9d4c7c4 100644 --- a/pkg/discovery/provider_test.go +++ b/pkg/discovery/content/provider/provider_test.go @@ -1,9 +1,11 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -package discovery +package provider import ( "context" + "fmt" + "strings" "testing" "time" @@ -53,3 +55,59 @@ func TestContainerdStoreAds(t *testing.T) { } } } + +func TestMerge(t *testing.T) { + + ch1 := make(chan string, 10) + ch2 := make(chan string) + ch3 := make(chan string, 100) + ch4 := make(chan string, 1000) + + mergedChan := merge(ch1, ch2, ch3, ch4) + + // Write to the channels. + go func() { + for i := 0; i < 100; i++ { + ch1 <- fmt.Sprintf("ch1-%d", i) + } + close(ch1) + }() + + go func() { + for i := 0; i < 100; i++ { + ch2 <- fmt.Sprintf("ch2-%d", i) + } + close(ch2) + }() + + go func() { + for i := 0; i < 100; i++ { + ch3 <- fmt.Sprintf("ch3-%d", i) + } + close(ch3) + }() + + go func() { + for i := 0; i < 100; i++ { + ch4 <- fmt.Sprintf("ch4-%d", i) + } + close(ch4) + }() + + // Read from the merged channel. + total := 0 + for val := range mergedChan { + if strings.HasPrefix(val, "ch1-") || + strings.HasPrefix(val, "ch2-") || + strings.HasPrefix(val, "ch3-") || + strings.HasPrefix(val, "ch4-") { + total++ + } else { + t.Errorf("unexpected value: %v", val) + } + } + + if total != 400 { + t.Errorf("expected: %v, got: %v", 400, total) + } +} diff --git a/internal/remote/interface.go b/pkg/discovery/content/reader/interface.go similarity index 97% rename from internal/remote/interface.go rename to pkg/discovery/content/reader/interface.go index 66004b4..2c7795a 100644 --- a/internal/remote/interface.go +++ b/pkg/discovery/content/reader/interface.go @@ -1,6 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -package remote +package reader import ( "net/http" diff --git a/internal/remote/tests/mockreader.go b/pkg/mocks/reader.go similarity index 82% rename from internal/remote/tests/mockreader.go rename to pkg/mocks/reader.go index 4183aa7..24b97a0 100644 --- a/internal/remote/tests/mockreader.go +++ b/pkg/mocks/reader.go @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -package tests +package mocks import ( - "github.com/azure/peerd/internal/remote" + "github.com/azure/peerd/pkg/discovery/content/reader" "github.com/rs/zerolog" ) @@ -13,7 +13,7 @@ type mockReader struct { data []byte } -var _ remote.Reader = &mockReader{} +var _ reader.Reader = &mockReader{} // FstatRemote implements remote.Reader. func (m *mockReader) FstatRemote() (int64, error) { @@ -34,6 +34,6 @@ func (m *mockReader) PreadRemote(buf []byte, offset int64) (int, error) { } // NewMockReader creates a new mock reader for testing purposes. -func NewMockReader(data []byte) remote.Reader { +func NewMockReader(data []byte) reader.Reader { return &mockReader{data: data} } diff --git a/pkg/discovery/routing/tests/mock.go b/pkg/mocks/router.go similarity index 93% rename from pkg/discovery/routing/tests/mock.go rename to pkg/mocks/router.go index 5531e50..f72fa1e 100644 --- a/pkg/discovery/routing/tests/mock.go +++ b/pkg/mocks/router.go @@ -1,13 +1,12 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -package tests +package mocks import ( "context" "sync" "github.com/azure/peerd/pkg/discovery/routing" - "github.com/azure/peerd/pkg/mocks" "github.com/azure/peerd/pkg/peernet" "github.com/libp2p/go-libp2p/core/peer" ) @@ -38,7 +37,7 @@ func (m *MockRouter) ResolveWithNegativeCacheCallback(ctx context.Context, key s var _ routing.Router = &MockRouter{} func NewMockRouter(resolver map[string][]string) *MockRouter { - n, err := peernet.New(&mocks.MockHost{PeerStore: &mocks.MockPeerstore{}}) + n, err := peernet.New(&MockHost{PeerStore: &MockPeerstore{}}) if err != nil { panic(err) } From b67c616e9ba3a54966d57a2b1899350f6023ed9f Mon Sep 17 00:00:00 2001 From: Aviral Takkar Date: Mon, 15 Apr 2024 14:40:12 -0700 Subject: [PATCH 2/3] feat: big refactor --- Makefile | 2 +- cmd/proxy/main.go | 8 +- internal/files/files.go | 4 +- internal/files/files_test.go | 4 +- internal/files/store/file.go | 4 +- internal/files/store/file_test.go | 22 +- internal/files/store/interface.go | 6 +- internal/files/store/store.go | 27 +- internal/files/store/store_test.go | 24 +- internal/handlers/files/handler.go | 23 +- internal/handlers/files/handler_test.go | 42 ++- internal/handlers/root.go | 15 +- internal/handlers/root_test.go | 4 +- internal/handlers/v2/handler.go | 23 +- internal/handlers/v2/handler_test.go | 42 +-- internal/handlers/v2/mirror.go | 19 +- internal/handlers/v2/mirror_test.go | 10 +- internal/handlers/v2/registry.go | 19 +- internal/handlers/v2/registry_test.go | 20 +- pkg/{ => containerd}/mocks/contentstore.go | 0 pkg/{ => containerd}/mocks/eventservice.go | 0 pkg/{ => containerd}/mocks/imagestore.go | 0 pkg/containerd/store_test.go | 2 +- {internal => pkg}/context/context.go | 53 ++- {internal => pkg}/context/context_test.go | 36 +- pkg/discovery/content/consts.go | 17 - pkg/discovery/content/http.go | 14 - .../content/provider/provider_test.go | 4 +- .../content/reader}/mocks/reader.go | 0 pkg/discovery/content/reader/reader.go | 288 ++++++++++++++++ pkg/discovery/content/reader/reader_test.go | 323 ++++++++++++++++++ pkg/{ => discovery/routing}/mocks/router.go | 3 +- pkg/{ => peernet}/mocks/host.go | 0 pkg/{ => peernet}/mocks/peerstore.go | 0 pkg/peernet/network_test.go | 2 +- tests/cmd/main.go | 4 +- 36 files changed, 842 insertions(+), 222 deletions(-) rename pkg/{ => containerd}/mocks/contentstore.go (100%) rename pkg/{ => containerd}/mocks/eventservice.go (100%) rename pkg/{ => containerd}/mocks/imagestore.go (100%) rename {internal => pkg}/context/context.go (77%) rename {internal => pkg}/context/context_test.go (91%) delete mode 100644 pkg/discovery/content/consts.go delete mode 100644 pkg/discovery/content/http.go rename pkg/{ => discovery/content/reader}/mocks/reader.go (100%) create mode 100644 pkg/discovery/content/reader/reader.go create mode 100644 pkg/discovery/content/reader/reader_test.go rename pkg/{ => discovery/routing}/mocks/router.go (94%) rename pkg/{ => peernet}/mocks/host.go (100%) rename pkg/{ => peernet}/mocks/peerstore.go (100%) diff --git a/Makefile b/Makefile index df7aaff..9d4a6ff 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ GOLINT = golangci-lint run # Source repository variables. ROOT_DIR := $(shell git rev-parse --show-toplevel) BIN_DIR = $(ROOT_DIR)/bin -TEST_PKGS = $(shell go list ./... | grep -v 'github.com/azure/peerd/api\|github.com/azure/peerd/pkg/mocks') # Exclude generated and mock code. +TEST_PKGS = $(shell go list ./... | grep -v 'github.com/azure/peerd/api\|github.com/azure/peerd/pkg/discovery/routing/mocks') # Exclude generated and mock code. TESTS_BIN_DIR = $(BIN_DIR)/tests COVERAGE_DIR=$(BIN_DIR)/coverage SCRIPTS_DIR=$(ROOT_DIR)/build/ci/scripts diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 9a230b9..8de1a73 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -15,10 +15,10 @@ import ( "time" "github.com/alexflint/go-arg" - p2pcontext "github.com/azure/peerd/internal/context" "github.com/azure/peerd/internal/files/store" "github.com/azure/peerd/internal/handlers" "github.com/azure/peerd/pkg/containerd" + pcontext "github.com/azure/peerd/pkg/context" "github.com/azure/peerd/pkg/discovery/content/provider" "github.com/azure/peerd/pkg/discovery/routing" "github.com/azure/peerd/pkg/k8s" @@ -43,10 +43,10 @@ func main() { zerolog.SetGlobalLevel(ll) zerolog.TimeFieldFormat = time.RFC3339Nano - l := zerolog.New(os.Stdout).With().Timestamp().Str("self", p2pcontext.NodeName).Str("version", version).Logger() + l := zerolog.New(os.Stdout).With().Timestamp().Str("self", pcontext.NodeName).Str("version", version).Logger() ctx := l.WithContext(context.Background()) - ctx, err = metrics.WithContext(ctx, p2pcontext.NodeName, "peerd") + ctx, err = metrics.WithContext(ctx, pcontext.NodeName, "peerd") if err != nil { l.Error().Err(err).Msg("failed to initialize metrics") os.Exit(1) @@ -86,7 +86,7 @@ func serverCommand(ctx context.Context, args *ServerCmd) (err error) { return err } - clientset, err := k8s.NewKubernetesInterface(p2pcontext.KubeConfigPath, p2pcontext.NodeName) + clientset, err := k8s.NewKubernetesInterface(pcontext.KubeConfigPath, pcontext.NodeName) if err != nil { return err } diff --git a/internal/files/files.go b/internal/files/files.go index 65d22c5..a1ed00f 100644 --- a/internal/files/files.go +++ b/internal/files/files.go @@ -6,7 +6,7 @@ import ( "fmt" "io" - "github.com/azure/peerd/internal/remote" + "github.com/azure/peerd/pkg/discovery/content/reader" "github.com/azure/peerd/pkg/math" ) @@ -23,7 +23,7 @@ func FileChunkKey(name string, offset, cacheBlockSize int64) string { } // Fetchfile gets the content of a file from the given offset using a remote reader. -func FetchFile(r remote.Reader, name string, offset int64, count int) ([]byte, error) { +func FetchFile(r reader.Reader, name string, offset int64, count int) ([]byte, error) { d := make([]byte, count) l := r.Log().With().Str("name", name).Int64("offset", offset).Int("count", count).Logger() l.Debug().Msg("fetch file start") diff --git a/internal/files/files_test.go b/internal/files/files_test.go index f405535..119257c 100644 --- a/internal/files/files_test.go +++ b/internal/files/files_test.go @@ -8,7 +8,7 @@ import ( "strconv" "testing" - "github.com/azure/peerd/internal/remote" + "github.com/azure/peerd/pkg/discovery/content/reader" "github.com/rs/zerolog" ) @@ -83,4 +83,4 @@ func (m *mockReader) PreadRemote(buf []byte, offset int64) (int, error) { } } -var _ remote.Reader = &mockReader{} +var _ reader.Reader = &mockReader{} diff --git a/internal/files/store/file.go b/internal/files/store/file.go index 355622a..b1e9162 100644 --- a/internal/files/store/file.go +++ b/internal/files/store/file.go @@ -9,7 +9,7 @@ import ( "sync" "github.com/azure/peerd/internal/files" - "github.com/azure/peerd/internal/remote" + "github.com/azure/peerd/pkg/discovery/content/reader" "github.com/azure/peerd/pkg/math" ) @@ -27,7 +27,7 @@ type file struct { chunkOffset int64 - reader remote.Reader + reader reader.Reader store *store } diff --git a/internal/files/store/file_test.go b/internal/files/store/file_test.go index 0c377f3..4c6053f 100644 --- a/internal/files/store/file_test.go +++ b/internal/files/store/file_test.go @@ -10,9 +10,9 @@ import ( "testing" "github.com/azure/peerd/internal/files" - remotetests "github.com/azure/peerd/internal/remote/tests" "github.com/azure/peerd/pkg/cache" - "github.com/azure/peerd/pkg/discovery/routing/tests" + readermocks "github.com/azure/peerd/pkg/discovery/content/reader/mocks" + "github.com/azure/peerd/pkg/discovery/routing/mocks" ) func TestReadAtWithChunkOffset(t *testing.T) { @@ -20,14 +20,14 @@ func TestReadAtWithChunkOffset(t *testing.T) { files.CacheBlockSize = 1 // 1 byte - s, err := NewFilesStore(ctxWithMetrics, tests.NewMockRouter(make(map[string][]string))) + s, err := NewFilesStore(ctxWithMetrics, mocks.NewMockRouter(make(map[string][]string))) if err != nil { t.Fatal(err) } fWithChunkOffset := &file{ Name: "test", - reader: remotetests.NewMockReader(data), + reader: readermocks.NewMockReader(data), store: s.(*store), chunkOffset: 4, } @@ -78,14 +78,14 @@ func TestReadAt(t *testing.T) { files.CacheBlockSize = 1 // 1 byte - s, err := NewFilesStore(ctxWithMetrics, tests.NewMockRouter(make(map[string][]string))) + s, err := NewFilesStore(ctxWithMetrics, mocks.NewMockRouter(make(map[string][]string))) if err != nil { t.Fatal(err) } f := &file{ Name: "test", - reader: remotetests.NewMockReader(data), + reader: readermocks.NewMockReader(data), store: s.(*store), } size, err := f.Fstat() @@ -128,14 +128,14 @@ func TestReadAt(t *testing.T) { func TestSeek(t *testing.T) { data := []byte("hello world") - s, err := NewFilesStore(ctxWithMetrics, tests.NewMockRouter(make(map[string][]string))) + s, err := NewFilesStore(ctxWithMetrics, mocks.NewMockRouter(make(map[string][]string))) if err != nil { t.Fatal(err) } f := &file{ Name: "test", - reader: remotetests.NewMockReader(data), + reader: readermocks.NewMockReader(data), store: s.(*store), } size, err := f.Fstat() @@ -192,14 +192,14 @@ func TestFstat(t *testing.T) { t.Fatal(err) } - s, err := NewFilesStore(ctxWithMetrics, tests.NewMockRouter(make(map[string][]string))) + s, err := NewFilesStore(ctxWithMetrics, mocks.NewMockRouter(make(map[string][]string))) if err != nil { t.Fatal(err) } f := &file{ Name: "test", - reader: remotetests.NewMockReader(data), + reader: readermocks.NewMockReader(data), store: s.(*store), } @@ -212,7 +212,7 @@ func TestFstat(t *testing.T) { f = &file{ Name: "test2", - reader: remotetests.NewMockReader(data), + reader: readermocks.NewMockReader(data), store: s.(*store), chunkOffset: 14, } diff --git a/internal/files/store/interface.go b/internal/files/store/interface.go index 62d67ef..83712bd 100644 --- a/internal/files/store/interface.go +++ b/internal/files/store/interface.go @@ -5,17 +5,17 @@ package store import ( "time" - "github.com/gin-gonic/gin" + "github.com/azure/peerd/pkg/context" "github.com/opencontainers/go-digest" ) // FilesStore describes a store for files. type FilesStore interface { // Key tries to find the cache key for the requested content or returns empty. - Key(c *gin.Context) (key string, d digest.Digest, err error) + Key(c context.Context) (key string, d digest.Digest, err error) // Open opens the requested file and starts prefetching it. It also returns the size of the file. - Open(c *gin.Context) (File, error) + Open(c context.Context) (File, error) // Subscribe returns a channel that will be notified when a blob is added to the store. Subscribe() chan string diff --git a/internal/files/store/store.go b/internal/files/store/store.go index 53e60b2..420df9d 100644 --- a/internal/files/store/store.go +++ b/internal/files/store/store.go @@ -9,14 +9,13 @@ import ( "strings" "time" - p2pcontext "github.com/azure/peerd/internal/context" "github.com/azure/peerd/internal/files" - "github.com/azure/peerd/internal/remote" "github.com/azure/peerd/pkg/cache" + pcontext "github.com/azure/peerd/pkg/context" + "github.com/azure/peerd/pkg/discovery/content/reader" "github.com/azure/peerd/pkg/discovery/routing" "github.com/azure/peerd/pkg/metrics" "github.com/azure/peerd/pkg/urlparser" - "github.com/gin-gonic/gin" "github.com/opencontainers/go-digest" "github.com/rs/zerolog" ) @@ -58,7 +57,7 @@ type prefetchableSegment struct { offset int64 count int - reader remote.Reader + reader reader.Reader } // store describes a content store whose contents can come from disk or a remote source. @@ -82,15 +81,15 @@ func (s *store) Subscribe() chan string { } // Open opens the requested file and starts prefetching it. -func (s *store) Open(c *gin.Context) (File, error) { +func (s *store) Open(c pcontext.Context) (File, error) { - chunkKey := c.GetString(p2pcontext.FileChunkCtxKey) + chunkKey := c.GetString(pcontext.FileChunkCtxKey) tokens := strings.Split(chunkKey, files.FileChunkKeySep) name := tokens[0] alignedOff, _ := strconv.ParseInt(tokens[1], 10, 64) - log := p2pcontext.Logger(c) - if p2pcontext.IsRequestFromAPeer(c) { + log := pcontext.Logger(c) + if pcontext.IsRequestFromAPeer(c) { // This request came from a peer. Don't serve it unless we have the requested range cached. if ok := s.cache.Exists(name, alignedOff); !ok { log.Info().Str("name", name).Msg("peer request not cached") @@ -103,10 +102,10 @@ func (s *store) Open(c *gin.Context) (File, error) { store: s, cur: 0, size: 0, - reader: remote.NewReader(c, s.router, s.resolveRetries, s.resolveTimeout, s.metricsRecorder), + reader: reader.NewReader(c, s.router, s.resolveRetries, s.resolveTimeout, s.metricsRecorder), } - if p2pcontext.IsRequestFromAPeer(c) { + if pcontext.IsRequestFromAPeer(c) { // Ensure this file can only serve the requested chunk. // This is to prevent infinite loops when a peer requests a file that is not cached. f.chunkOffset = alignedOff @@ -122,10 +121,10 @@ func (s *store) Open(c *gin.Context) (File, error) { } // Key tries to find the cache key for the requested content or returns empty. -func (s *store) Key(c *gin.Context) (string, digest.Digest, error) { - log := p2pcontext.Logger(c) +func (s *store) Key(c pcontext.Context) (string, digest.Digest, error) { + log := pcontext.Logger(c) - blobUrl := p2pcontext.BlobUrl(c) + blobUrl := pcontext.BlobUrl(c) d, err := s.parser.ParseDigest(blobUrl) if err != nil { log.Error().Err(err).Msg("store key") @@ -133,7 +132,7 @@ func (s *store) Key(c *gin.Context) (string, digest.Digest, error) { startIndex := int64(0) // Default to 0 for HEADs. if c.Request.Method == "GET" { - startIndex, err = p2pcontext.RangeStartIndex(c.Request.Header.Get("Range")) + startIndex, err = pcontext.RangeStartIndex(c.Request.Header.Get("Range")) if err != nil { return "", "", err } diff --git a/internal/files/store/store_test.go b/internal/files/store/store_test.go index 8fc7694..8ea0a57 100644 --- a/internal/files/store/store_test.go +++ b/internal/files/store/store_test.go @@ -9,9 +9,9 @@ import ( "os" "testing" - p2pcontext "github.com/azure/peerd/internal/context" "github.com/azure/peerd/internal/files" - "github.com/azure/peerd/pkg/discovery/routing/tests" + pcontext "github.com/azure/peerd/pkg/context" + "github.com/azure/peerd/pkg/discovery/routing/mocks" "github.com/gin-gonic/gin" "github.com/opencontainers/go-digest" ) @@ -29,7 +29,7 @@ func TestOpenP2p(t *testing.T) { t.Fatal(err) } req.Header.Set("Range", fmt.Sprintf("bytes=%v-%v", files.CacheBlockSize, files.CacheBlockSize+172)) - req.Header.Set(p2pcontext.P2PHeaderKey, "true") + req.Header.Set(pcontext.P2PHeaderKey, "true") expD := "sha256:d18c7a64c5158179bdee531a663c5b487de57ff17cff3af29a51c7e70b491d9d" expK := fmt.Sprintf("%v%v%v", expD, files.FileChunkKeySep, files.CacheBlockSize) @@ -40,15 +40,15 @@ func TestOpenP2p(t *testing.T) { ctx.Params = []gin.Param{ {Key: "url", Value: hostAndPath}, } - ctx.Set(p2pcontext.FileChunkCtxKey, expK) + ctx.Set(pcontext.FileChunkCtxKey, expK) PrefetchWorkers = 0 // turn off prefetching - s, err := NewFilesStore(ctxWithMetrics, tests.NewMockRouter(make(map[string][]string))) + s, err := NewFilesStore(ctxWithMetrics, mocks.NewMockRouter(make(map[string][]string))) if err != nil { t.Fatal(err) } - _, err = s.Open(ctx) + _, err = s.Open(pcontext.Context{Context: ctx}) if err != os.ErrNotExist { t.Errorf("expected %v, got %v", os.ErrNotExist, err) } @@ -71,17 +71,17 @@ func TestOpenNonP2p(t *testing.T) { ctx.Params = []gin.Param{ {Key: "url", Value: hostAndPath}, } - ctx.Set(p2pcontext.FileChunkCtxKey, expK) + ctx.Set(pcontext.FileChunkCtxKey, expK) PrefetchWorkers = 0 // turn off prefetching - s, err := NewMockStore(ctxWithMetrics, tests.NewMockRouter(make(map[string][]string))) + s, err := NewMockStore(ctxWithMetrics, mocks.NewMockRouter(make(map[string][]string))) if err != nil { t.Fatal(err) } s.Cache().PutSize(expD, 200) - _, err = s.Open(ctx) + _, err = s.Open(pcontext.Context{Context: ctx}) if err != nil { t.Fatal(err) } @@ -105,12 +105,12 @@ func TestKey(t *testing.T) { {Key: "url", Value: hostAndPath}, } - s, err := NewFilesStore(ctxWithMetrics, tests.NewMockRouter(make(map[string][]string))) + s, err := NewFilesStore(ctxWithMetrics, mocks.NewMockRouter(make(map[string][]string))) if err != nil { t.Fatal(err) } - k, d, err := s.Key(ctx) + k, d, err := s.Key(pcontext.Context{Context: ctx}) if err != nil { t.Fatal(err) } @@ -125,7 +125,7 @@ func TestKey(t *testing.T) { } func TestSubscribe(t *testing.T) { - s, err := NewFilesStore(ctxWithMetrics, tests.NewMockRouter(make(map[string][]string))) + s, err := NewFilesStore(ctxWithMetrics, mocks.NewMockRouter(make(map[string][]string))) if err != nil { t.Fatal(err) } diff --git a/internal/handlers/files/handler.go b/internal/handlers/files/handler.go index 0d9eaf9..b3ce9be 100644 --- a/internal/handlers/files/handler.go +++ b/internal/handlers/files/handler.go @@ -8,10 +8,9 @@ import ( "os" "time" - p2pcontext "github.com/azure/peerd/internal/context" "github.com/azure/peerd/internal/files/store" + pcontext "github.com/azure/peerd/pkg/context" "github.com/azure/peerd/pkg/metrics" - "github.com/gin-gonic/gin" ) // FilesHandler describes a handler for files. @@ -20,11 +19,9 @@ type FilesHandler struct { metricsRecorder metrics.Metrics } -var _ gin.HandlerFunc = (&FilesHandler{}).Handle - // Handle handles a request for a file. -func (h *FilesHandler) Handle(c *gin.Context) { - log := p2pcontext.Logger(c).With().Str("blob", p2pcontext.BlobUrl(c)).Bool("p2p", p2pcontext.IsRequestFromAPeer(c)).Logger() +func (h *FilesHandler) Handle(c pcontext.Context) { + log := pcontext.Logger(c).With().Str("blob", pcontext.BlobUrl(c)).Bool("p2p", pcontext.IsRequestFromAPeer(c)).Logger() log.Debug().Msg("files handler start") s := time.Now() defer func() { @@ -56,14 +53,14 @@ func (h *FilesHandler) Handle(c *gin.Context) { w.Header().Set("Content-Type", "application/octet-stream") w.Header().Del("Content-Length") - w.Header().Set(p2pcontext.NodeHeaderKey, p2pcontext.NodeName) - w.Header().Set(p2pcontext.CorrelationHeaderKey, c.GetString(p2pcontext.CorrelationIdCtxKey)) + w.Header().Set(pcontext.NodeHeaderKey, pcontext.NodeName) + w.Header().Set(pcontext.CorrelationHeaderKey, c.GetString(pcontext.CorrelationIdCtxKey)) http.ServeContent(w, c.Request, "file", time.Now(), f) } // fill fills the context with handler specific information. -func (h *FilesHandler) fill(c *gin.Context) error { +func (h *FilesHandler) fill(c pcontext.Context) error { c.Set("handler", "files") key, d, err := h.store.Key(c) @@ -71,10 +68,10 @@ func (h *FilesHandler) fill(c *gin.Context) error { return err } - c.Set(p2pcontext.DigestCtxKey, d.String()) - c.Set(p2pcontext.FileChunkCtxKey, key) - c.Set(p2pcontext.BlobUrlCtxKey, p2pcontext.BlobUrl(c)) - c.Set(p2pcontext.BlobRangeCtxKey, c.Request.Header.Get("Range")) + c.Set(pcontext.DigestCtxKey, d.String()) + c.Set(pcontext.FileChunkCtxKey, key) + c.Set(pcontext.BlobUrlCtxKey, pcontext.BlobUrl(c)) + c.Set(pcontext.BlobRangeCtxKey, c.Request.Header.Get("Range")) return nil } diff --git a/internal/handlers/files/handler_test.go b/internal/handlers/files/handler_test.go index 9d5924b..66b98c7 100644 --- a/internal/handlers/files/handler_test.go +++ b/internal/handlers/files/handler_test.go @@ -10,10 +10,10 @@ import ( "net/http/httptest" "testing" - p2pcontext "github.com/azure/peerd/internal/context" "github.com/azure/peerd/internal/files" "github.com/azure/peerd/internal/files/store" - "github.com/azure/peerd/pkg/discovery/routing/tests" + pcontext "github.com/azure/peerd/pkg/context" + "github.com/azure/peerd/pkg/discovery/routing/mocks" "github.com/azure/peerd/pkg/metrics" "github.com/gin-gonic/gin" ) @@ -34,7 +34,7 @@ func TestPartialContentResponseInP2PMode(t *testing.T) { } expRange := fmt.Sprintf("bytes=%v-%v", 12, 100) req.Header.Set("Range", expRange) - req.Header.Set(p2pcontext.P2PHeaderKey, "true") + req.Header.Set(pcontext.P2PHeaderKey, "true") expD := "sha256:d18c7a64c5158179bdee531a663c5b487de57ff17cff3af29a51c7e70b491d9d" @@ -47,7 +47,7 @@ func TestPartialContentResponseInP2PMode(t *testing.T) { } store.PrefetchWorkers = 0 // turn off prefetching - s, err := store.NewMockStore(ctxWithMetrics, tests.NewMockRouter(make(map[string][]string))) + s, err := store.NewMockStore(ctxWithMetrics, mocks.NewMockRouter(make(map[string][]string))) if err != nil { t.Fatal(err) } @@ -63,11 +63,13 @@ func TestPartialContentResponseInP2PMode(t *testing.T) { return []byte(content), nil }) - h.Handle(ctx) + pctx := pcontext.FromContext(ctx) + + h.Handle(pctx) resp := recorder.Result() if resp.StatusCode != http.StatusPartialContent { - t.Errorf("expected %v, got %v", http.StatusOK, ctx.Writer.Status()) + t.Errorf("expected %v, got %v", http.StatusOK, pctx.Writer.Status()) } ret, err := io.ReadAll(resp.Body) @@ -87,7 +89,7 @@ func TestNotFoundInP2PMode(t *testing.T) { } expRange := fmt.Sprintf("bytes=%v-%v", files.CacheBlockSize, files.CacheBlockSize+172) req.Header.Set("Range", expRange) - req.Header.Set(p2pcontext.P2PHeaderKey, "true") + req.Header.Set(pcontext.P2PHeaderKey, "true") // Create a new context with the request. ctx, _ := gin.CreateTestContext(httptest.NewRecorder()) @@ -97,14 +99,16 @@ func TestNotFoundInP2PMode(t *testing.T) { } store.PrefetchWorkers = 0 // turn off prefetching - s, err := store.NewFilesStore(ctxWithMetrics, tests.NewMockRouter(make(map[string][]string))) + s, err := store.NewFilesStore(ctxWithMetrics, mocks.NewMockRouter(make(map[string][]string))) if err != nil { t.Fatal(err) } h := New(ctxWithMetrics, s) - h.Handle(ctx) + pmc := pcontext.FromContext(ctx) + + h.Handle(pmc) if ctx.Writer.Status() != http.StatusNotFound { t.Errorf("expected %v, got %v", http.StatusNotFound, ctx.Writer.Status()) } @@ -118,7 +122,7 @@ func TestFill(t *testing.T) { } expRange := fmt.Sprintf("bytes=%v-%v", files.CacheBlockSize, files.CacheBlockSize+172) req.Header.Set("Range", expRange) - req.Header.Set(p2pcontext.P2PHeaderKey, "true") + req.Header.Set(pcontext.P2PHeaderKey, "true") expD := "sha256:d18c7a64c5158179bdee531a663c5b487de57ff17cff3af29a51c7e70b491d9d" expK := fmt.Sprintf("%v_%v", expD, files.CacheBlockSize) @@ -131,26 +135,28 @@ func TestFill(t *testing.T) { } store.PrefetchWorkers = 0 // turn off prefetching - s, err := store.NewFilesStore(ctxWithMetrics, tests.NewMockRouter(make(map[string][]string))) + s, err := store.NewFilesStore(ctxWithMetrics, mocks.NewMockRouter(make(map[string][]string))) if err != nil { t.Fatal(err) } h := New(ctxWithMetrics, s) - err = h.fill(ctx) + pmc := pcontext.FromContext(ctx) + + err = h.fill(pmc) if err != nil { t.Fatal(err) } - if ctx.GetString(p2pcontext.FileChunkCtxKey) != expK { - t.Errorf("expected %v, got %v", expK, ctx.GetString(p2pcontext.FileChunkCtxKey)) + if ctx.GetString(pcontext.FileChunkCtxKey) != expK { + t.Errorf("expected %v, got %v", expK, ctx.GetString(pcontext.FileChunkCtxKey)) } - if ctx.GetString(p2pcontext.BlobRangeCtxKey) != expRange { - t.Errorf("expected %v, got %v", expRange, ctx.GetString(p2pcontext.BlobRangeCtxKey)) + if ctx.GetString(pcontext.BlobRangeCtxKey) != expRange { + t.Errorf("expected %v, got %v", expRange, ctx.GetString(pcontext.BlobRangeCtxKey)) } - if ctx.GetString(p2pcontext.BlobUrlCtxKey) != hostAndPath+query { - t.Errorf("expected %v, got %v", hostAndPath+query, ctx.GetString(p2pcontext.BlobUrlCtxKey)) + if ctx.GetString(pcontext.BlobUrlCtxKey) != hostAndPath+query { + t.Errorf("expected %v, got %v", hostAndPath+query, ctx.GetString(pcontext.BlobUrlCtxKey)) } } diff --git a/internal/handlers/root.go b/internal/handlers/root.go index d07a380..ea72475 100644 --- a/internal/handlers/root.go +++ b/internal/handlers/root.go @@ -7,11 +7,11 @@ import ( "net/http" "time" - p2pcontext "github.com/azure/peerd/internal/context" filesStore "github.com/azure/peerd/internal/files/store" "github.com/azure/peerd/internal/handlers/files" v2 "github.com/azure/peerd/internal/handlers/v2" "github.com/azure/peerd/pkg/containerd" + pcontext "github.com/azure/peerd/pkg/context" "github.com/azure/peerd/pkg/discovery/routing" "github.com/gin-gonic/gin" "github.com/rs/zerolog" @@ -44,10 +44,13 @@ func newEngine(ctx context.Context) *gin.Engine { baseLog := zerolog.Ctx(ctx) engine.Use(func(c *gin.Context) { - p2pcontext.FillCorrelationId(c) - c.Set(p2pcontext.LoggerCtxKey, baseLog) - l := p2pcontext.Logger(c) + pc := pcontext.FromContext(c) + + pcontext.FillCorrelationId(pc) + c.Set(pcontext.LoggerCtxKey, baseLog) + + l := pcontext.Logger(pc) l.Debug().Msg("request start") s := time.Now() @@ -94,7 +97,7 @@ func registerRoutes(engine *gin.Engine, f, v gin.HandlerFunc) { // @Failure 404 {string} string "Not Found" // @Router /blobs/{url} [get] func fileHandler(c *gin.Context) { - fh.Handle(c) + fh.Handle(pcontext.FromContext(c)) } // v2Handler is a handler function for the /v2 API @@ -107,5 +110,5 @@ func fileHandler(c *gin.Context) { // @Router /v2/{repo}/manifests/{reference} [get] // @Router /v2/{repo}/blobs/{digest} [get] func v2Handler(c *gin.Context) { - v2h.Handle(c) + v2h.Handle(pcontext.FromContext(c)) } diff --git a/internal/handlers/root_test.go b/internal/handlers/root_test.go index 1665fe5..1ec7b71 100644 --- a/internal/handlers/root_test.go +++ b/internal/handlers/root_test.go @@ -8,7 +8,7 @@ import ( "github.com/azure/peerd/internal/files/store" "github.com/azure/peerd/pkg/containerd" - "github.com/azure/peerd/pkg/discovery/routing/tests" + "github.com/azure/peerd/pkg/discovery/routing/mocks" "github.com/azure/peerd/pkg/metrics" "github.com/gin-gonic/gin" ) @@ -102,7 +102,7 @@ func TestNewEngine(t *testing.T) { } func TestHandler(t *testing.T) { - mr := tests.NewMockRouter(map[string][]string{}) + mr := mocks.NewMockRouter(map[string][]string{}) ms := containerd.NewMockContainerdStore(nil) mfs, err := store.NewMockStore(ctxWithMetrics, mr) if err != nil { diff --git a/internal/handlers/v2/handler.go b/internal/handlers/v2/handler.go index 8fd1521..988b866 100644 --- a/internal/handlers/v2/handler.go +++ b/internal/handlers/v2/handler.go @@ -8,12 +8,11 @@ import ( "path" "time" - p2pcontext "github.com/azure/peerd/internal/context" "github.com/azure/peerd/pkg/containerd" + pcontext "github.com/azure/peerd/pkg/context" "github.com/azure/peerd/pkg/discovery/routing" "github.com/azure/peerd/pkg/metrics" "github.com/azure/peerd/pkg/oci/distribution" - "github.com/gin-gonic/gin" ) // V2Handler describes a handler for OCI content. @@ -23,17 +22,15 @@ type V2Handler struct { metricsRecorder metrics.Metrics } -var _ gin.HandlerFunc = (&V2Handler{}).Handle - // Handle handles a request for a file. -func (h *V2Handler) Handle(c *gin.Context) { - l := p2pcontext.Logger(c).With().Bool("p2p", p2pcontext.IsRequestFromAPeer(c)).Logger() +func (h *V2Handler) Handle(c pcontext.Context) { + l := pcontext.Logger(c).With().Bool("p2p", pcontext.IsRequestFromAPeer(c)).Logger() l.Debug().Msg("v2 handler start") s := time.Now() defer func() { dur := time.Since(s) h.metricsRecorder.RecordRequest(c.Request.Method, "oci", dur.Seconds()) - l.Debug().Dur("duration", dur).Str("ns", c.GetString(p2pcontext.NamespaceCtxKey)).Str("ref", c.GetString(p2pcontext.ReferenceCtxKey)).Str("digest", c.GetString(p2pcontext.DigestCtxKey)).Msg("v2 handler stop") + l.Debug().Dur("duration", dur).Str("ns", c.GetString(pcontext.NamespaceCtxKey)).Str("ref", c.GetString(pcontext.ReferenceCtxKey)).Str("digest", c.GetString(pcontext.DigestCtxKey)).Msg("v2 handler stop") }() p := path.Clean(c.Request.URL.Path) @@ -54,7 +51,7 @@ func (h *V2Handler) Handle(c *gin.Context) { return } - if p2pcontext.IsRequestFromAPeer(c) { + if pcontext.IsRequestFromAPeer(c) { h.registry.Handle(c) return } else { @@ -64,7 +61,7 @@ func (h *V2Handler) Handle(c *gin.Context) { } // fill fills the context with handler specific information. -func (h *V2Handler) fill(c *gin.Context) error { +func (h *V2Handler) fill(c pcontext.Context) error { c.Set("handler", "v2") ns := c.Query("ns") @@ -72,16 +69,16 @@ func (h *V2Handler) fill(c *gin.Context) error { ns = "docker.io" } - c.Set(p2pcontext.NamespaceCtxKey, ns) + c.Set(pcontext.NamespaceCtxKey, ns) ref, dgst, refType, err := distribution.ParsePathComponents(ns, c.Request.URL.Path) if err != nil { return err } - c.Set(p2pcontext.ReferenceCtxKey, ref) - c.Set(p2pcontext.DigestCtxKey, dgst.String()) - c.Set(p2pcontext.RefTypeCtxKey, refType) + c.Set(pcontext.ReferenceCtxKey, ref) + c.Set(pcontext.DigestCtxKey, dgst.String()) + c.Set(pcontext.RefTypeCtxKey, refType) return nil } diff --git a/internal/handlers/v2/handler_test.go b/internal/handlers/v2/handler_test.go index f8df5e4..55309e1 100644 --- a/internal/handlers/v2/handler_test.go +++ b/internal/handlers/v2/handler_test.go @@ -6,9 +6,9 @@ import ( "net/http/httptest" "testing" - p2pcontext "github.com/azure/peerd/internal/context" "github.com/azure/peerd/pkg/containerd" - "github.com/azure/peerd/pkg/discovery/routing/tests" + pcontext "github.com/azure/peerd/pkg/context" + "github.com/azure/peerd/pkg/discovery/routing/mocks" "github.com/azure/peerd/pkg/metrics" "github.com/azure/peerd/pkg/oci/distribution" "github.com/gin-gonic/gin" @@ -19,7 +19,7 @@ var ( ) func TestNew(t *testing.T) { - mr := tests.NewMockRouter(nil) + mr := mocks.NewMockRouter(nil) ms := containerd.NewMockContainerdStore(nil) h, err := New(ctxWithMetrics, mr, ms) @@ -33,7 +33,7 @@ func TestNew(t *testing.T) { } func TestFillDefault(t *testing.T) { - mr := tests.NewMockRouter(nil) + mr := mocks.NewMockRouter(nil) ms := containerd.NewMockContainerdStore(nil) h, err := New(ctxWithMetrics, mr, ms) @@ -50,25 +50,27 @@ func TestFillDefault(t *testing.T) { } mc.Request = req - err = h.fill(mc) + pmc := pcontext.FromContext(mc) + + err = h.fill(pmc) if err != nil { t.Fatalf("unexpected error: %v", err) } - gotNs := mc.GetString(p2pcontext.NamespaceCtxKey) + gotNs := mc.GetString(pcontext.NamespaceCtxKey) if gotNs != "docker.io" { t.Fatalf("expected docker.io, got %s", gotNs) } - if mc.GetString(p2pcontext.ReferenceCtxKey) != "docker.io/library/alpine:3.18.0" { - t.Fatalf("expected library/alpine, got %s", mc.GetString(p2pcontext.ReferenceCtxKey)) + if mc.GetString(pcontext.ReferenceCtxKey) != "docker.io/library/alpine:3.18.0" { + t.Fatalf("expected library/alpine, got %s", mc.GetString(pcontext.ReferenceCtxKey)) } - if mc.GetString(p2pcontext.DigestCtxKey) != "" { - t.Fatalf("expected empty string, got %s", mc.GetString(p2pcontext.DigestCtxKey)) + if mc.GetString(pcontext.DigestCtxKey) != "" { + t.Fatalf("expected empty string, got %s", mc.GetString(pcontext.DigestCtxKey)) } - gotRefType, ok := mc.Get(p2pcontext.RefTypeCtxKey) + gotRefType, ok := mc.Get(pcontext.RefTypeCtxKey) if !ok { t.Fatalf("expected reference type, got nil") } @@ -84,24 +86,26 @@ func TestFillDefault(t *testing.T) { } mc2.Request = req2 - err = h.fill(mc2) + pmc2 := pcontext.FromContext(mc2) + + err = h.fill(pmc2) if err != nil { t.Fatalf("unexpected error: %v", err) } - if mc2.GetString(p2pcontext.NamespaceCtxKey) != "k8s.io" { - t.Fatalf("expected k8s.io, got %s", mc2.GetString(p2pcontext.NamespaceCtxKey)) + if mc2.GetString(pcontext.NamespaceCtxKey) != "k8s.io" { + t.Fatalf("expected k8s.io, got %s", mc2.GetString(pcontext.NamespaceCtxKey)) } - if mc2.GetString(p2pcontext.ReferenceCtxKey) != "" { - t.Fatalf("expected empty string, got %s", mc2.GetString(p2pcontext.ReferenceCtxKey)) + if mc2.GetString(pcontext.ReferenceCtxKey) != "" { + t.Fatalf("expected empty string, got %s", mc2.GetString(pcontext.ReferenceCtxKey)) } - if mc2.GetString(p2pcontext.DigestCtxKey) != "sha256:bb863d6b95453b6b10dfaa1a52cb53f453d9a97ee775808ebaf6533bb4c9bb30" { - t.Fatalf("expected sha256:bb863d6b95453b6b10dfaa1a52cb53f453d9a97ee775808ebaf6533bb4c9bb30, got %s", mc2.GetString(p2pcontext.DigestCtxKey)) + if mc2.GetString(pcontext.DigestCtxKey) != "sha256:bb863d6b95453b6b10dfaa1a52cb53f453d9a97ee775808ebaf6533bb4c9bb30" { + t.Fatalf("expected sha256:bb863d6b95453b6b10dfaa1a52cb53f453d9a97ee775808ebaf6533bb4c9bb30, got %s", mc2.GetString(pcontext.DigestCtxKey)) } - gotRefType, ok = mc2.Get(p2pcontext.RefTypeCtxKey) + gotRefType, ok = mc2.Get(pcontext.RefTypeCtxKey) if !ok { t.Fatalf("expected reference type, got nil") } diff --git a/internal/handlers/v2/mirror.go b/internal/handlers/v2/mirror.go index 3fed11d..0699b3c 100644 --- a/internal/handlers/v2/mirror.go +++ b/internal/handlers/v2/mirror.go @@ -11,10 +11,9 @@ import ( "net/url" "time" - p2pcontext "github.com/azure/peerd/internal/context" + pcontext "github.com/azure/peerd/pkg/context" "github.com/azure/peerd/pkg/discovery/routing" "github.com/azure/peerd/pkg/peernet" - "github.com/gin-gonic/gin" ) var ( @@ -34,16 +33,14 @@ type Mirror struct { n peernet.Network } -var _ gin.HandlerFunc = (&Mirror{}).Handle - // Handle handles a request to this registry mirror. -func (m *Mirror) Handle(c *gin.Context) { - key := c.GetString(p2pcontext.DigestCtxKey) +func (m *Mirror) Handle(c pcontext.Context) { + key := c.GetString(pcontext.DigestCtxKey) if key == "" { - key = c.GetString(p2pcontext.ReferenceCtxKey) + key = c.GetString(pcontext.ReferenceCtxKey) } - l := p2pcontext.Logger(c).With().Str("handler", "mirror").Str("ref", key).Logger() + l := pcontext.Logger(c).With().Str("handler", "mirror").Str("ref", key).Logger() l.Debug().Msg("mirror handler start") s := time.Now() defer func() { @@ -71,14 +68,14 @@ func (m *Mirror) Handle(c *gin.Context) { case <-resolveCtx.Done(): // Resolving mirror has timed out. //nolint - c.AbortWithError(http.StatusNotFound, fmt.Errorf(p2pcontext.PeerNotFoundLog)) + c.AbortWithError(http.StatusNotFound, fmt.Errorf(pcontext.PeerNotFoundLog)) return case peer, ok := <-peersChan: // Channel closed means no more mirrors will be received and max retries has been reached. if !ok { //nolint - c.AbortWithError(http.StatusInternalServerError, fmt.Errorf(p2pcontext.PeerResolutionExhaustedLog)) + c.AbortWithError(http.StatusInternalServerError, fmt.Errorf(pcontext.PeerResolutionExhaustedLog)) return } @@ -95,7 +92,7 @@ func (m *Mirror) Handle(c *gin.Context) { r.URL = u r.URL.Path = c.Request.URL.Path r.URL.RawQuery = c.Request.URL.RawQuery - p2pcontext.SetOutboundHeaders(r, c) + pcontext.SetOutboundHeaders(r, c) } proxy.ModifyResponse = func(resp *http.Response) error { if resp.StatusCode != http.StatusOK { diff --git a/internal/handlers/v2/mirror_test.go b/internal/handlers/v2/mirror_test.go index 3ff8feb..484d4d4 100644 --- a/internal/handlers/v2/mirror_test.go +++ b/internal/handlers/v2/mirror_test.go @@ -9,8 +9,8 @@ import ( "net/http/httptest" "testing" - p2pcontext "github.com/azure/peerd/internal/context" - "github.com/azure/peerd/pkg/discovery/routing/tests" + pcontext "github.com/azure/peerd/pkg/context" + "github.com/azure/peerd/pkg/discovery/routing/mocks" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -62,7 +62,7 @@ func TestMirrorHandler(t *testing.T) { "first-peer-error": {"foo", goodSvr.URL}, "last-peer-working": {badSvr.URL, badSvr.URL, goodSvr.URL}, } - router := tests.NewMockRouter(resolver) + router := mocks.NewMockRouter(resolver) m := &Mirror{ router: router, resolveRetries: ResolveRetries, @@ -120,8 +120,8 @@ func TestMirrorHandler(t *testing.T) { c, _ := gin.CreateTestContext(rw) target := fmt.Sprintf("http://example.com/%s", tt.key) c.Request = httptest.NewRequest(method, target, nil) - c.Set(p2pcontext.DigestCtxKey, tt.key) - m.Handle(c) + c.Set(pcontext.DigestCtxKey, tt.key) + m.Handle(pcontext.FromContext(c)) resp := rw.Result() defer resp.Body.Close() diff --git a/internal/handlers/v2/registry.go b/internal/handlers/v2/registry.go index bba7dae..4a812c7 100644 --- a/internal/handlers/v2/registry.go +++ b/internal/handlers/v2/registry.go @@ -8,10 +8,9 @@ import ( "strconv" "time" - p2pcontext "github.com/azure/peerd/internal/context" "github.com/azure/peerd/pkg/containerd" + pcontext "github.com/azure/peerd/pkg/context" "github.com/azure/peerd/pkg/oci/distribution" - "github.com/gin-gonic/gin" "github.com/opencontainers/go-digest" ) @@ -28,16 +27,14 @@ type Registry struct { containerdStore containerd.Store } -var _ gin.HandlerFunc = (&Registry{}).Handle - // Handle handles a request to this registry. -func (r *Registry) Handle(c *gin.Context) { - dgstStr := c.GetString(p2pcontext.DigestCtxKey) - ref := c.GetString(p2pcontext.ReferenceCtxKey) +func (r *Registry) Handle(c pcontext.Context) { + dgstStr := c.GetString(pcontext.DigestCtxKey) + ref := c.GetString(pcontext.ReferenceCtxKey) var d digest.Digest var err error - l := p2pcontext.Logger(c).With().Str("handler", "registry").Str("ref", ref).Str("digest", dgstStr).Logger() + l := pcontext.Logger(c).With().Str("handler", "registry").Str("ref", ref).Str("digest", dgstStr).Logger() l.Debug().Msg("registry handler start") s := time.Now() defer func() { @@ -61,7 +58,7 @@ func (r *Registry) Handle(c *gin.Context) { } } - refType, ok := c.Get(p2pcontext.RefTypeCtxKey) + refType, ok := c.Get(pcontext.RefTypeCtxKey) if !ok { //nolint c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("ref type not found in context")) @@ -83,7 +80,7 @@ func (r *Registry) Handle(c *gin.Context) { } // handleManifest handles a manifest request. -func (r *Registry) handleManifest(c *gin.Context, dgst digest.Digest) { +func (r *Registry) handleManifest(c pcontext.Context, dgst digest.Digest) { size, err := r.containerdStore.Size(c, dgst) if err != nil { //nolint @@ -118,7 +115,7 @@ func (r *Registry) handleManifest(c *gin.Context, dgst digest.Digest) { } // handleBlob handles a blob request. -func (r *Registry) handleBlob(c *gin.Context, dgst digest.Digest) { +func (r *Registry) handleBlob(c pcontext.Context, dgst digest.Digest) { size, err := r.containerdStore.Size(c, dgst) if err != nil { //nolint diff --git a/internal/handlers/v2/registry_test.go b/internal/handlers/v2/registry_test.go index 9f958d6..eacea3c 100644 --- a/internal/handlers/v2/registry_test.go +++ b/internal/handlers/v2/registry_test.go @@ -7,8 +7,8 @@ import ( "net/http/httptest" "testing" - p2pcontext "github.com/azure/peerd/internal/context" "github.com/azure/peerd/pkg/containerd" + pcontext "github.com/azure/peerd/pkg/context" "github.com/azure/peerd/pkg/oci/distribution" "github.com/gin-gonic/gin" ) @@ -43,7 +43,9 @@ func TestHandleManifest(t *testing.T) { mc.Request = req - r.handleManifest(mc, "sha256:bb863d6b95453b6b10dfaa1a52cb53f453d9a97ee775808ebaf6533bb4c9bb30") + pmc := pcontext.Context{Context: mc} + + r.handleManifest(pmc, "sha256:bb863d6b95453b6b10dfaa1a52cb53f453d9a97ee775808ebaf6533bb4c9bb30") if mr.Code != 200 { t.Fatalf("expected 200, got %d", mr.Code) @@ -83,7 +85,9 @@ func TestHandleBlob(t *testing.T) { mc.Request = req - r.handleBlob(mc, "sha256:blob") + pmc := pcontext.Context{Context: mc} + + r.handleBlob(pmc, "sha256:blob") if mr.Code != 200 { t.Fatalf("expected 200, got %d", mr.Code) @@ -122,11 +126,13 @@ func TestHandle(t *testing.T) { } mc.Request = req - mc.Set(p2pcontext.DigestCtxKey, "sha256:bb863d6b95453b6b10dfaa1a52cb53f453d9a97ee775808ebaf6533bb4c9bb30") - mc.Set(p2pcontext.ReferenceCtxKey, "library/alpine:3.18.0") - mc.Set(p2pcontext.RefTypeCtxKey, distribution.ReferenceType(distribution.ReferenceTypeManifest)) + mc.Set(pcontext.DigestCtxKey, "sha256:bb863d6b95453b6b10dfaa1a52cb53f453d9a97ee775808ebaf6533bb4c9bb30") + mc.Set(pcontext.ReferenceCtxKey, "library/alpine:3.18.0") + mc.Set(pcontext.RefTypeCtxKey, distribution.ReferenceType(distribution.ReferenceTypeManifest)) + + pmc := pcontext.Context{Context: mc} - r.Handle(mc) + r.Handle(pmc) if mr.Code != 200 { t.Fatalf("expected 200, got %d", mr.Code) diff --git a/pkg/mocks/contentstore.go b/pkg/containerd/mocks/contentstore.go similarity index 100% rename from pkg/mocks/contentstore.go rename to pkg/containerd/mocks/contentstore.go diff --git a/pkg/mocks/eventservice.go b/pkg/containerd/mocks/eventservice.go similarity index 100% rename from pkg/mocks/eventservice.go rename to pkg/containerd/mocks/eventservice.go diff --git a/pkg/mocks/imagestore.go b/pkg/containerd/mocks/imagestore.go similarity index 100% rename from pkg/mocks/imagestore.go rename to pkg/containerd/mocks/imagestore.go diff --git a/pkg/containerd/store_test.go b/pkg/containerd/store_test.go index 039fc52..4081db1 100644 --- a/pkg/containerd/store_test.go +++ b/pkg/containerd/store_test.go @@ -10,7 +10,7 @@ import ( "testing" "time" - "github.com/azure/peerd/pkg/mocks" + "github.com/azure/peerd/pkg/containerd/mocks" "github.com/containerd/containerd" eventtypes "github.com/containerd/containerd/api/events" "github.com/containerd/containerd/events" diff --git a/internal/context/context.go b/pkg/context/context.go similarity index 77% rename from internal/context/context.go rename to pkg/context/context.go index e77b87f..5f70052 100644 --- a/internal/context/context.go +++ b/pkg/context/context.go @@ -15,6 +15,11 @@ import ( "github.com/rs/zerolog" ) +const ( + // KubeConfigPath is the path of the kubeconfig file. + KubeConfigPath = "/opt/peerd/kubeconfig" +) + // Context keys. const ( CorrelationIdCtxKey = "correlation_id" @@ -30,9 +35,9 @@ const ( // Request headers. const ( - P2PHeaderKey = "X-MS-Cluster-P2P-RequestFromPeer" - CorrelationHeaderKey = "X-MS-Cluster-P2P-CorrelationId" - NodeHeaderKey = "X-MS-Cluster-P2P-Node" + P2PHeaderKey = "X-MS-Peerd-RequestFromPeer" + CorrelationHeaderKey = "X-MS-Peerd-CorrelationId" + NodeHeaderKey = "X-MS-Peerd-Node" ) // Log messages. @@ -46,17 +51,31 @@ const ( var ( NodeName, _ = os.Hostname() - - // KubeConfigPath is the path of the kubeconfig file. - KubeConfigPath = "/opt/peerd/kubeconfig" ) +// Context is the request context that can be passed around to various components to provide request specific information. +type Context struct { + *gin.Context +} + +// FromContext creates a new context from the given gin context. +func FromContext(c *gin.Context) Context { + return Context{Context: c} +} + +// Copy creates a copy of the context that can be safely used outside the request's scope. +func (c Context) Copy() Context { + cc := c.Context.Copy() + return Context{Context: cc} +} + // IsRequestFromAPeer indicates if the current request is from a peer. -func IsRequestFromAPeer(c *gin.Context) bool { +func IsRequestFromAPeer(c Context) bool { return c.Request.Header.Get(P2PHeaderKey) == "true" } -func FillCorrelationId(c *gin.Context) { +// FillCorrelationId fills the correlation ID in the context. +func FillCorrelationId(c Context) { correlationId := c.Request.Header.Get(CorrelationHeaderKey) if correlationId == "" { correlationId = uuid.New().String() @@ -64,8 +83,15 @@ func FillCorrelationId(c *gin.Context) { c.Set(CorrelationIdCtxKey, correlationId) } +// SetOutboundHeaders sets the mandatory headers for all outbound requests. +func SetOutboundHeaders(r *http.Request, c Context) { + r.Header.Set(P2PHeaderKey, "true") + r.Header.Set(CorrelationHeaderKey, c.GetString(CorrelationIdCtxKey)) + r.Header.Set(NodeHeaderKey, NodeName) +} + // Logger gets the logger with request specific fields. -func Logger(c *gin.Context) zerolog.Logger { +func Logger(c Context) zerolog.Logger { var l zerolog.Logger obj, ok := c.Get(LoggerCtxKey) if !ok { @@ -80,17 +106,10 @@ func Logger(c *gin.Context) zerolog.Logger { } // BlobUrl extracts the blob URL from the incoming request URL. -func BlobUrl(c *gin.Context) string { +func BlobUrl(c Context) string { return strings.TrimPrefix(c.Param("url"), "/") + "?" + c.Request.URL.RawQuery } -// SetOutboundHeaders sets the mandatory headers for all outbound requests. -func SetOutboundHeaders(r *http.Request, c *gin.Context) { - r.Header.Set(P2PHeaderKey, "true") - r.Header.Set(CorrelationHeaderKey, c.GetString(CorrelationIdCtxKey)) - r.Header.Set(NodeHeaderKey, NodeName) -} - // RangeStartIndex returns the start index of a byte range specified in the given range header value. // It expects the range value to be in the format "bytes=startIndex-endIndex". func RangeStartIndex(rangeValue string) (int64, error) { diff --git a/internal/context/context_test.go b/pkg/context/context_test.go similarity index 91% rename from internal/context/context_test.go rename to pkg/context/context_test.go index 435d5ea..e125c02 100644 --- a/internal/context/context_test.go +++ b/pkg/context/context_test.go @@ -27,7 +27,9 @@ func TestLogger(t *testing.T) { c, _ := gin.CreateTestContext(httptest.NewRecorder()) c.Request = req - l := Logger(c) + pc := FromContext(c) + + l := Logger(pc) if l.Info().Enabled() { t.Fatal("expected logger to be disabled") } @@ -35,7 +37,7 @@ func TestLogger(t *testing.T) { testL := zerolog.New(os.Stdout).With().Timestamp().Logger() c.Set(LoggerCtxKey, &testL) - l = Logger(c) + l = Logger(pc) if !l.Info().Enabled() { t.Fatal("expected logger to be enabled") } @@ -51,9 +53,12 @@ func TestSetOutboundHeaders(t *testing.T) { // Create a new context with the request. ctx, _ := gin.CreateTestContext(httptest.NewRecorder()) ctx.Request = req - FillCorrelationId(ctx) - SetOutboundHeaders(req, ctx) + pc := FromContext(ctx) + + FillCorrelationId(pc) + + SetOutboundHeaders(req, pc) if req.Header.Get(P2PHeaderKey) != "true" { t.Errorf("expected: %v, got: %v", "true", req.Header.Get(P2PHeaderKey)) @@ -82,8 +87,10 @@ func TestBlobUrl(t *testing.T) { {Key: "url", Value: hostAndPath}, } + pc := FromContext(ctx) + // Call BlobUrl and verify the result. - got := BlobUrl(ctx) + got := BlobUrl(pc) if got != u { t.Errorf("expected: %v, got: %v", u, got) } @@ -99,8 +106,10 @@ func TestFillCorrelationId(t *testing.T) { ctx, _ := gin.CreateTestContext(httptest.NewRecorder()) ctx.Request = req - FillCorrelationId(ctx) - cid, ok := ctx.Get(CorrelationIdCtxKey) + pc := FromContext(ctx) + + FillCorrelationId(pc) + cid, ok := pc.Get(CorrelationIdCtxKey) if !ok || cid == "" { t.Fatal("expected correlation ID to be set") } @@ -110,8 +119,11 @@ func TestFillCorrelationId(t *testing.T) { ctx, _ = gin.CreateTestContext(httptest.NewRecorder()) ctx.Request = req ctx.Request.Header.Set(CorrelationHeaderKey, sample) - FillCorrelationId(ctx) - cid, ok = ctx.Get(CorrelationIdCtxKey) + + pc = FromContext(ctx) + + FillCorrelationId(pc) + cid, ok = pc.Get(CorrelationIdCtxKey) if !ok || cid == "" { t.Fatal("expected correlation ID to be set") } else if cid != sample { @@ -129,12 +141,14 @@ func TestIsRequestFromPeer(t *testing.T) { ctx, _ := gin.CreateTestContext(httptest.NewRecorder()) ctx.Request = req - if IsRequestFromAPeer(ctx) { + pc := FromContext(ctx) + + if IsRequestFromAPeer(pc) { t.Fatal("expected request to not be from a peer") } ctx.Request.Header.Set(P2PHeaderKey, "true") - if !IsRequestFromAPeer(ctx) { + if !IsRequestFromAPeer(pc) { t.Fatal("expected request to be from a peer") } } diff --git a/pkg/discovery/content/consts.go b/pkg/discovery/content/consts.go deleted file mode 100644 index 2ab6b65..0000000 --- a/pkg/discovery/content/consts.go +++ /dev/null @@ -1,17 +0,0 @@ -package content - -// Log messages. -const ( - PeerResolutionStartLog = "peer resolution start" - PeerResolutionStopLog = "peer resolution stop" - PeerNotFoundLog = "peer not found" - PeerResolutionExhaustedLog = "peer resolution exhausted" - PeerRequestErrorLog = "peer request error" -) - -// Request headers. -const ( - P2PHeaderKey = "X-MS-Peerd-RequestFromPeer" - CorrelationHeaderKey = "X-MS-Peerd-CorrelationId" - NodeHeaderKey = "X-MS-Peerd-Node" -) diff --git a/pkg/discovery/content/http.go b/pkg/discovery/content/http.go deleted file mode 100644 index 25d79b1..0000000 --- a/pkg/discovery/content/http.go +++ /dev/null @@ -1,14 +0,0 @@ -package content - -import ( - "net/http" - - "github.com/azure/peerd/internal/context" -) - -// SetOutboundHeaders sets the mandatory headers for all outbound requests. -func SetOutboundHeaders(r *http.Request, correlationId string) { - r.Header.Set(P2PHeaderKey, "true") - r.Header.Set(CorrelationHeaderKey, correlationId) - r.Header.Set(NodeHeaderKey, context.NodeName) -} diff --git a/pkg/discovery/content/provider/provider_test.go b/pkg/discovery/content/provider/provider_test.go index 9d4c7c4..236ccaf 100644 --- a/pkg/discovery/content/provider/provider_test.go +++ b/pkg/discovery/content/provider/provider_test.go @@ -10,7 +10,7 @@ import ( "time" "github.com/azure/peerd/pkg/containerd" - "github.com/azure/peerd/pkg/discovery/routing/tests" + "github.com/azure/peerd/pkg/discovery/routing/mocks" "github.com/stretchr/testify/require" ) @@ -33,7 +33,7 @@ func TestContainerdStoreAds(t *testing.T) { } containerdStore := containerd.NewMockContainerdStore(refs) - router := tests.NewMockRouter(map[string][]string{}) + router := mocks.NewMockRouter(map[string][]string{}) ctx, cancel := context.WithCancel(context.TODO()) go func() { diff --git a/pkg/mocks/reader.go b/pkg/discovery/content/reader/mocks/reader.go similarity index 100% rename from pkg/mocks/reader.go rename to pkg/discovery/content/reader/mocks/reader.go diff --git a/pkg/discovery/content/reader/reader.go b/pkg/discovery/content/reader/reader.go new file mode 100644 index 0000000..e8a0d98 --- /dev/null +++ b/pkg/discovery/content/reader/reader.go @@ -0,0 +1,288 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +package reader + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + pcontext "github.com/azure/peerd/pkg/context" + "github.com/azure/peerd/pkg/discovery/routing" + "github.com/azure/peerd/pkg/metrics" + "github.com/rs/zerolog" +) + +type operation int + +const ( + operationFstatRemote = operation(iota) + operationPreadRemote +) + +var errPeerNotFound = errors.New("peer not found") + +// reader is a Reader implementation. +type reader struct { + context pcontext.Context + resolveTimeout time.Duration + + router routing.Router + resolveRetries int + defaultHttpClient *http.Client + + metricsRecorder metrics.Metrics +} + +var _ Reader = &reader{} + +// Log returns the logger with context for this reader. +func (r *reader) Log() *zerolog.Logger { + l := pcontext.Logger(r.context) + return &l +} + +// PreadRemote is like pread but to a remote file. +func (r *reader) PreadRemote(buf []byte, offset int64) (int, error) { + key := r.context.GetString(pcontext.FileChunkCtxKey) + start := offset + end := int64(len(buf)) + offset - 1 + + log := r.Log().With().Str("operation", "preadremote").Str("key", key).Int64("start", start).Int64("end", end).Logger() + + count, err := r.doP2p(log, key, start, end, operationPreadRemote, buf) + if err == nil { + return int(count), nil + } + + // Could not find a peer that has this file, request origin. + startTime := time.Now() + originReq, err := r.originRequest(start, end) + if err != nil { + return -1, err + } + + count32 := int(0) + defer func() { + r.metricsRecorder.RecordUpstreamResponse(originReq.URL.Hostname(), key, "pread", time.Since(startTime).Seconds(), int64(count32)) + }() + count32, err = r.preadRemote(log, originReq, r.defaultHttpClient, buf) + return count32, err +} + +// FstatRemote stats a remote file. +func (r *reader) FstatRemote() (int64, error) { + key := r.context.GetString(pcontext.FileChunkCtxKey) + start := int64(0) + end := int64(0) + + log := r.Log().With().Str("operation", "fstatremote").Int64("start", start).Int64("end", end).Str("key", key).Logger() + + startTime := time.Now() + originReq, err := r.originRequest(start, end) + if err != nil { + return -1, err + } + + var count int64 + defer func() { + r.metricsRecorder.RecordUpstreamResponse(originReq.URL.Hostname(), key, "fstat", time.Since(startTime).Seconds(), count) + }() + count, err = r.fstatRemote(log, originReq, r.defaultHttpClient) + return count, err +} + +// doP2p tries to resolve the key in the p2p network and if successful, it will perform the operation on the peer, and return the result. +func (r *reader) doP2p(log zerolog.Logger, fileChunkKey string, start, end int64, o operation, buf []byte) (int64, error) { + if pcontext.IsRequestFromAPeer(r.context) { + log.Warn().Msg("refusing to propagate request from one peer to another") + return -1, errPeerNotFound + } + + log.Debug().Msg(pcontext.PeerResolutionStartLog) + defer log.Debug().Msg(pcontext.PeerResolutionStopLog) + + resolveCtx, cancel := context.WithTimeout(log.WithContext(r.context), r.resolveTimeout) + defer cancel() + + startTime := time.Now() + peerCount := 0 + peersCh, negCacheCallback, err := r.router.ResolveWithNegativeCacheCallback(resolveCtx, fileChunkKey, false, r.resolveRetries) + if err != nil { + //nolint:errcheck // ignore + log.Error().Err(err).Msg(pcontext.PeerRequestErrorLog) + return -1, err + } + + // Request a peer for this file. +peerLoop: + for { + select { + + case <-resolveCtx.Done(): + // Resolving mirror has timed out. + negCacheCallback() + log.Info().Msg(pcontext.PeerNotFoundLog) + break peerLoop + + case peer, ok := <-peersCh: + // Channel closed means no more mirrors will be received and max retries has been reached. + if !ok { + negCacheCallback() + log.Info().Msg(pcontext.PeerResolutionExhaustedLog) + break peerLoop + } + + if peerCount == 0 { + // Only report the time it took to discover the first peer. + r.metricsRecorder.RecordPeerDiscovery(peer.HttpHost, time.Since(startTime).Seconds()) + peerCount++ + } + + peerReq, err := r.peerRequest(peer.HttpHost, start, end) + if err != nil { + log.Error().Err(err).Msg(pcontext.PeerRequestErrorLog) + // try next peer + break + } + + client := r.router.Net().HTTPClientFor(peer.ID) + + var count int64 + startTime = time.Now() + if o == operationFstatRemote { + count, err = r.fstatRemote(log, peerReq, client) + } else if o == operationPreadRemote { + var c int + c, err = r.preadRemote(log, peerReq, client, buf) + count = int64(c) + } else { + err = fmt.Errorf("unknown operation: %v", o) + } + + if err != nil { + // try next peer + log.Error().Err(err).Msg(pcontext.PeerRequestErrorLog) + } else { + op := "fstat" + if o == operationPreadRemote { + op = "pread" + } + r.metricsRecorder.RecordPeerResponse(peer.HttpHost, fileChunkKey, op, time.Since(startTime).Seconds(), count) + return count, nil + } + } + } + + return -1, errPeerNotFound +} + +// fstatRemote stats the file. +func (r *reader) fstatRemote(log zerolog.Logger, req *http.Request, client *http.Client) (int64, error) { + log.Debug().Str("url", req.URL.String()).Str("range", req.Header.Get("Range")).Msg("reader fstatRemote start") + defer log.Debug().Msg("reader fstatRemote stop") + + resp, err := client.Do(req) + if err != nil { + log.Error().Err(err).Msg("reader fstatRemote error") + return 0, Error{resp, err} + } + defer resp.Body.Close() + + if resp.StatusCode == 200 { + return resp.ContentLength, nil + } + + if resp.StatusCode == 206 { + l := resp.ContentLength + rs := resp.Header.Get("Content-Range") + if rs == "" { + return l, nil + } + + pos := strings.LastIndexByte(rs, '/') + if pos < 0 { + return l, nil + } + + l, _ = strconv.ParseInt(rs[pos+1:], 10, 64) + return l, nil + } + + log.Error().Err(err).Int("status", resp.StatusCode).Msg("reader fstatRemote error") + return 0, Error{resp, fmt.Errorf("unexpected response code: %d", resp.StatusCode)} +} + +// preadRemote reads the file. +func (r *reader) preadRemote(log zerolog.Logger, req *http.Request, client *http.Client, buf []byte) (int, error) { + log.Debug().Str("url", req.URL.String()).Str("range", req.Header.Get("Range")).Msg("reader preadRemote start") + statusCode := -1 + s := time.Now() + defer func() { + log.Debug().Int("status", statusCode).Dur("duration", time.Since(s)).Msg("reader preadRemote stop") + }() + + resp, err := client.Do(req) + if resp != nil { + statusCode = resp.StatusCode + } + if err != nil { + detailedErr := Error{resp, err} + log.Error().Err(detailedErr).Str("url", req.URL.String()).Str("range", req.Header.Get("Range")).Msg("reader preadRemote error") + return 0, detailedErr + } + defer resp.Body.Close() + + if resp.StatusCode != 200 && resp.StatusCode != 206 { + log.Error().Err(err).Int("status", resp.StatusCode).Msg("reader preadRemote error") + return 0, Error{resp, fmt.Errorf("unexpected response code: %d", resp.StatusCode)} + } + + return io.ReadFull(resp.Body, buf) +} + +// originRequest will create a new request to origin. +func (r *reader) originRequest(start, end int64) (*http.Request, error) { + return r.remoteRequest(r.context.GetString(pcontext.BlobUrlCtxKey), start, end) +} + +// perRequest will create a new request to a peer. +func (r *reader) peerRequest(peer string, start, end int64) (*http.Request, error) { + return r.remoteRequest(fmt.Sprintf("%v/blobs/%v", peer, r.context.GetString(pcontext.BlobUrlCtxKey)), start, end) +} + +// remoteRequest creates a new HTTP request to a remote server. +func (r *reader) remoteRequest(u string, start, end int64) (*http.Request, error) { + req, err := http.NewRequest("GET", u, nil) + if err != nil { + return nil, err + } + + for key, vals := range r.context.Request.Header { + vals2 := make([]string, len(vals)) + copy(vals2, vals) + req.Header[key] = vals2 + } + + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) + pcontext.SetOutboundHeaders(req, r.context) + + return req, nil +} + +// NewReader creates a new remote reader. +func NewReader(c pcontext.Context, router routing.Router, resolveRetries int, resolveTimeout time.Duration, metricsRecorder metrics.Metrics) Reader { + return &reader{ + context: c.Copy(), + resolveTimeout: resolveTimeout, + router: router, + resolveRetries: resolveRetries, + defaultHttpClient: router.Net().HTTPClientFor(""), + metricsRecorder: metricsRecorder, + } +} diff --git a/pkg/discovery/content/reader/reader_test.go b/pkg/discovery/content/reader/reader_test.go new file mode 100644 index 0000000..21ad3c7 --- /dev/null +++ b/pkg/discovery/content/reader/reader_test.go @@ -0,0 +1,323 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +package reader + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + pcontext "github.com/azure/peerd/pkg/context" + "github.com/azure/peerd/pkg/discovery/routing/mocks" + "github.com/azure/peerd/pkg/metrics" + "github.com/gin-gonic/gin" + "github.com/prometheus/client_golang/prometheus" + "github.com/rs/zerolog" +) + +var ( + hostAndPath = "https://avtakkartest.blob.core.windows.net/d18c7a64c5158179-ff8cb2f639ff44879c12c94361a746d0-782b855128//docker/registry/v2/blobs/sha256/d1/d18c7a64c5158179bdee531a663c5b487de57ff17cff3af29a51c7e70b491d9d/data" + query = "?se=2023-09-20T01%3A14%3A49Z&sig=m4Cr%2BYTZHZQlN5LznY7nrTQ4LCIx2OqnDDM3Dpedbhs%3D&sp=r&spr=https&sr=b&sv=2018-03-28®id=01031d61e1024861afee5d512651eb9f" + u = hostAndPath + query + mr = metrics.NewPromMetrics(prometheus.DefaultRegisterer, "test", "test") +) + +func TestPreadRemoteUpstream(t *testing.T) { + // Setup + m := map[string][]string{} + key := "somekey" + expected := "expected-result" + peersTried := 0 + svr3 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + peersTried++ + w.WriteHeader(http.StatusUnauthorized) + })) + defer svr3.Close() + svr2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + peersTried++ + w.WriteHeader(http.StatusNotFound) + })) + defer svr2.Close() + svr1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + peersTried++ + w.WriteHeader(http.StatusBadGateway) + })) + defer svr1.Close() + val := []string{svr1.URL, svr2.URL, svr3.URL} + m[key] = val + + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if "?"+r.URL.RawQuery == query { + w.Header().Set("Content-Type", "application/octet-stream") + // nolint:errcheck + w.Write([]byte(expected)) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer svr.Close() + p := svr.URL + "/some-path" + u := "http://127.0.0.1:5000/blobs/" + p + query + req, err := http.NewRequest("GET", u, nil) + if err != nil { + t.Fatal(err) + } + + router := mocks.NewMockRouter(m) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = req + c.Params = []gin.Param{ + {Key: "url", Value: p}, + } + + pc := pcontext.FromContext(c) + + pc.Set(pcontext.BlobUrlCtxKey, pcontext.BlobUrl(pc)) + pc.Set(pcontext.BlobRangeCtxKey, "bytes=0-10") + pc.Set(pcontext.FileChunkCtxKey, key) + + r := NewReader(pc, router, 3, 500*time.Millisecond, mr).(*reader) + b := make([]byte, 10) + + // Test + got, err := r.PreadRemote(b, 0) + + // Assert + if err != nil { + t.Fatal(err) + } else if got != 10 { + t.Fatalf("expected %v, got %v", 10, got) + } else if string(b) != expected[:10] { + t.Fatalf("expected %v, got %v", expected[:10], string(b)) + } else if peersTried != 3 { + t.Fatalf("expected %v, got %v", 3, peersTried) + } +} + +func TestFstatRemote(t *testing.T) { + m := map[string][]string{} + + expected := "expected-result" + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if "?"+r.URL.RawQuery == query { + w.Header().Set("Content-Type", "application/octet-stream") + // nolint:errcheck + w.Write([]byte(expected)) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer svr.Close() + p := svr.URL + "/some-path" + u := "http://127.0.0.1:5000/blobs/" + p + query + req, err := http.NewRequest("GET", u, nil) + if err != nil { + t.Fatal(err) + } + + router := mocks.NewMockRouter(m) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = req + c.Params = []gin.Param{ + {Key: "url", Value: p}, + } + + pc := pcontext.FromContext(c) + + pc.Set(pcontext.BlobUrlCtxKey, pcontext.BlobUrl(pc)) + pc.Set(pcontext.BlobRangeCtxKey, "bytes=0-0") + + r := NewReader(pc, router, 3, 500*time.Millisecond, mr).(*reader) + + got, err := r.FstatRemote() + if err != nil { + t.Fatal(err) + } else if got != int64(len(expected)) { + t.Fatalf("expected %v, got %v", len(expected), got) + } +} + +func TestFstatRemotePartialContent(t *testing.T) { + m := map[string][]string{} + + expected := "expected-result" + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if "?"+r.URL.RawQuery == query { + w.Header().Set("Content-Type", "application/octet-stream") + // nolint:errcheck + w.WriteHeader(http.StatusPartialContent) + w.Header().Set("Content-Range", "bytes 0-10/10") + // nolint:errcheck + w.Write([]byte(expected)) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer svr.Close() + p := svr.URL + "/some-path" + u := "http://127.0.0.1:5000/blobs/" + p + query + req, err := http.NewRequest("GET", u, nil) + if err != nil { + t.Fatal(err) + } + + router := mocks.NewMockRouter(m) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = req + c.Params = []gin.Param{ + {Key: "url", Value: p}, + } + + pc := pcontext.FromContext(c) + + pc.Set(pcontext.BlobUrlCtxKey, pcontext.BlobUrl(pc)) + pc.Set(pcontext.BlobRangeCtxKey, "bytes=0-0") + + r := NewReader(pc, router, 3, 500*time.Millisecond, mr).(*reader) + + got, err := r.FstatRemote() + if err != nil { + t.Fatal(err) + } else if got != int64(len(expected)) { + t.Fatalf("expected %v, got %v", len(expected), got) + } +} + +func TestP2pRetries(t *testing.T) { + l := zerolog.Nop() + m := map[string][]string{} + key := "somekey" + expected := "expected-result" + svr3 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/octet-stream") + // nolint:errcheck + w.Write([]byte(expected)) + })) + defer svr3.Close() + svr2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer svr2.Close() + svr1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + })) + defer svr1.Close() + val := []string{svr1.URL, svr2.URL, svr3.URL} + m[key] = val + + req, err := http.NewRequest("GET", "http://127.0.0.1:5000/blobs/"+u, nil) + if err != nil { + t.Fatal(err) + } + + router := mocks.NewMockRouter(m) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = req + r := NewReader(pcontext.FromContext(c), router, 3, 500*time.Millisecond, mr).(*reader) + b := make([]byte, 10) + + got, err := r.doP2p(l, key, 0, 10, operationPreadRemote, b) + if err != nil { + t.Fatal(err) + } + + if got != 10 { + t.Fatalf("expected %v, got %v", 10, got) + } else if string(b) != expected[:10] { + t.Fatalf("expected %v, got %v", expected[:10], string(b)) + } +} + +func TestP2pSuccess(t *testing.T) { + l := zerolog.Nop() + m := map[string][]string{} + key := "somekey" + expected := "expected-result" + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/octet-stream") + // nolint:errcheck + w.Write([]byte(expected)) + })) + defer svr.Close() + val := []string{svr.URL} + m[key] = val + + req, err := http.NewRequest("GET", "http://127.0.0.1:5000/blobs/"+u, nil) + if err != nil { + t.Fatal(err) + } + + router := mocks.NewMockRouter(m) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = req + r := NewReader(pcontext.FromContext(c), router, 3, 500*time.Millisecond, mr).(*reader) + b := make([]byte, 10) + + got, err := r.doP2p(l, key, 0, 10, operationPreadRemote, b) + if err != nil { + t.Fatal(err) + } + + if got != 10 { + t.Fatalf("expected %v, got %v", 10, got) + } else if string(b) != expected[:10] { + t.Fatalf("expected %v, got %v", expected[:10], string(b)) + } +} + +func TestP2pPeerNotFound(t *testing.T) { + l := zerolog.Nop() + m := map[string][]string{} + + req, err := http.NewRequest("GET", "http://127.0.0.1:5000/blobs/"+u, nil) + if err != nil { + t.Fatal(err) + } + + router := mocks.NewMockRouter(m) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = req + + r := NewReader(pcontext.FromContext(c), router, 3, 500*time.Millisecond, mr).(*reader) + + b := make([]byte, 10) + _, err = r.doP2p(l, "key", 0, 10, operationPreadRemote, b) + if err == nil { + t.Fatal("expected error") + } + + if err != errPeerNotFound { + t.Fatalf("expected %v, got %v", errPeerNotFound, err) + } +} + +func TestP2pNoInfiniteLoops(t *testing.T) { + l := zerolog.Nop() + m := map[string][]string{} + key := "some-key" + val := []string{"http://localhost"} + m[key] = val + + req, err := http.NewRequest("GET", "http://127.0.0.1:5000/blobs/"+u, nil) + if err != nil { + t.Fatal(err) + } + + router := mocks.NewMockRouter(m) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = req + c.Request.Header.Add(pcontext.P2PHeaderKey, "true") + + r := NewReader(pcontext.FromContext(c), router, 3, 500*time.Millisecond, mr).(*reader) + + b := make([]byte, 10) + _, err = r.doP2p(l, key, 0, 10, operationPreadRemote, b) + if err == nil { + t.Fatal("expected error") + } + + if err != errPeerNotFound { + t.Fatalf("expected %v, got %v", errPeerNotFound, err) + } +} diff --git a/pkg/mocks/router.go b/pkg/discovery/routing/mocks/router.go similarity index 94% rename from pkg/mocks/router.go rename to pkg/discovery/routing/mocks/router.go index f72fa1e..e49ba28 100644 --- a/pkg/mocks/router.go +++ b/pkg/discovery/routing/mocks/router.go @@ -8,6 +8,7 @@ import ( "github.com/azure/peerd/pkg/discovery/routing" "github.com/azure/peerd/pkg/peernet" + "github.com/azure/peerd/pkg/peernet/mocks" "github.com/libp2p/go-libp2p/core/peer" ) @@ -37,7 +38,7 @@ func (m *MockRouter) ResolveWithNegativeCacheCallback(ctx context.Context, key s var _ routing.Router = &MockRouter{} func NewMockRouter(resolver map[string][]string) *MockRouter { - n, err := peernet.New(&MockHost{PeerStore: &MockPeerstore{}}) + n, err := peernet.New(&mocks.MockHost{PeerStore: &mocks.MockPeerstore{}}) if err != nil { panic(err) } diff --git a/pkg/mocks/host.go b/pkg/peernet/mocks/host.go similarity index 100% rename from pkg/mocks/host.go rename to pkg/peernet/mocks/host.go diff --git a/pkg/mocks/peerstore.go b/pkg/peernet/mocks/peerstore.go similarity index 100% rename from pkg/mocks/peerstore.go rename to pkg/peernet/mocks/peerstore.go diff --git a/pkg/peernet/network_test.go b/pkg/peernet/network_test.go index 1783400..117396a 100644 --- a/pkg/peernet/network_test.go +++ b/pkg/peernet/network_test.go @@ -7,7 +7,7 @@ import ( "net/http" "testing" - "github.com/azure/peerd/pkg/mocks" + "github.com/azure/peerd/pkg/peernet/mocks" ) func TestNew(t *testing.T) { diff --git a/tests/cmd/main.go b/tests/cmd/main.go index a938994..8142392 100644 --- a/tests/cmd/main.go +++ b/tests/cmd/main.go @@ -10,7 +10,7 @@ import ( "syscall" "github.com/alexflint/go-arg" - p2pcontext "github.com/azure/peerd/internal/context" + pcontext "github.com/azure/peerd/pkg/context" "github.com/azure/peerd/tests/random" "github.com/azure/peerd/tests/scanner" "github.com/rs/zerolog" @@ -21,7 +21,7 @@ func main() { arg.MustParse(args) zerolog.SetGlobalLevel(zerolog.InfoLevel) - l := zerolog.New(os.Stdout).With().Timestamp().Str("node", p2pcontext.NodeName).Str("version", version).Logger() + l := zerolog.New(os.Stdout).With().Timestamp().Str("node", pcontext.NodeName).Str("version", version).Logger() ctx := l.WithContext(context.Background()) err := run(ctx, args) From 7de14623c040558f4047baca870e19d579cec0ca Mon Sep 17 00:00:00 2001 From: Aviral Takkar Date: Mon, 15 Apr 2024 14:56:02 -0700 Subject: [PATCH 3/3] refactor: move files to pkg --- cmd/proxy/main.go | 2 +- internal/handlers/files/handler.go | 2 +- internal/handlers/files/handler_test.go | 4 ++-- internal/handlers/root.go | 2 +- internal/handlers/root_test.go | 2 +- {internal => pkg}/files/files.go | 0 {internal => pkg}/files/files_test.go | 0 {internal => pkg}/files/store/file.go | 2 +- {internal => pkg}/files/store/file_test.go | 2 +- {internal => pkg}/files/store/interface.go | 0 {internal => pkg}/files/store/main_test.go | 0 {internal => pkg}/files/store/mockstore.go | 0 {internal => pkg}/files/store/store.go | 2 +- {internal => pkg}/files/store/store_test.go | 2 +- 14 files changed, 10 insertions(+), 10 deletions(-) rename {internal => pkg}/files/files.go (100%) rename {internal => pkg}/files/files_test.go (100%) rename {internal => pkg}/files/store/file.go (98%) rename {internal => pkg}/files/store/file_test.go (99%) rename {internal => pkg}/files/store/interface.go (100%) rename {internal => pkg}/files/store/main_test.go (100%) rename {internal => pkg}/files/store/mockstore.go (100%) rename {internal => pkg}/files/store/store.go (99%) rename {internal => pkg}/files/store/store_test.go (98%) diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 8de1a73..fb1f6cc 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -15,12 +15,12 @@ import ( "time" "github.com/alexflint/go-arg" - "github.com/azure/peerd/internal/files/store" "github.com/azure/peerd/internal/handlers" "github.com/azure/peerd/pkg/containerd" pcontext "github.com/azure/peerd/pkg/context" "github.com/azure/peerd/pkg/discovery/content/provider" "github.com/azure/peerd/pkg/discovery/routing" + "github.com/azure/peerd/pkg/files/store" "github.com/azure/peerd/pkg/k8s" "github.com/azure/peerd/pkg/k8s/events" "github.com/azure/peerd/pkg/metrics" diff --git a/internal/handlers/files/handler.go b/internal/handlers/files/handler.go index b3ce9be..8ad1c4b 100644 --- a/internal/handlers/files/handler.go +++ b/internal/handlers/files/handler.go @@ -8,8 +8,8 @@ import ( "os" "time" - "github.com/azure/peerd/internal/files/store" pcontext "github.com/azure/peerd/pkg/context" + "github.com/azure/peerd/pkg/files/store" "github.com/azure/peerd/pkg/metrics" ) diff --git a/internal/handlers/files/handler_test.go b/internal/handlers/files/handler_test.go index 66b98c7..5f4b4b7 100644 --- a/internal/handlers/files/handler_test.go +++ b/internal/handlers/files/handler_test.go @@ -10,10 +10,10 @@ import ( "net/http/httptest" "testing" - "github.com/azure/peerd/internal/files" - "github.com/azure/peerd/internal/files/store" pcontext "github.com/azure/peerd/pkg/context" "github.com/azure/peerd/pkg/discovery/routing/mocks" + "github.com/azure/peerd/pkg/files" + "github.com/azure/peerd/pkg/files/store" "github.com/azure/peerd/pkg/metrics" "github.com/gin-gonic/gin" ) diff --git a/internal/handlers/root.go b/internal/handlers/root.go index ea72475..c739fab 100644 --- a/internal/handlers/root.go +++ b/internal/handlers/root.go @@ -7,12 +7,12 @@ import ( "net/http" "time" - filesStore "github.com/azure/peerd/internal/files/store" "github.com/azure/peerd/internal/handlers/files" v2 "github.com/azure/peerd/internal/handlers/v2" "github.com/azure/peerd/pkg/containerd" pcontext "github.com/azure/peerd/pkg/context" "github.com/azure/peerd/pkg/discovery/routing" + filesStore "github.com/azure/peerd/pkg/files/store" "github.com/gin-gonic/gin" "github.com/rs/zerolog" ) diff --git a/internal/handlers/root_test.go b/internal/handlers/root_test.go index 1ec7b71..ae7c606 100644 --- a/internal/handlers/root_test.go +++ b/internal/handlers/root_test.go @@ -6,9 +6,9 @@ import ( "net/http/httptest" "testing" - "github.com/azure/peerd/internal/files/store" "github.com/azure/peerd/pkg/containerd" "github.com/azure/peerd/pkg/discovery/routing/mocks" + "github.com/azure/peerd/pkg/files/store" "github.com/azure/peerd/pkg/metrics" "github.com/gin-gonic/gin" ) diff --git a/internal/files/files.go b/pkg/files/files.go similarity index 100% rename from internal/files/files.go rename to pkg/files/files.go diff --git a/internal/files/files_test.go b/pkg/files/files_test.go similarity index 100% rename from internal/files/files_test.go rename to pkg/files/files_test.go diff --git a/internal/files/store/file.go b/pkg/files/store/file.go similarity index 98% rename from internal/files/store/file.go rename to pkg/files/store/file.go index b1e9162..56b0add 100644 --- a/internal/files/store/file.go +++ b/pkg/files/store/file.go @@ -8,8 +8,8 @@ import ( "sync" - "github.com/azure/peerd/internal/files" "github.com/azure/peerd/pkg/discovery/content/reader" + "github.com/azure/peerd/pkg/files" "github.com/azure/peerd/pkg/math" ) diff --git a/internal/files/store/file_test.go b/pkg/files/store/file_test.go similarity index 99% rename from internal/files/store/file_test.go rename to pkg/files/store/file_test.go index 4c6053f..547213f 100644 --- a/internal/files/store/file_test.go +++ b/pkg/files/store/file_test.go @@ -9,10 +9,10 @@ import ( "strings" "testing" - "github.com/azure/peerd/internal/files" "github.com/azure/peerd/pkg/cache" readermocks "github.com/azure/peerd/pkg/discovery/content/reader/mocks" "github.com/azure/peerd/pkg/discovery/routing/mocks" + "github.com/azure/peerd/pkg/files" ) func TestReadAtWithChunkOffset(t *testing.T) { diff --git a/internal/files/store/interface.go b/pkg/files/store/interface.go similarity index 100% rename from internal/files/store/interface.go rename to pkg/files/store/interface.go diff --git a/internal/files/store/main_test.go b/pkg/files/store/main_test.go similarity index 100% rename from internal/files/store/main_test.go rename to pkg/files/store/main_test.go diff --git a/internal/files/store/mockstore.go b/pkg/files/store/mockstore.go similarity index 100% rename from internal/files/store/mockstore.go rename to pkg/files/store/mockstore.go diff --git a/internal/files/store/store.go b/pkg/files/store/store.go similarity index 99% rename from internal/files/store/store.go rename to pkg/files/store/store.go index 420df9d..769684f 100644 --- a/internal/files/store/store.go +++ b/pkg/files/store/store.go @@ -9,11 +9,11 @@ import ( "strings" "time" - "github.com/azure/peerd/internal/files" "github.com/azure/peerd/pkg/cache" pcontext "github.com/azure/peerd/pkg/context" "github.com/azure/peerd/pkg/discovery/content/reader" "github.com/azure/peerd/pkg/discovery/routing" + "github.com/azure/peerd/pkg/files" "github.com/azure/peerd/pkg/metrics" "github.com/azure/peerd/pkg/urlparser" "github.com/opencontainers/go-digest" diff --git a/internal/files/store/store_test.go b/pkg/files/store/store_test.go similarity index 98% rename from internal/files/store/store_test.go rename to pkg/files/store/store_test.go index 8ea0a57..2edd20c 100644 --- a/internal/files/store/store_test.go +++ b/pkg/files/store/store_test.go @@ -9,9 +9,9 @@ import ( "os" "testing" - "github.com/azure/peerd/internal/files" pcontext "github.com/azure/peerd/pkg/context" "github.com/azure/peerd/pkg/discovery/routing/mocks" + "github.com/azure/peerd/pkg/files" "github.com/gin-gonic/gin" "github.com/opencontainers/go-digest" )