Skip to content

Commit

Permalink
Add passthrough handler (#16)
Browse files Browse the repository at this point in the history
This passes through GET requests other than get-entries to the CT
backend. This is particularly useful for some log scanning tools that
call `/ct/v1/get-sth` before beginning their scan.

This is mainly intended as a convenience for small scale testing. In
production we plan to bypass this tool for all request paths other than
/ct/v1/get-entries.
  • Loading branch information
jsha authored Sep 7, 2023
1 parent 6d33b73 commit 7c4dc15
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 11 deletions.
6 changes: 3 additions & 3 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,10 @@ func TestIntegration(t *testing.T) {
singleFlightShared: prometheus.NewCounter(prometheus.CounterOpts{}),
}

// Invalid URL; should 404
// Invalid URL; should be passed through to backend and 400
resp := getResp(ctile, "/foo")
if resp.StatusCode != 404 {
t.Errorf("expected 404 got %d", resp.StatusCode)
if resp.StatusCode != 400 {
t.Errorf("expected 400 got %d", resp.StatusCode)
}

// Malformed queries; should 400
Expand Down
46 changes: 38 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"net/url"
"os"
"strconv"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand Down Expand Up @@ -293,12 +292,6 @@ type tileCachingHandler struct {
}

func (tch *tileCachingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.HasSuffix(r.URL.Path, "/ct/v1/get-entries") {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "invalid path %q\n", r.URL.Path)
return
}

start, end, err := parseQueryParams(r.URL.Query())
if err != nil {
w.WriteHeader(http.StatusBadRequest)
Expand Down Expand Up @@ -436,6 +429,39 @@ func singleflightDo[V any](group *singleflight.Group, key string, fn func() (V,
return out.(V), err, shared
}

// passthroughHandler is an HTTP handler that passes through GET requests to the CT log.
type passthroughHandler struct {
logURL string
}

func (p passthroughHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
fmt.Fprintln(w, "only GET is supported")
return
}
url := fmt.Sprintf("%s%s", p.logURL, r.URL.Path)
req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, url, nil)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "creating request: %s\n", err)
return
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "fetching %s: %s\n", url, err)
return
}
defer resp.Body.Close()

w.WriteHeader(resp.StatusCode)
_, err = io.Copy(w, resp.Body)
if err != nil {
log.Printf("copying response body to client: %s\n", err)
}
}

func main() {
logURL := flag.String("log-url", "", "CT log URL. e.g. https://oak.ct.letsencrypt.org/2023")
tileSize := flag.Int("tile-size", 0, "tile size. Must match the value used by the backend")
Expand Down Expand Up @@ -512,13 +538,17 @@ func main() {
singleFlightShared: singleFlightShared,
}

mux := http.NewServeMux()
mux.Handle("/ct/v1/get-entries", handler)
mux.Handle("/ct/v1/", passthroughHandler{logURL: *logURL})

srv := http.Server{
Addr: *listenAddress,
ReadTimeout: 5 * time.Second,
WriteTimeout: *fullRequestTimeout + 1*time.Second, // must be a bit larger than the max time spent in the HTTP handler
IdleTimeout: 5 * time.Minute,
ReadHeaderTimeout: 2 * time.Second,
Handler: http.TimeoutHandler(handler, *fullRequestTimeout, "full request timeout"),
Handler: http.TimeoutHandler(mux, *fullRequestTimeout, "full request timeout"),
}

log.Fatal(srv.ListenAndServe())
Expand Down

0 comments on commit 7c4dc15

Please sign in to comment.