From f557c061ed43dde2fbfb683ecba967ff8100baa0 Mon Sep 17 00:00:00 2001 From: lucasew Date: Wed, 16 Oct 2024 23:19:52 -0300 Subject: [PATCH] refactor around how upstream connections are made Signed-off-by: lucasew --- cmd/ts-proxyd/main.go | 14 +++++----- http.go | 62 +++++++++++++++++-------------------------- proxy.go | 28 ++++++++++--------- 3 files changed, 48 insertions(+), 56 deletions(-) diff --git a/cmd/ts-proxyd/main.go b/cmd/ts-proxyd/main.go index d5bfb0b..2f4d047 100644 --- a/cmd/ts-proxyd/main.go +++ b/cmd/ts-proxyd/main.go @@ -13,25 +13,25 @@ var options tsproxy.TailscaleProxyServerOptions func init() { var err error flag.StringVar(&options.Network, "net", "tcp", "Network, for net.Dial") - flag.StringVar(&options.Upstream, "h", "", "Where to forward the connection") + flag.StringVar(&options.Address, "address", "", "Where to forward the connection") flag.StringVar(&options.Hostname, "n", "", "Hostname in tailscale devices list") flag.BoolVar(&options.EnableFunnel, "f", false, "Enable tailscale funnel") flag.BoolVar(&options.EnableTLS, "t", false, "Enable HTTPS/TLS") flag.StringVar(&options.StateDir, "s", "", "State directory") - flag.StringVar(&options.Addr, "addr", "", "Port to listen") + flag.StringVar(&options.Listen, "listen", "", "Port to listen") flag.BoolVar(&options.EnableHTTP, "raw", false, "Disable HTTP handling") flag.Parse() options.EnableHTTP = !options.EnableHTTP - if options.Addr == "" && options.EnableHTTP { + if options.Listen == "" && options.EnableHTTP { if options.EnableFunnel || options.EnableTLS { - options.Addr = ":443" + options.Listen = ":443" } else { - options.Addr = ":80" + options.Listen = ":80" } } spew.Dump(options) - if options.Addr == "" { - panic("-addr not defined") + if options.Listen == "" { + panic("-listen not defined") } if err != nil { log.Fatal(err) diff --git a/http.go b/http.go index 25ff5c1..b8a7c23 100644 --- a/http.go +++ b/http.go @@ -1,11 +1,10 @@ package tsproxy import ( - "context" - "io" "log" "net" "net/http" + "net/http/httputil" "net/url" "github.com/davecgh/go-spew/spew" @@ -17,25 +16,21 @@ func init() { type TailscaleHTTPProxyServer struct { server *TailscaleProxyServer - client *http.Client - scheme string + proxy *httputil.ReverseProxy } func NewTailscaleHTTPProxyServer(server *TailscaleProxyServer) (Server, error) { - transport := &http.Transport{ - Dial: server.Dial, - } - client := &http.Client{ - Transport: transport, + u := &url.URL{ + Scheme: "http", + Host: server.Hostname(), } - parsedURL, err := url.Parse(server.options.Upstream) - if err != nil { - return nil, err + proxy := httputil.NewSingleHostReverseProxy(u) + proxy.Transport = &http.Transport{ + Dial: server.Dial, } return &TailscaleHTTPProxyServer{ server: server, - client: client, - scheme: parsedURL.Scheme, + proxy: proxy, }, nil } @@ -51,29 +46,22 @@ func (tps *TailscaleHTTPProxyServer) ServeHTTP(w http.ResponseWriter, r *http.Re w.WriteHeader(500) return } - log.Printf("got http conn") - defer log.Printf("http conn end") - ctx, cancel := context.WithCancel(r.Context()) - defer cancel() - req := r.Clone(ctx) - req.URL.Scheme = tps.scheme - req.URL.Host = "whatever-would-be-ignored-anyway" - req.RequestURI = "" - req.Header.Set("Tailscale-User-Login", userInfo.UserProfile.LoginName) - req.Header.Set("Tailscale-User-Name", userInfo.UserProfile.DisplayName) - req.Header.Set("Tailscale-User-Profile-Pic", userInfo.UserProfile.ProfilePicURL) - req.Header.Set("Tailscale-Headers-Info", "https://tailscale.com/s/serve-headers") - resp, err := tps.client.Do(req) - if err != nil { - log.Printf("error/http/proxy: %s", err.Error()) - w.WriteHeader(500) + if r.Host != tps.server.Hostname() { + destinationURL := new(url.URL) + *destinationURL = *r.URL + destinationURL.Host = tps.server.Hostname() + tps.server.options.Listen + if tps.server.options.EnableTLS { + destinationURL.Scheme = "https" + } else { + destinationURL.Scheme = "http" + } + http.Redirect(w, r, destinationURL.String(), http.StatusMovedPermanently) return } - for k, v := range resp.Header { - w.Header()[k] = v - } - w.WriteHeader(resp.StatusCode) - buf := bufferPool.Get().([]byte) - defer bufferPool.Put(buf) - io.CopyBuffer(w, resp.Body, buf) + log.Printf("%s %s %s %s", r.Method, userInfo.UserProfile.LoginName, r.Host, r.URL.String()) + r.Header.Set("Tailscale-User-Login", userInfo.UserProfile.LoginName) + r.Header.Set("Tailscale-User-Name", userInfo.UserProfile.DisplayName) + r.Header.Set("Tailscale-User-Profile-Pic", userInfo.UserProfile.ProfilePicURL) + r.Header.Set("Tailscale-Headers-Info", "https://tailscale.com/s/serve-headers") + tps.proxy.ServeHTTP(w, r) } diff --git a/proxy.go b/proxy.go index cb6f7ca..6381fe7 100644 --- a/proxy.go +++ b/proxy.go @@ -5,7 +5,6 @@ import ( "errors" "log" "net" - "net/url" "os" @@ -46,9 +45,9 @@ type TailscaleProxyServerOptions struct { // protocol to listen, passed to net.Dial Network string // where to forward requests - Upstream string + Address string // address to bind the server, passed to net.Dial - Addr string + Listen string } func NewTailscaleProxyServer(options TailscaleProxyServerOptions) (*TailscaleProxyServer, error) { @@ -57,11 +56,11 @@ func NewTailscaleProxyServer(options TailscaleProxyServerOptions) (*TailscalePro } ctx, cancel := context.WithCancel(options.Context) s := new(tsnet.Server) - s.Hostname = options.Hostname if options.Hostname == "" { - s.Hostname = "tsproxy" + options.Hostname = "tsproxy" } - if options.Upstream == "" { + s.Hostname = options.Hostname + if options.Address == "" { return nil, ErrInvalidUpstream } if options.StateDir != "" { @@ -83,6 +82,13 @@ func (tps *TailscaleProxyServer) listenFunnel(network string, addr string) (net. return tps.server.ListenFunnel(network, addr) } +func (tps *TailscaleProxyServer) Hostname() string { + for _, domain := range tps.server.CertDomains() { + return domain + } + return tps.options.Hostname +} + func (tps *TailscaleProxyServer) GetListenerFunction() ListenerFunction { if tps.options.EnableFunnel { return tps.listenFunnel @@ -94,15 +100,13 @@ func (tps *TailscaleProxyServer) GetListenerFunction() ListenerFunction { } func (tps *TailscaleProxyServer) GetListener() (net.Listener, error) { - return tps.GetListenerFunction()("tcp", tps.options.Addr) + return tps.GetListenerFunction()("tcp", tps.options.Listen) } func (tps *TailscaleProxyServer) Dial(network string, addr string) (net.Conn, error) { - u, err := url.Parse(tps.options.Upstream) - if err != nil { - return nil, err - } - return net.Dial(tps.options.Network, u.Host) + dialNetwork := tps.options.Network + dialHost := tps.options.Address + return net.Dial(dialNetwork, dialHost) } func (tps *TailscaleProxyServer) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) {