diff --git a/internal/strategy/gomod.go b/internal/strategy/gomod.go new file mode 100644 index 0000000..2380f85 --- /dev/null +++ b/internal/strategy/gomod.go @@ -0,0 +1,168 @@ +package strategy + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "net/url" + "strings" + "time" + + "github.com/block/cachew/internal/cache" + "github.com/block/cachew/internal/httputil" + "github.com/block/cachew/internal/jobscheduler" + "github.com/block/cachew/internal/logging" + "github.com/block/cachew/internal/strategy/handler" +) + +func init() { + Register("gomod", "Caches Go module proxy requests.", NewGoMod) +} + +// GoModConfig represents the configuration for the Go module proxy strategy. +// +// In HCL it looks like: +// +// gomod { +// proxy = "https://proxy.golang.org" +// } +type GoModConfig struct { + Proxy string `hcl:"proxy,optional" help:"Upstream Go module proxy URL (defaults to proxy.golang.org)" default:"https://proxy.golang.org"` + MutableTTL time.Duration `hcl:"mutable-ttl,optional" help:"TTL for mutable Go module proxy endpoints (list, latest). Defaults to 5m." default:"5m"` + ImmutableTTL time.Duration `hcl:"immutable-ttl,optional" help:"TTL for immutable Go module proxy endpoints (versioned info, mod, zip). Defaults to 168h (7 days)." default:"168h"` +} + +// The GoMod strategy implements a caching proxy for the Go module proxy protocol. +// +// It supports all standard GOPROXY endpoints: +// - /$module/@v/list - Lists available versions +// - /$module/@v/$version.info - Version metadata JSON +// - /$module/@v/$version.mod - go.mod file +// - /$module/@v/$version.zip - Module source code +// - /$module/@latest - Latest version info +// +// The strategy uses differential caching: short TTL (5 minutes) for mutable +// endpoints (list, latest) and long TTL (7 days) for immutable versioned content. +type GoMod struct { + config GoModConfig + cache cache.Cache + client *http.Client + logger *slog.Logger + proxy *url.URL +} + +var _ Strategy = (*GoMod)(nil) + +// NewGoMod creates a new Go module proxy strategy. +func NewGoMod(ctx context.Context, config GoModConfig, scheduler jobscheduler.Scheduler, cache cache.Cache, mux Mux) (*GoMod, error) { + parsedURL, err := url.Parse(config.Proxy) + if err != nil { + return nil, fmt.Errorf("invalid proxy URL: %w", err) + } + + g := &GoMod{ + config: config, + cache: cache, + client: http.DefaultClient, + logger: logging.FromContext(ctx), + proxy: parsedURL, + } + + g.logger.InfoContext(ctx, "Initialized Go module proxy strategy", + slog.String("proxy", g.proxy.String())) + + // Create handler with caching configuration + h := handler.New(g.client, g.cache). + CacheKey(func(r *http.Request) string { + return g.buildUpstreamURL(r).String() + }). + Transform(g.transformRequest). + TTL(g.calculateTTL) + + // Register a namespaced handler for Go module proxy patterns + mux.Handle("GET /gomod/{path...}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + // Check if this is a valid Go module proxy endpoint + if g.isGoModulePath(path) { + h.ServeHTTP(w, r) + return + } + http.NotFound(w, r) + })) + + return g, nil +} + +// isGoModulePath checks if the path matches a valid Go module proxy endpoint pattern. +func (g *GoMod) isGoModulePath(path string) bool { + // Strip the /gomod prefix before checking the pattern + path = strings.TrimPrefix(path, "/gomod") + + // Valid patterns: + // - /@v/list + // - /@v/{version}.info + // - /@v/{version}.mod + // - /@v/{version}.zip + // - /@latest + return strings.HasSuffix(path, "/@v/list") || + strings.HasSuffix(path, "/@latest") || + (strings.Contains(path, "/@v/") && + (strings.HasSuffix(path, ".info") || + strings.HasSuffix(path, ".mod") || + strings.HasSuffix(path, ".zip"))) +} + +func (g *GoMod) String() string { + return "gomod:" + g.proxy.Host +} + +// buildUpstreamURL constructs the full upstream URL from the incoming request. +func (g *GoMod) buildUpstreamURL(r *http.Request) *url.URL { + // The full path includes the module path and the endpoint + // e.g., /gomod/github.com/user/repo/@v/v1.0.0.info + // We need to strip the /gomod prefix before forwarding to the upstream proxy + path := r.URL.Path + path = strings.TrimPrefix(path, "/gomod") + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + + targetURL := *g.proxy + targetURL.Path = g.proxy.Path + path + targetURL.RawQuery = r.URL.RawQuery + + return &targetURL +} + +// transformRequest creates the upstream request to the Go module proxy. +func (g *GoMod) transformRequest(r *http.Request) (*http.Request, error) { + targetURL := g.buildUpstreamURL(r) + + g.logger.DebugContext(r.Context(), "Transforming Go module request", + slog.String("original_path", r.URL.Path), + slog.String("upstream_url", targetURL.String())) + + req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, targetURL.String(), nil) + if err != nil { + return nil, httputil.Errorf(http.StatusInternalServerError, "create upstream request: %w", err) + } + + return req, nil +} + +// calculateTTL returns the appropriate cache TTL based on the endpoint type. +// +// Mutable endpoints (list, latest) get short TTL (5 minutes). +// Immutable versioned content (info, mod, zip) gets long TTL (7 days). +func (g *GoMod) calculateTTL(r *http.Request) time.Duration { + path := r.URL.Path + + // Short TTL for mutable endpoints + if strings.HasSuffix(path, "/@v/list") || strings.HasSuffix(path, "/@latest") { + return g.config.MutableTTL + } + + // Long TTL for immutable versioned content (.info, .mod, .zip) + return g.config.ImmutableTTL +} diff --git a/internal/strategy/gomod_test.go b/internal/strategy/gomod_test.go new file mode 100644 index 0000000..b60ac85 --- /dev/null +++ b/internal/strategy/gomod_test.go @@ -0,0 +1,324 @@ +package strategy_test + +import ( + "context" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/alecthomas/assert/v2" + + "github.com/block/cachew/internal/cache" + "github.com/block/cachew/internal/jobscheduler" + "github.com/block/cachew/internal/logging" + "github.com/block/cachew/internal/strategy" +) + +type mockGoModServer struct { + server *httptest.Server + requestCount map[string]int // Track requests by path + lastPath string + responses map[string]mockResponse +} + +type mockResponse struct { + status int + content string +} + +func newMockGoModServer() *mockGoModServer { + m := &mockGoModServer{ + requestCount: make(map[string]int), + responses: make(map[string]mockResponse), + } + + // Set up default responses for common endpoints + m.responses["/@v/list"] = mockResponse{ + status: http.StatusOK, + content: "v1.0.0\nv1.0.1\nv1.1.0\n", + } + m.responses["/@v/v1.0.0.info"] = mockResponse{ + status: http.StatusOK, + content: `{"Version":"v1.0.0","Time":"2023-01-01T00:00:00Z"}`, + } + m.responses["/@v/v1.0.0.mod"] = mockResponse{ + status: http.StatusOK, + content: "module github.com/example/test\n\ngo 1.21\n", + } + m.responses["/@v/v1.0.0.zip"] = mockResponse{ + status: http.StatusOK, + content: "PK\x03\x04...", // Mock zip content + } + m.responses["/@latest"] = mockResponse{ + status: http.StatusOK, + content: `{"Version":"v1.1.0","Time":"2023-06-01T00:00:00Z"}`, + } + + mux := http.NewServeMux() + mux.HandleFunc("/", m.handleRequest) + m.server = httptest.NewServer(mux) + + return m +} + +func (m *mockGoModServer) handleRequest(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + m.lastPath = path + m.requestCount[path]++ + + // Find matching response + var resp mockResponse + found := false + + // Try exact match first + if r, ok := m.responses[path]; ok { + resp = r + found = true + } else { + // Try suffix match for module paths + for suffix, r := range m.responses { + if len(path) >= len(suffix) && path[len(path)-len(suffix):] == suffix { + resp = r + found = true + break + } + } + } + + // If still not found, try pattern matching for any version + if !found && strings.Contains(path, "/@v/") { + switch { + case strings.HasSuffix(path, ".info"): + resp = mockResponse{ + status: http.StatusOK, + content: `{"Version":"v1.0.0","Time":"2023-01-01T00:00:00Z"}`, + } + found = true + case strings.HasSuffix(path, ".mod"): + resp = mockResponse{ + status: http.StatusOK, + content: "module github.com/example/test\n\ngo 1.21\n", + } + found = true + case strings.HasSuffix(path, ".zip"): + resp = mockResponse{ + status: http.StatusOK, + content: "PK\x03\x04...", + } + found = true + } + } + + if !found { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte("not found")) + return + } + + w.WriteHeader(resp.status) + _, _ = w.Write([]byte(resp.content)) +} + +func (m *mockGoModServer) close() { + m.server.Close() +} + +func (m *mockGoModServer) setResponse(path string, status int, content string) { + m.responses[path] = mockResponse{ + status: status, + content: content, + } +} + +func setupGoModTest(t *testing.T) (*mockGoModServer, *http.ServeMux, context.Context) { + t.Helper() + + mock := newMockGoModServer() + t.Cleanup(mock.close) + + _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) + + memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: 24 * time.Hour}) + assert.NoError(t, err) + t.Cleanup(func() { _ = memCache.Close() }) + + mux := http.NewServeMux() + _, err = strategy.NewGoMod(ctx, strategy.GoModConfig{ + Proxy: mock.server.URL, + MutableTTL: 5 * time.Minute, + ImmutableTTL: 168 * time.Hour, + }, jobscheduler.New(ctx, jobscheduler.Config{}), memCache, mux) + assert.NoError(t, err) + + return mock, mux, ctx +} + +func TestGoModList(t *testing.T) { + mock, mux, ctx := setupGoModTest(t) + + req := httptest.NewRequest(http.MethodGet, "/gomod/github.com/example/test/@v/list", nil) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + mux.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "v1.0.0\nv1.0.1\nv1.1.0\n", w.Body.String()) + assert.Equal(t, 1, mock.requestCount["/github.com/example/test/@v/list"]) +} + +func TestGoModInfo(t *testing.T) { + mock, mux, ctx := setupGoModTest(t) + + req := httptest.NewRequest(http.MethodGet, "/gomod/github.com/example/test/@v/v1.0.0.info", nil) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + mux.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, `{"Version":"v1.0.0","Time":"2023-01-01T00:00:00Z"}`, w.Body.String()) + assert.Equal(t, 1, mock.requestCount["/github.com/example/test/@v/v1.0.0.info"]) +} + +func TestGoModMod(t *testing.T) { + mock, mux, ctx := setupGoModTest(t) + + req := httptest.NewRequest(http.MethodGet, "/gomod/github.com/example/test/@v/v1.0.0.mod", nil) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + mux.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "module github.com/example/test\n\ngo 1.21\n", w.Body.String()) + assert.Equal(t, 1, mock.requestCount["/github.com/example/test/@v/v1.0.0.mod"]) +} + +func TestGoModZip(t *testing.T) { + mock, mux, ctx := setupGoModTest(t) + + req := httptest.NewRequest(http.MethodGet, "/gomod/github.com/example/test/@v/v1.0.0.zip", nil) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + mux.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "PK\x03\x04...", w.Body.String()) + assert.Equal(t, 1, mock.requestCount["/github.com/example/test/@v/v1.0.0.zip"]) +} + +func TestGoModLatest(t *testing.T) { + mock, mux, ctx := setupGoModTest(t) + + req := httptest.NewRequest(http.MethodGet, "/gomod/github.com/example/test/@latest", nil) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + mux.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, `{"Version":"v1.1.0","Time":"2023-06-01T00:00:00Z"}`, w.Body.String()) + assert.Equal(t, 1, mock.requestCount["/github.com/example/test/@latest"]) +} + +func TestGoModCaching(t *testing.T) { + mock, mux, ctx := setupGoModTest(t) + + path := "/gomod/github.com/example/test/@v/v1.0.0.info" + upstreamPath := "/github.com/example/test/@v/v1.0.0.info" + + // First request + req1 := httptest.NewRequest(http.MethodGet, path, nil) + req1 = req1.WithContext(ctx) + w1 := httptest.NewRecorder() + mux.ServeHTTP(w1, req1) + + assert.Equal(t, http.StatusOK, w1.Code) + assert.Equal(t, 1, mock.requestCount[upstreamPath]) + + // Second request should hit cache + req2 := httptest.NewRequest(http.MethodGet, path, nil) + req2 = req2.WithContext(ctx) + w2 := httptest.NewRecorder() + mux.ServeHTTP(w2, req2) + + assert.Equal(t, http.StatusOK, w2.Code) + assert.Equal(t, w1.Body.String(), w2.Body.String()) + assert.Equal(t, 1, mock.requestCount[upstreamPath], "second request should be served from cache") +} + +func TestGoModComplexModulePath(t *testing.T) { + mock, mux, ctx := setupGoModTest(t) + + // Test module path with multiple slashes + req := httptest.NewRequest(http.MethodGet, "/gomod/golang.org/x/tools/@v/v0.1.0.info", nil) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + mux.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, 1, mock.requestCount["/golang.org/x/tools/@v/v0.1.0.info"]) +} + +func TestGoModNonOKResponse(t *testing.T) { + mock, mux, ctx := setupGoModTest(t) + + // Set up 404 response + upstreamPath := "/github.com/example/nonexistent/@v/v99.0.0.info" + notFoundPath := "/gomod" + upstreamPath + mock.setResponse(upstreamPath, http.StatusNotFound, "not found") + + // First request should return 404 + req1 := httptest.NewRequest(http.MethodGet, notFoundPath, nil) + req1 = req1.WithContext(ctx) + w1 := httptest.NewRecorder() + mux.ServeHTTP(w1, req1) + + assert.Equal(t, http.StatusNotFound, w1.Code) + assert.Equal(t, 1, mock.requestCount[upstreamPath]) + + // Second request should also hit upstream (404s are not cached) + req2 := httptest.NewRequest(http.MethodGet, notFoundPath, nil) + req2 = req2.WithContext(ctx) + w2 := httptest.NewRecorder() + mux.ServeHTTP(w2, req2) + + assert.Equal(t, http.StatusNotFound, w2.Code) + assert.Equal(t, 2, mock.requestCount[upstreamPath], "404 responses should not be cached") +} + +func TestGoModMultipleConcurrentRequests(t *testing.T) { + mock, mux, ctx := setupGoModTest(t) + + path := "/gomod/github.com/example/test/@v/v1.0.0.zip" + upstreamPath := "/github.com/example/test/@v/v1.0.0.zip" + + // Make multiple concurrent requests + results := make(chan *httptest.ResponseRecorder, 3) + for range 3 { + go func() { + req := httptest.NewRequest(http.MethodGet, path, nil) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + results <- w + }() + } + + // Collect results + for range 3 { + w := <-results + assert.Equal(t, http.StatusOK, w.Code) + } + + // First request should have created the cache entry + // Subsequent requests might hit cache or might be in-flight + // We just verify all requests succeeded + assert.True(t, mock.requestCount[upstreamPath] >= 1, "at least one request should have been made to upstream") +}