diff --git a/integration_test.go b/integration_test.go index e6823b8..eeaf662 100644 --- a/integration_test.go +++ b/integration_test.go @@ -23,7 +23,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" - "golang.org/x/sync/singleflight" ) const containerName string = "ctile_integration_test_minio" @@ -144,7 +143,7 @@ func TestIntegration(t *testing.T) { t.Fatal(err) } - ctile := makeTCH(server.URL, s3Service) + ctile := makeTCH(t, server.URL, s3Service) // Invalid URL; should 404 passed through to backend and 400 resp := getResp(ctile, "/foo") @@ -277,7 +276,7 @@ func TestIntegration(t *testing.T) { })) defer server.Close() - erroringCTile := makeTCH(errorCTLog.URL, s3Service) + erroringCTile := makeTCH(t, errorCTLog.URL, s3Service) resp = getResp(erroringCTile, "/ct/v1/get-entries?start=0&end=1") if resp.StatusCode != 500 { t.Errorf("expected 500 got %d", resp.StatusCode) @@ -285,7 +284,7 @@ func TestIntegration(t *testing.T) { expectAndResetMetric(t, erroringCTile.requestsMetric, 1, "error", "ct_log_get") } -func getResp(ctile tileCachingHandler, url string) *http.Response { +func getResp(ctile *tileCachingHandler, url string) *http.Response { req := httptest.NewRequest("GET", url, nil) w := httptest.NewRecorder() @@ -294,7 +293,7 @@ func getResp(ctile tileCachingHandler, url string) *http.Response { return w.Result() } -func getAndParseResp(t *testing.T, ctile tileCachingHandler, url string) (entries, http.Header, error) { +func getAndParseResp(t *testing.T, ctile *tileCachingHandler, url string) (entries, http.Header, error) { t.Helper() resp := getResp(ctile, url) body, _ := io.ReadAll(resp.Body) @@ -321,21 +320,10 @@ func expectAndResetMetric(t *testing.T, metric *prometheus.CounterVec, expected metric.Reset() } -func makeTCH(url string, s3Service *s3.Client) tileCachingHandler { - return tileCachingHandler{ - logURL: url, - tileSize: 3, - - s3Service: s3Service, - s3Prefix: "test", - s3Bucket: "bucket", - - cacheGroup: &singleflight.Group{}, - - fullRequestTimeout: 10 * time.Second, - - requestsMetric: prometheus.NewCounterVec(prometheus.CounterOpts{Help: "foo", Name: "ctile_requests"}, []string{"result", "source"}), - partialTiles: prometheus.NewCounter(prometheus.CounterOpts{Name: "ctile_partial_tiles"}), - singleFlightShared: prometheus.NewCounter(prometheus.CounterOpts{Name: "ctile_singleflight_shared"}), +func makeTCH(t *testing.T, url string, s3Service *s3.Client) *tileCachingHandler { + tch, err := newTileCachingHandler(url, 3, s3Service, "test", "bucket", 10*time.Second, prometheus.NewRegistry()) + if err != nil { + t.Fatal(err) } + return tch } diff --git a/main.go b/main.go index 9deece0..021f722 100644 --- a/main.go +++ b/main.go @@ -294,6 +294,70 @@ type tileCachingHandler struct { fullRequestTimeout time.Duration } +func newTileCachingHandler( + logURL string, + tileSize int, + s3Service *s3.Client, + s3Prefix string, + s3Bucket string, + fullRequestTimeout time.Duration, + promRegisterer prometheus.Registerer, +) (*tileCachingHandler, error) { + if logURL == "" { + return nil, errors.New("logURL must not be empty") + } + if tileSize == 0 { + return nil, errors.New("tileSize must not be zero") + } + if s3Service == nil { + return nil, errors.New("s3Service must not be nil") + } + if s3Prefix == "" { + return nil, errors.New("s3Prefix must not be empty") + } + if s3Bucket == "" { + return nil, errors.New("s3Bucket must not be empty") + } + if fullRequestTimeout == 0 { + return nil, errors.New("fullRequestTimeout must not be zero") + } + requestsMetric := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "ctile_requests", + Help: "total number of requests, by result and source", + }, + []string{"result", "source"}, + ) + promRegisterer.MustRegister(requestsMetric) + + partialTiles := prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "ctile_partial_tiles", + Help: "number of requests not cached due to partial tile returned from CT log", + }) + promRegisterer.MustRegister(partialTiles) + + singleFlightShared := prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "ctile_single_flight_shared", + Help: "number of inbound requests coalesced into a single set of backend requests", + }) + promRegisterer.MustRegister(singleFlightShared) + + return &tileCachingHandler{ + logURL: logURL, + tileSize: tileSize, + s3Service: s3Service, + s3Prefix: s3Prefix, + s3Bucket: s3Bucket, + cacheGroup: &singleflight.Group{}, + requestsMetric: requestsMetric, + partialTiles: partialTiles, + singleFlightShared: singleFlightShared, + fullRequestTimeout: fullRequestTimeout, + }, nil +} + func (tch *tileCachingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // For non-get-entries requests, pass them along to the backend if !strings.HasSuffix(r.URL.Path, "/ct/v1/get-entries") { @@ -527,39 +591,9 @@ func main() { promRegistry := newStatsRegistry(*metricsAddress) - requestsMetric := prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "ctile_requests", - Help: "total number of requests, by result and source", - }, - []string{"result", "source"}, - ) - promRegistry.MustRegister(requestsMetric) - - partialTiles := prometheus.NewCounter( - prometheus.CounterOpts{ - Name: "ctile_partial_tiles", - Help: "number of requests not cached due to partial tile returned from CT log", - }) - promRegistry.MustRegister(partialTiles) - - singleFlightShared := prometheus.NewCounter( - prometheus.CounterOpts{ - Name: "ctile_single_flight_shared", - Help: "number of inbound requests coalesced into a single set of backend requests", - }) - promRegistry.MustRegister(singleFlightShared) - - handler := &tileCachingHandler{ - logURL: *logURL, - tileSize: *tileSize, - s3Service: svc, - s3Prefix: *s3prefix, - s3Bucket: *s3bucket, - cacheGroup: &singleflight.Group{}, - requestsMetric: requestsMetric, - partialTiles: partialTiles, - singleFlightShared: singleFlightShared, + handler, err := newTileCachingHandler(*logURL, *tileSize, svc, *s3prefix, *s3bucket, *fullRequestTimeout, promRegistry) + if err != nil { + log.Fatal(err) } srv := http.Server{