diff --git a/golink.go b/golink.go index 7c499a0..757c45b 100644 --- a/golink.go +++ b/golink.go @@ -139,15 +139,6 @@ func Run() error { // flush stats periodically go flushStatsLoop() - http.HandleFunc("/", serveGo) - http.HandleFunc("/.detail/", serveDetail) - http.HandleFunc("/.export", serveExport) - http.HandleFunc("/.help", serveHelp) - http.HandleFunc("/.opensearch", serveOpenSearch) - http.HandleFunc("/.all", serveAll) - http.HandleFunc("/.delete/", serveDelete) - http.Handle("/.static/", http.StripPrefix("/.", http.FileServer(http.FS(embeddedFS)))) - if *dev != "" { // override default hostname for dev mode if *hostname == defaultHostname { @@ -160,7 +151,7 @@ func Run() error { } log.Printf("Running in dev mode on %s ...", *dev) - log.Fatal(http.ListenAndServe(*dev, nil)) + log.Fatal(http.ListenAndServe(*dev, serveHandler())) } if *hostname == "" { @@ -295,6 +286,29 @@ func deleteLinkStats(link *Link) { db.DeleteStats(link.Short) } +// serverHandler returns the main http.Handler for serving all requests. +func serveHandler() http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/.detail/", serveDetail) + mux.HandleFunc("/.export", serveExport) + mux.HandleFunc("/.help", serveHelp) + mux.HandleFunc("/.opensearch", serveOpenSearch) + mux.HandleFunc("/.all", serveAll) + mux.HandleFunc("/.delete/", serveDelete) + mux.Handle("/.static/", http.StripPrefix("/.", http.FileServer(http.FS(embeddedFS)))) + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // all internal URLs begin with a leading "."; any other URL is treated as a go link. + // Serve go links directly without passing through the ServeMux, + // which sometimes modifies the request URL path, which we don't want. + if !strings.HasPrefix(r.URL.Path, "/.") { + serveGo(w, r) + return + } + mux.ServeHTTP(w, r) + }) +} + func serveHome(w http.ResponseWriter, short string) { var clicks []visitData @@ -408,7 +422,11 @@ func serveGo(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - http.Redirect(w, r, target.String(), http.StatusFound) + + // http.Redirect always cleans the redirect URL, which we don't always want. + // Instead, manually set status and Location header. + w.WriteHeader(http.StatusFound) + w.Header().Set("Location", target.String()) } // acceptHTML returns whether the request can accept a text/html response. diff --git a/golink_test.go b/golink_test.go index e67d502..9506add 100644 --- a/golink_test.go +++ b/golink_test.go @@ -69,6 +69,12 @@ func TestServeGo(t *testing.T) { wantStatus: http.StatusFound, wantLink: "http://who/p?q=1", }, + { + name: "simple link with double slash in path", + link: "/who/http://host", + wantStatus: http.StatusFound, + wantLink: "http://who/http://host", + }, { name: "user link", link: "/me", @@ -105,7 +111,7 @@ func TestServeGo(t *testing.T) { r := httptest.NewRequest("GET", tt.link, nil) w := httptest.NewRecorder() - serveGo(w, r) + serveHandler().ServeHTTP(w, r) if w.Code != tt.wantStatus { t.Errorf("serveGo(%q) = %d; want %d", tt.link, w.Code, tt.wantStatus)