Skip to content

Commit

Permalink
Merge pull request #2 from betalo-sweden/dev-4122-reverse-proxy
Browse files Browse the repository at this point in the history
[DEV-4122] Use reverse proxy for http destinations
  • Loading branch information
alesr authored Mar 13, 2018
2 parents 84fd6ea + e9dbdec commit b566d98
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 31 deletions.
32 changes: 21 additions & 11 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (

func main() {
var (
flagPEMPath = flag.String("pem", "", "Filepath to certificate")
flagCertPath = flag.String("cert", "", "Filepath to certificate")
flagKeyPath = flag.String("key", "", "Filepath to private key")
flagAddr = flag.String("addr", "", "Server address")
flagAuthUser = flag.String("user", "", "Server authentication username")
Expand All @@ -32,11 +32,20 @@ func main() {
flagServerReadHeaderTimeout = flag.Duration("serverreadheadertimeout", 30*time.Second, "Server read header timeout")
flagServerWriteTimeout = flag.Duration("serverwritetimeout", 30*time.Second, "Server write timeout")
flagServerIdleTimeout = flag.Duration("serveridletimeout", 30*time.Second, "Server idle timeout")
flagVerbose = flag.Bool("verbose", false, "Set log level to DEBUG")
)

flag.Parse()

c := zap.NewProductionConfig()
c.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder

if *flagVerbose {
c.Level.SetLevel(zapcore.DebugLevel)
} else {
c.Level.SetLevel(zapcore.ErrorLevel)
}

logger, err := c.Build()
if err != nil {
log.Fatalln("Error: failed to initiate logger")
Expand All @@ -45,14 +54,15 @@ func main() {
stdLogger := zap.NewStdLog(logger)

p := &Proxy{
Logger: logger,
AuthUser: *flagAuthUser,
AuthPass: *flagAuthPass,
DestDialTimeout: *flagDestDialTimeout,
DestReadTimeout: *flagDestReadTimeout,
DestWriteTimeout: *flagDestWriteTimeout,
ClientReadTimeout: *flagClientReadTimeout,
ClientWriteTimeout: *flagClientWriteTimeout,
ForwardingHTTPProxy: NewForwardingHTTPProxy(stdLogger),
Logger: logger,
AuthUser: *flagAuthUser,
AuthPass: *flagAuthPass,
DestDialTimeout: *flagDestDialTimeout,
DestReadTimeout: *flagDestReadTimeout,
DestWriteTimeout: *flagDestWriteTimeout,
ClientReadTimeout: *flagClientReadTimeout,
ClientWriteTimeout: *flagClientWriteTimeout,
}

s := &http.Server{
Expand Down Expand Up @@ -82,8 +92,8 @@ func main() {
p.Logger.Info("Server starting", zap.String("address", s.Addr))

var svrErr error
if *flagPEMPath != "" && *flagKeyPath != "" {
svrErr = s.ListenAndServeTLS(*flagPEMPath, *flagKeyPath)
if *flagCertPath != "" && *flagKeyPath != "" {
svrErr = s.ListenAndServeTLS(*flagCertPath, *flagKeyPath)
} else {
svrErr = s.ListenAndServe()
}
Expand Down
72 changes: 52 additions & 20 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ package main
import (
"encoding/base64"
"io"
"log"
"net"
"net/http"
"net/http/httputil"
"strings"
"time"

Expand All @@ -19,39 +21,49 @@ import (

// Proxy is a HTTPS forward proxy.
type Proxy struct {
Logger *zap.Logger
AuthUser string
AuthPass string
DestDialTimeout time.Duration
DestReadTimeout time.Duration
DestWriteTimeout time.Duration
ClientReadTimeout time.Duration
ClientWriteTimeout time.Duration
Logger *zap.Logger
AuthUser string
AuthPass string
ForwardingHTTPProxy *httputil.ReverseProxy
DestDialTimeout time.Duration
DestReadTimeout time.Duration
DestWriteTimeout time.Duration
ClientReadTimeout time.Duration
ClientWriteTimeout time.Duration
}

func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
p.Logger.Info("Incoming request", zap.String("host", r.Host))

if r.Method != http.MethodConnect {
p.Logger.Info("Method not allowed:", zap.String("method", r.Method))
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}

if p.AuthUser != "" && p.AuthPass != "" {
user, pass, ok := parseBasicProxyAuth(r.Header.Get("Proxy-Authenticate"))
user, pass, ok := parseBasicProxyAuth(r.Header.Get("Proxy-Authorization"))
if !ok || user != p.AuthUser || pass != p.AuthPass {
p.Logger.Warn("Authentication attempt with invalid credentials")
p.Logger.Warn("Authorization attempt with invalid credentials")
http.Error(w, http.StatusText(http.StatusProxyAuthRequired), http.StatusProxyAuthRequired)
return
}
}

p.connect(w, r)
if r.URL.Scheme == "http" {
p.handleHTTP(w, r)
} else {
p.handleTunneling(w, r)
}
}

func (p *Proxy) handleHTTP(w http.ResponseWriter, r *http.Request) {
p.Logger.Debug("Got HTTP request", zap.String("host", r.Host))
p.ForwardingHTTPProxy.ServeHTTP(w, r)
}

func (p *Proxy) connect(w http.ResponseWriter, r *http.Request) {
p.Logger.Debug("Connecting:", zap.String("host", r.Host))
func (p *Proxy) handleTunneling(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
p.Logger.Info("Method not allowed", zap.String("method", r.Method))
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}

p.Logger.Debug("Connecting", zap.String("host", r.Host))

destConn, err := net.DialTimeout("tcp", r.Host, p.DestDialTimeout)
if err != nil {
Expand All @@ -64,7 +76,7 @@ func (p *Proxy) connect(w http.ResponseWriter, r *http.Request) {

w.WriteHeader(http.StatusOK)

p.Logger.Debug("Hijacking:", zap.String("host", r.Host))
p.Logger.Debug("Hijacking", zap.String("host", r.Host))

hijacker, ok := w.(http.Hijacker)
if !ok {
Expand Down Expand Up @@ -115,3 +127,23 @@ func parseBasicProxyAuth(auth string) (username, password string, ok bool) {
}
return cs[:s], cs[s+1:], true
}

// NewForwardingHTTPProxy retuns a new reverse proxy that takes an incoming
// request and sends it to another server, proxying the response back to the
// client.
//
// See: https://golang.org/pkg/net/http/httputil/#ReverseProxy
func NewForwardingHTTPProxy(logger *log.Logger) *httputil.ReverseProxy {
director := func(req *http.Request) {
if _, ok := req.Header["User-Agent"]; !ok {
// explicitly disable User-Agent so it's not set to default value
req.Header.Set("User-Agent", "")
}
}
// TODO:(alesr) Use timeouts specified via flags to customize the default
// transport used by the reverse proxy.
return &httputil.ReverseProxy{
ErrorLog: logger,
Director: director,
}
}

0 comments on commit b566d98

Please sign in to comment.