Skip to content

Commit

Permalink
refactor around how upstream connections are made
Browse files Browse the repository at this point in the history
Signed-off-by: lucasew <lucas59356@gmail.com>
  • Loading branch information
lucasew committed Oct 17, 2024
1 parent 97120bd commit f557c06
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 56 deletions.
14 changes: 7 additions & 7 deletions cmd/ts-proxyd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,25 @@ var options tsproxy.TailscaleProxyServerOptions
func init() {
var err error
flag.StringVar(&options.Network, "net", "tcp", "Network, for net.Dial")
flag.StringVar(&options.Upstream, "h", "", "Where to forward the connection")
flag.StringVar(&options.Address, "address", "", "Where to forward the connection")
flag.StringVar(&options.Hostname, "n", "", "Hostname in tailscale devices list")
flag.BoolVar(&options.EnableFunnel, "f", false, "Enable tailscale funnel")
flag.BoolVar(&options.EnableTLS, "t", false, "Enable HTTPS/TLS")
flag.StringVar(&options.StateDir, "s", "", "State directory")
flag.StringVar(&options.Addr, "addr", "", "Port to listen")
flag.StringVar(&options.Listen, "listen", "", "Port to listen")
flag.BoolVar(&options.EnableHTTP, "raw", false, "Disable HTTP handling")
flag.Parse()
options.EnableHTTP = !options.EnableHTTP
if options.Addr == "" && options.EnableHTTP {
if options.Listen == "" && options.EnableHTTP {
if options.EnableFunnel || options.EnableTLS {
options.Addr = ":443"
options.Listen = ":443"
} else {
options.Addr = ":80"
options.Listen = ":80"
}
}
spew.Dump(options)
if options.Addr == "" {
panic("-addr not defined")
if options.Listen == "" {
panic("-listen not defined")
}
if err != nil {
log.Fatal(err)
Expand Down
62 changes: 25 additions & 37 deletions http.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package tsproxy

import (
"context"
"io"
"log"
"net"
"net/http"
"net/http/httputil"
"net/url"

"github.com/davecgh/go-spew/spew"
Expand All @@ -17,25 +16,21 @@ func init() {

type TailscaleHTTPProxyServer struct {
server *TailscaleProxyServer
client *http.Client
scheme string
proxy *httputil.ReverseProxy
}

func NewTailscaleHTTPProxyServer(server *TailscaleProxyServer) (Server, error) {
transport := &http.Transport{
Dial: server.Dial,
}
client := &http.Client{
Transport: transport,
u := &url.URL{
Scheme: "http",
Host: server.Hostname(),
}
parsedURL, err := url.Parse(server.options.Upstream)
if err != nil {
return nil, err
proxy := httputil.NewSingleHostReverseProxy(u)
proxy.Transport = &http.Transport{
Dial: server.Dial,
}
return &TailscaleHTTPProxyServer{
server: server,
client: client,
scheme: parsedURL.Scheme,
proxy: proxy,
}, nil
}

Expand All @@ -51,29 +46,22 @@ func (tps *TailscaleHTTPProxyServer) ServeHTTP(w http.ResponseWriter, r *http.Re
w.WriteHeader(500)
return
}
log.Printf("got http conn")
defer log.Printf("http conn end")
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
req := r.Clone(ctx)
req.URL.Scheme = tps.scheme
req.URL.Host = "whatever-would-be-ignored-anyway"
req.RequestURI = ""
req.Header.Set("Tailscale-User-Login", userInfo.UserProfile.LoginName)
req.Header.Set("Tailscale-User-Name", userInfo.UserProfile.DisplayName)
req.Header.Set("Tailscale-User-Profile-Pic", userInfo.UserProfile.ProfilePicURL)
req.Header.Set("Tailscale-Headers-Info", "https://tailscale.com/s/serve-headers")
resp, err := tps.client.Do(req)
if err != nil {
log.Printf("error/http/proxy: %s", err.Error())
w.WriteHeader(500)
if r.Host != tps.server.Hostname() {
destinationURL := new(url.URL)
*destinationURL = *r.URL
destinationURL.Host = tps.server.Hostname() + tps.server.options.Listen
if tps.server.options.EnableTLS {
destinationURL.Scheme = "https"
} else {
destinationURL.Scheme = "http"
}
http.Redirect(w, r, destinationURL.String(), http.StatusMovedPermanently)
return
}
for k, v := range resp.Header {
w.Header()[k] = v
}
w.WriteHeader(resp.StatusCode)
buf := bufferPool.Get().([]byte)
defer bufferPool.Put(buf)
io.CopyBuffer(w, resp.Body, buf)
log.Printf("%s %s %s %s", r.Method, userInfo.UserProfile.LoginName, r.Host, r.URL.String())
r.Header.Set("Tailscale-User-Login", userInfo.UserProfile.LoginName)
r.Header.Set("Tailscale-User-Name", userInfo.UserProfile.DisplayName)
r.Header.Set("Tailscale-User-Profile-Pic", userInfo.UserProfile.ProfilePicURL)
r.Header.Set("Tailscale-Headers-Info", "https://tailscale.com/s/serve-headers")
tps.proxy.ServeHTTP(w, r)
}
28 changes: 16 additions & 12 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"log"
"net"
"net/url"

"os"

Expand Down Expand Up @@ -46,9 +45,9 @@ type TailscaleProxyServerOptions struct {
// protocol to listen, passed to net.Dial
Network string
// where to forward requests
Upstream string
Address string
// address to bind the server, passed to net.Dial
Addr string
Listen string
}

func NewTailscaleProxyServer(options TailscaleProxyServerOptions) (*TailscaleProxyServer, error) {
Expand All @@ -57,11 +56,11 @@ func NewTailscaleProxyServer(options TailscaleProxyServerOptions) (*TailscalePro
}
ctx, cancel := context.WithCancel(options.Context)
s := new(tsnet.Server)
s.Hostname = options.Hostname
if options.Hostname == "" {
s.Hostname = "tsproxy"
options.Hostname = "tsproxy"
}
if options.Upstream == "" {
s.Hostname = options.Hostname
if options.Address == "" {
return nil, ErrInvalidUpstream
}
if options.StateDir != "" {
Expand All @@ -83,6 +82,13 @@ func (tps *TailscaleProxyServer) listenFunnel(network string, addr string) (net.
return tps.server.ListenFunnel(network, addr)
}

func (tps *TailscaleProxyServer) Hostname() string {
for _, domain := range tps.server.CertDomains() {
return domain
}
return tps.options.Hostname
}

func (tps *TailscaleProxyServer) GetListenerFunction() ListenerFunction {
if tps.options.EnableFunnel {
return tps.listenFunnel
Expand All @@ -94,15 +100,13 @@ func (tps *TailscaleProxyServer) GetListenerFunction() ListenerFunction {
}

func (tps *TailscaleProxyServer) GetListener() (net.Listener, error) {
return tps.GetListenerFunction()("tcp", tps.options.Addr)
return tps.GetListenerFunction()("tcp", tps.options.Listen)
}

func (tps *TailscaleProxyServer) Dial(network string, addr string) (net.Conn, error) {
u, err := url.Parse(tps.options.Upstream)
if err != nil {
return nil, err
}
return net.Dial(tps.options.Network, u.Host)
dialNetwork := tps.options.Network
dialHost := tps.options.Address
return net.Dial(dialNetwork, dialHost)
}

func (tps *TailscaleProxyServer) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) {
Expand Down

0 comments on commit f557c06

Please sign in to comment.