diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cfaad76 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.pem diff --git a/cert.go b/cert.go index 6f22e1a..947c763 100644 --- a/cert.go +++ b/cert.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "math/big" + "net" "time" ) @@ -27,46 +28,8 @@ const ( leafUsage = caUsage ) -func genCert(ca *tls.Certificate, names []string) (*tls.Certificate, error) { - now := time.Now().Add(-1 * time.Hour).UTC() - if !ca.Leaf.IsCA { - return nil, errors.New("CA cert is not a CA") - } - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - if err != nil { - return nil, fmt.Errorf("failed to generate serial number: %s", err) - } - tmpl := &x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{CommonName: names[0]}, - NotBefore: now, - NotAfter: now.Add(leafMaxAge), - KeyUsage: leafUsage, - BasicConstraintsValid: true, - DNSNames: names, - SignatureAlgorithm: x509.ECDSAWithSHA512, - } - key, err := genKeyPair() - if err != nil { - return nil, err - } - x, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Leaf, key.Public(), ca.PrivateKey) - if err != nil { - return nil, err - } - cert := new(tls.Certificate) - cert.Certificate = append(cert.Certificate, x) - cert.PrivateKey = key - cert.Leaf, _ = x509.ParseCertificate(x) - return cert, nil -} - -func genKeyPair() (*ecdsa.PrivateKey, error) { - return ecdsa.GenerateKey(elliptic.P521(), rand.Reader) -} - -func GenCA(name string) (certPEM, keyPEM []byte, err error) { +// GenerateCA generates a CA cert and key pair. +func GenerateCA(name string) (certPEM, keyPEM []byte, err error) { now := time.Now().UTC() tmpl := &x509.Certificate{ SerialNumber: big.NewInt(1), @@ -101,3 +64,51 @@ func GenCA(name string) (certPEM, keyPEM []byte, err error) { }) return } + +// GenerateCert generates a leaf cert from ca. +func GenerateCert(ca *tls.Certificate, hosts ...string) (*tls.Certificate, error) { + now := time.Now().Add(-1 * time.Hour).UTC() + if !ca.Leaf.IsCA { + return nil, errors.New("CA cert is not a CA") + } + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, fmt.Errorf("failed to generate serial number: %s", err) + } + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{CommonName: hosts[0]}, + NotBefore: now, + NotAfter: now.Add(leafMaxAge), + KeyUsage: leafUsage, + BasicConstraintsValid: true, + SignatureAlgorithm: x509.ECDSAWithSHA512, + } + + for _, h := range hosts { + if ip := net.ParseIP(h); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } else { + template.DNSNames = append(template.DNSNames, h) + } + } + + key, err := genKeyPair() + if err != nil { + return nil, err + } + x, err := x509.CreateCertificate(rand.Reader, template, ca.Leaf, key.Public(), ca.PrivateKey) + if err != nil { + return nil, err + } + cert := new(tls.Certificate) + cert.Certificate = append(cert.Certificate, x) + cert.PrivateKey = key + cert.Leaf, _ = x509.ParseCertificate(x) + return cert, nil +} + +func genKeyPair() (*ecdsa.PrivateKey, error) { + return ecdsa.GenerateKey(elliptic.P521(), rand.Reader) +} diff --git a/cert_test.go b/cert_test.go new file mode 100644 index 0000000..70466ad --- /dev/null +++ b/cert_test.go @@ -0,0 +1,261 @@ +// Code generated by go-bindata. +// sources: +// cert.pem +// key.pem +// DO NOT EDIT! + +package mitm + +import ( + "bytes" + "compress/gzip" + "fmt" + "io" + "strings" + "os" + "time" + "io/ioutil" + "path" + "path/filepath" +) + +func bindataRead(data []byte, name string) ([]byte, error) { + gz, err := gzip.NewReader(bytes.NewBuffer(data)) + if err != nil { + return nil, fmt.Errorf("Read %q: %v", name, err) + } + + var buf bytes.Buffer + _, err = io.Copy(&buf, gz) + clErr := gz.Close() + + if err != nil { + return nil, fmt.Errorf("Read %q: %v", name, err) + } + if clErr != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +type asset struct { + bytes []byte + info os.FileInfo +} + +type bindataFileInfo struct { + name string + size int64 + mode os.FileMode + modTime time.Time +} + +func (fi bindataFileInfo) Name() string { + return fi.name +} +func (fi bindataFileInfo) Size() int64 { + return fi.size +} +func (fi bindataFileInfo) Mode() os.FileMode { + return fi.mode +} +func (fi bindataFileInfo) ModTime() time.Time { + return fi.modTime +} +func (fi bindataFileInfo) IsDir() bool { + return false +} +func (fi bindataFileInfo) Sys() interface{} { + return nil +} + +var _certPem = []byte("\x1f\x8b\x08\x00\x00\x09\x6e\x88\x00\xff\x64\x92\x4f\x93\x9a\x4c\x10\xc6\xef\x7c\x8a\xf7\x6e\xbd\xa5\x20\x24\x7a\xd8\x43\xf7\x30\xe0\x88\x8c\x0e\x3b\xfc\xdb\xdb\x4a\x8d\xa3\x80\x6b\x40\x77\x87\xf0\xe9\xe3\x9a\xad\xe4\x90\x3e\x75\xfd\x9e\xae\xae\x7e\xba\x9e\xff\x3f\x0b\x69\xc8\xf8\x7f\x84\x26\x92\x05\x8c\x80\xa4\x0f\x6a\xc5\x8c\xa1\xaa\x09\x81\x64\xa2\xc1\x30\x04\xcd\x12\x58\xbf\x7f\x30\x75\xa3\xa9\x57\xf7\x17\xbe\x31\xea\x3a\xca\x42\x80\x21\xba\x64\xd1\xe5\x85\x8d\xf5\x8c\xde\x67\x0d\xad\xa9\xb0\x62\x70\x43\xb0\x53\x4a\x8e\xf1\x4a\xe4\xfc\xf6\xf2\x8c\xfe\x7e\x04\x15\x98\xd9\xc0\x25\x38\xb1\xdf\x0c\x5b\xc9\x0c\x97\xc1\xeb\x83\xd5\x9f\x4c\xff\x61\x56\x8c\x6c\xa0\x3e\x6c\x51\xf3\x0c\xe1\x22\x71\x16\xd4\xfb\x3c\xd5\xc2\x59\x98\x5c\x82\x44\x5d\x75\xc7\xa6\xde\xee\x04\x43\xd4\xfa\xab\x8f\x11\x0d\x27\x00\xc9\x4f\x8b\x0d\x67\x2f\xe9\x76\x65\x53\x40\xc4\x26\xee\x8f\x79\xa6\x28\x9a\xa9\x28\xaa\x77\x67\xd9\x1d\x23\xa5\xa7\xc6\x35\xf5\xf2\xbb\x97\x38\xf3\x6f\xf3\xac\xf4\x76\x83\xdd\xcf\xa3\x8e\xf4\xe2\xea\x84\x9e\x75\x66\xc5\xe8\xaa\xc6\xcb\x8b\xa9\x3d\x2e\x33\x7b\x86\x3d\x5e\xec\xd2\x64\x5f\x57\xad\xc4\x02\xe1\xb0\xa0\x08\x31\x81\x4e\x18\x6a\x4a\x3f\x4b\x66\x2d\x82\x79\xfc\xc4\x8a\x4c\x89\x28\xd2\x15\xdc\x25\xff\xb7\x28\x41\xac\xa6\x08\xa9\xb9\x33\x9c\x8e\x50\x3d\x16\x25\x94\x06\x77\x4b\x5a\xdf\x5a\x15\x06\xb7\x2a\x1c\xda\xcd\x99\x7f\x58\xfb\xb2\xa2\x07\x03\x00\x12\xa2\xbf\x1e\x85\x0f\x9a\xaf\xc1\xc7\x10\x4e\x14\x7a\x7f\xc3\x03\x91\x1e\xd2\x6e\x73\xd8\xda\x92\xec\xc5\x64\xdf\xb9\xea\xad\x59\x13\xc7\x3a\x2d\xf4\xb5\x28\x65\xaf\x7d\x2f\x2d\xd7\x2e\x61\xc2\x47\xaa\x36\x19\x23\x97\x21\xac\x5b\xde\x48\xcf\x7e\x3e\xe5\x59\xa5\x64\x7e\x1d\x61\xa4\x75\xfb\xc6\x9c\xf6\x75\x4c\x89\x52\xf0\xf4\x64\x3d\xd2\x40\xb9\xff\x6f\x42\x7e\x05\x00\x00\xff\xff\xbb\xc0\x85\xc9\x3e\x02\x00\x00") + +func certPemBytes() ([]byte, error) { + return bindataRead( + _certPem, + "cert.pem", + ) +} + +func certPem() (*asset, error) { + bytes, err := certPemBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "cert.pem", size: 574, mode: os.FileMode(420), modTime: time.Unix(1433877651, 0)} + a := &asset{bytes: bytes, info: info} + return a, nil +} + +var _keyPem = []byte("\x1f\x8b\x08\x00\x00\x09\x6e\x88\x00\xff\x6c\x8f\xbd\x6e\x83\x30\x18\x00\x77\x9e\x82\x1d\x55\x40\x53\xb5\x30\x64\xf8\x70\x3e\x8c\x13\xcc\x4f\xf9\x49\x60\x43\x36\x4a\x68\xaa\x1a\xab\x50\x37\x6f\x5f\x35\x73\x6e\xbd\xe5\xee\xe9\x9f\x08\x29\xcb\x6c\x24\x76\xf1\xce\x5a\xa8\xd1\x3e\x60\x77\x17\x16\x4f\x04\x81\x12\x91\xa5\x15\x1d\x4f\x26\xf6\x8f\xd7\x5f\xff\x3b\x82\xd9\x8b\xc4\x7e\xf3\xda\x5f\x71\xbe\x31\x97\xe6\xa3\xee\x34\x6f\xdf\xe6\xb1\xee\xb8\x02\x45\x09\xd1\xb4\xe2\x2f\xa1\x05\x06\x13\xd5\x94\xbb\xf2\x0c\x28\x26\xde\xaf\xcd\xf0\xf1\xbc\x8f\x0d\x99\x8a\xf5\x20\xfd\xd3\x05\x44\xe1\xc5\x9b\x74\xf9\x19\x54\xf5\xa5\x8a\x20\xe7\xcc\x1d\x9d\x46\x2e\xce\xe2\x1f\xf3\xda\x0a\xe4\x10\xde\xf4\x59\x7b\x69\x76\x59\xfb\x29\x0e\x9c\x64\xce\xdb\x4f\x37\x14\x6e\x2b\x25\x0c\xa6\xdc\x6e\xad\x7b\x2c\x66\xbb\x87\x0f\x7f\x01\x00\x00\xff\xff\x3d\x7d\x75\xfb\xe3\x00\x00\x00") + +func keyPemBytes() ([]byte, error) { + return bindataRead( + _keyPem, + "key.pem", + ) +} + +func keyPem() (*asset, error) { + bytes, err := keyPemBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "key.pem", size: 227, mode: os.FileMode(384), modTime: time.Unix(1433877651, 0)} + a := &asset{bytes: bytes, info: info} + return a, nil +} + +// Asset loads and returns the asset for the given name. +// It returns an error if the asset could not be found or +// could not be loaded. +func Asset(name string) ([]byte, error) { + cannonicalName := strings.Replace(name, "\\", "/", -1) + if f, ok := _bindata[cannonicalName]; ok { + a, err := f() + if err != nil { + return nil, fmt.Errorf("Asset %s can't read by error: %v", name, err) + } + return a.bytes, nil + } + return nil, fmt.Errorf("Asset %s not found", name) +} + +// MustAsset is like Asset but panics when Asset would return an error. +// It simplifies safe initialization of global variables. +func MustAsset(name string) []byte { + a, err := Asset(name) + if (err != nil) { + panic("asset: Asset(" + name + "): " + err.Error()) + } + + return a +} + +// AssetInfo loads and returns the asset info for the given name. +// It returns an error if the asset could not be found or +// could not be loaded. +func AssetInfo(name string) (os.FileInfo, error) { + cannonicalName := strings.Replace(name, "\\", "/", -1) + if f, ok := _bindata[cannonicalName]; ok { + a, err := f() + if err != nil { + return nil, fmt.Errorf("AssetInfo %s can't read by error: %v", name, err) + } + return a.info, nil + } + return nil, fmt.Errorf("AssetInfo %s not found", name) +} + +// AssetNames returns the names of the assets. +func AssetNames() []string { + names := make([]string, 0, len(_bindata)) + for name := range _bindata { + names = append(names, name) + } + return names +} + +// _bindata is a table, holding each asset generator, mapped to its name. +var _bindata = map[string]func() (*asset, error){ + "cert.pem": certPem, + "key.pem": keyPem, +} + +// AssetDir returns the file names below a certain +// directory embedded in the file by go-bindata. +// For example if you run go-bindata on data/... and data contains the +// following hierarchy: +// data/ +// foo.txt +// img/ +// a.png +// b.png +// then AssetDir("data") would return []string{"foo.txt", "img"} +// AssetDir("data/img") would return []string{"a.png", "b.png"} +// AssetDir("foo.txt") and AssetDir("notexist") would return an error +// AssetDir("") will return []string{"data"}. +func AssetDir(name string) ([]string, error) { + node := _bintree + if len(name) != 0 { + cannonicalName := strings.Replace(name, "\\", "/", -1) + pathList := strings.Split(cannonicalName, "/") + for _, p := range pathList { + node = node.Children[p] + if node == nil { + return nil, fmt.Errorf("Asset %s not found", name) + } + } + } + if node.Func != nil { + return nil, fmt.Errorf("Asset %s not found", name) + } + rv := make([]string, 0, len(node.Children)) + for childName := range node.Children { + rv = append(rv, childName) + } + return rv, nil +} + +type bintree struct { + Func func() (*asset, error) + Children map[string]*bintree +} +var _bintree = &bintree{nil, map[string]*bintree{ + "cert.pem": &bintree{certPem, map[string]*bintree{ + }}, + "key.pem": &bintree{keyPem, map[string]*bintree{ + }}, +}} + +// RestoreAsset restores an asset under the given directory +func RestoreAsset(dir, name string) error { + data, err := Asset(name) + if err != nil { + return err + } + info, err := AssetInfo(name) + if err != nil { + return err + } + err = os.MkdirAll(_filePath(dir, path.Dir(name)), os.FileMode(0755)) + if err != nil { + return err + } + err = ioutil.WriteFile(_filePath(dir, name), data, info.Mode()) + if err != nil { + return err + } + err = os.Chtimes(_filePath(dir, name), info.ModTime(), info.ModTime()) + if err != nil { + return err + } + return nil +} + +// RestoreAssets restores an asset under the given directory recursively +func RestoreAssets(dir, name string) error { + children, err := AssetDir(name) + // File + if err != nil { + return RestoreAsset(dir, name) + } + // Dir + for _, child := range children { + err = RestoreAssets(dir, path.Join(name, child)) + if err != nil { + return err + } + } + return nil +} + +func _filePath(dir, name string) string { + cannonicalName := strings.Replace(name, "\\", "/", -1) + return filepath.Join(append([]string{dir}, strings.Split(cannonicalName, "/")...)...) +} + diff --git a/mitm.go b/mitm.go index 78ac2d7..0bace1d 100644 --- a/mitm.go +++ b/mitm.go @@ -3,6 +3,7 @@ package mitm import ( "crypto/tls" "errors" + "io" "log" "net" "net/http" @@ -11,6 +12,67 @@ import ( "time" ) +type ServerParam struct { + CA *tls.Certificate // the Root CA for generatng on the fly MITM certificates + TLSConfig *tls.Config // a template TLS config for the server. +} + +// A ServerConn is a net.Conn that holds its clients SNI header in ServerName +// after the handshake. +type ServerConn struct { + *tls.Conn + + // ServerName is set during Conn's handshake to the client's requested + // server name set in the SNI header. It is not safe to access across + // multiple goroutines while Conn is performing the handshake. + ServerName string +} + +// Server wraps cn with a ServerConn configured with p so that during its +// Handshake, it will generate a new certificate using p.CA. After a successful +// Handshake, its ServerName field will be set to the clients requested +// ServerName in the SNI header. +func Server(cn net.Conn, p ServerParam) *ServerConn { + conf := new(tls.Config) + if p.TLSConfig != nil { + *conf = *p.TLSConfig + } + sc := new(ServerConn) + conf.GetCertificate = func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + sc.ServerName = hello.ServerName + return getCert(p.CA, hello.ServerName) + } + sc.Conn = tls.Server(cn, conf) + return sc +} + +type listener struct { + net.Listener + ca *tls.Certificate + conf *tls.Config +} + +// NewListener returns a net.Listener that generates a new cert from ca for +// each new Accept. It uses SNI to generate the cert, and herefore only +// works with clients that send SNI headers. +// +// This is useful for building transparent MITM proxies. +func NewListener(inner net.Listener, ca *tls.Certificate, conf *tls.Config) net.Listener { + return &listener{inner, ca, conf} +} + +func (l *listener) Accept() (net.Conn, error) { + cn, err := l.Listener.Accept() + if err != nil { + return nil, err + } + sc := Server(cn, ServerParam{ + CA: l.ca, + TLSConfig: l.conf, + }) + return sc, nil +} + // Proxy is a forward proxy that substitutes its own certificate // for incoming TLS connections in place of the upstream server's // certificate. @@ -36,162 +98,120 @@ type Proxy struct { // response body. // If zero, no periodic flushing is done. FlushInterval time.Duration -} -func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.Method == "CONNECT" { - p.serveConnect(w, r) - return - } - rp := &httputil.ReverseProxy{ - Director: httpDirector, - FlushInterval: p.FlushInterval, - } - p.Wrap(rp).ServeHTTP(w, r) + // Director is function which modifies the request into a new + // request to be sent using Transport. See the documentation for + // httputil.ReverseProxy for more details. For mitm proxies, the + // director defaults to HTTPDirector, but for transparent TLS + // proxies it should be set to HTTPSDirector. + Director func(*http.Request) } -func (p *Proxy) serveConnect(w http.ResponseWriter, r *http.Request) { - var ( - err error - sconn *tls.Conn - name = dnsName(r.Host) - ) +var ( + okHeader = "HTTP/1.1 200 OK\r\n\r\n" + noUpstreamHeader = "HTTP/1.1 503 No Upstream\r\n\r\n" + noDownstreamHeader = "HTTP/1.1 503 No Downstream\r\n\r\n" + errHeader = "HTTP/1.1 500 Internal Server Error\r\n\r\n" +) - if name == "" { - log.Println("cannot determine cert name for " + r.Host) - http.Error(w, "no upstream", 503) +func (p *Proxy) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Method != "CONNECT" { + rp := &httputil.ReverseProxy{ + Director: p.Director, + FlushInterval: p.FlushInterval, + } + if rp.Director == nil { + rp.Director = HTTPDirector + } + p.Wrap(rp).ServeHTTP(w, req) return } - provisionalCert, err := p.cert(name) + cn, _, err := w.(http.Hijacker).Hijack() if err != nil { - log.Println("cert", err) - http.Error(w, "no upstream", 503) + log.Println("Hijack:", err) + http.Error(w, "No Upstream", 503) return } + defer cn.Close() - sConfig := new(tls.Config) - if p.TLSServerConfig != nil { - *sConfig = *p.TLSServerConfig + _, err = io.WriteString(cn, okHeader) + if err != nil { + log.Println("Write:", err) + return } - sConfig.Certificates = []tls.Certificate{*provisionalCert} - sConfig.GetCertificate = func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { - cConfig := new(tls.Config) - if p.TLSClientConfig != nil { - *cConfig = *p.TLSClientConfig + + sc, ok := cn.(*ServerConn) + if !ok { + name := dnsName(req.Host) + if name == "" { + log.Println("cannot determine cert name for " + req.Host) + io.WriteString(cn, noDownstreamHeader) + return } - cConfig.ServerName = hello.ServerName - sconn, err = tls.Dial("tcp", r.Host, cConfig) - if err != nil { - log.Println("dial", r.Host, err) - return nil, err + sc = Server(cn, ServerParam{ + CA: p.CA, + TLSConfig: p.TLSServerConfig, + }) + if err := sc.Handshake(); err != nil { + log.Println("Server Handshake:", err) + return } - return p.cert(hello.ServerName) } - cconn, err := handshake(w, sConfig) + cc, err := p.tlsDial(req.Host, sc.ServerName) if err != nil { - log.Println("handshake", r.Host, err) + log.Println("tlsDial:", err) + io.WriteString(cn, noUpstreamHeader) return } - defer cconn.Close() - if sconn == nil { - log.Println("could not determine cert name for " + r.Host) - return + p.proxyMITM(sc, cc) +} + +func (p *Proxy) tlsDial(addr, serverName string) (net.Conn, error) { + conf := new(tls.Config) + if p.TLSClientConfig != nil { + *conf = *p.TLSClientConfig } - defer sconn.Close() + conf.ServerName = serverName + return tls.Dial("tcp", addr, conf) +} - od := &oneShotDialer{c: sconn} +func (p *Proxy) proxyMITM(upstream, downstream net.Conn) { + var mu sync.Mutex + dial := func(network, addr string) (net.Conn, error) { + mu.Lock() + defer mu.Unlock() + if downstream == nil { + return nil, io.EOF + } + cn := downstream + downstream = nil + return cn, nil + } rp := &httputil.ReverseProxy{ - Director: httpsDirector, - Transport: &http.Transport{DialTLS: od.Dial}, + Director: HTTPSDirector, + Transport: &http.Transport{DialTLS: dial}, FlushInterval: p.FlushInterval, } - - ch := make(chan int) - wc := &onCloseConn{cconn, func() { ch <- 0 }} + ch := make(chan struct{}) + wc := &onCloseConn{upstream, func() { ch <- struct{}{} }} http.Serve(&oneShotListener{wc}, p.Wrap(rp)) <-ch } -func (p *Proxy) cert(names ...string) (*tls.Certificate, error) { - return genCert(p.CA, names) -} - -var okHeader = []byte("HTTP/1.1 200 OK\r\n\r\n") - -// handshake hijacks w's underlying net.Conn, responds to the CONNECT request -// and manually performs the TLS handshake. It returns the net.Conn or and -// error if any. -func handshake(w http.ResponseWriter, config *tls.Config) (net.Conn, error) { - raw, _, err := w.(http.Hijacker).Hijack() - if err != nil { - http.Error(w, "no upstream", 503) - return nil, err - } - if _, err = raw.Write(okHeader); err != nil { - raw.Close() - return nil, err - } - conn := tls.Server(raw, config) - err = conn.Handshake() - if err != nil { - conn.Close() - raw.Close() - return nil, err - } - return conn, nil -} - -func httpDirector(r *http.Request) { +// HTTPDirector is director designed for use in Proxy for http +// proxies. +func HTTPDirector(r *http.Request) { r.URL.Host = r.Host r.URL.Scheme = "http" } -func httpsDirector(r *http.Request) { - r.URL.Host = r.Host - r.URL.Scheme = "https" -} - -// dnsName returns the DNS name in addr, if any. -func dnsName(addr string) string { - host, _, err := net.SplitHostPort(addr) - if err != nil { - return "" - } - return host -} - -// namesOnCert returns the dns names -// in the peer's presented cert. -func namesOnCert(conn *tls.Conn) []string { - // TODO(kr): handle IP addr SANs. - c := conn.ConnectionState().PeerCertificates[0] - if len(c.DNSNames) > 0 { - // If Subject Alt Name is given, - // we ignore the common name. - // This matches behavior of crypto/x509. - return c.DNSNames - } - return []string{c.Subject.CommonName} -} - -// A oneShotDialer implements net.Dialer whos Dial only returns a -// net.Conn as specified by c followed by an error for each subsequent Dial. -type oneShotDialer struct { - c net.Conn - mu sync.Mutex -} - -func (d *oneShotDialer) Dial(network, addr string) (net.Conn, error) { - d.mu.Lock() - defer d.mu.Unlock() - if d.c == nil { - return nil, errors.New("closed") - } - c := d.c - d.c = nil - return c, nil +// HTTPSDirector is a director designed for use in Proxy for +// transparent TLS proxies. +func HTTPSDirector(req *http.Request) { + req.URL.Host = req.Host + req.URL.Scheme = "https" } // A oneShotListener implements net.Listener whos Accept only returns a @@ -230,3 +250,57 @@ func (c *onCloseConn) Close() error { } return c.Conn.Close() } + +// dnsName returns the DNS name in addr, if any. +func dnsName(addr string) string { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return "" + } + return host +} + +// Certificates are cached locally to avoid unnecessary regeneration +const certCacheMaxSize = 1000 + +var ( + certCache = make(map[*tls.Certificate]map[string]*tls.Certificate) + certCacheMutex sync.RWMutex +) + +func getCert(ca *tls.Certificate, host string) (*tls.Certificate, error) { + if c := getCachedCert(ca, host); c != nil { + return c, nil + } + cert, err := GenerateCert(ca, host) + if err != nil { + return nil, err + } + cacheCert(ca, host, cert) + return cert, nil +} + +func getCachedCert(ca *tls.Certificate, host string) *tls.Certificate { + certCacheMutex.RLock() + defer certCacheMutex.RUnlock() + + if certCache[ca] == nil { + return nil + } + cert := certCache[ca][host] + if cert == nil || cert.Leaf.NotAfter.Before(time.Now()) { + return nil + } else { + return cert + } +} + +func cacheCert(ca *tls.Certificate, host string, cert *tls.Certificate) { + certCacheMutex.Lock() + defer certCacheMutex.Unlock() + + if certCache[ca] == nil || len(certCache[ca]) > certCacheMaxSize { + certCache[ca] = make(map[string]*tls.Certificate) + } + certCache[ca][host] = cert +} diff --git a/mitm_test.go b/mitm_test.go index 4f1ac13..d146997 100644 --- a/mitm_test.go +++ b/mitm_test.go @@ -1,10 +1,15 @@ +//go:generate go run $GOROOT/src/crypto/tls/generate_cert.go -host "example.com,127.0.0.1" -ca -ecdsa-curve P256 +//go:generate sh -c "go-bindata -o cert_test.go -pkg mitm *.pem" package mitm import ( + "bufio" "crypto/tls" "crypto/x509" "flag" + "io" "io/ioutil" + "log" "net" "net/http" "net/http/httptest" @@ -14,6 +19,10 @@ import ( "testing" ) +func init() { + log.SetFlags(log.Lshortfile) +} + var hostname, _ = os.Hostname() var ( @@ -24,30 +33,40 @@ func init() { flag.Parse() } -func genCA() (cert tls.Certificate, err error) { - certPEM, keyPEM, err := GenCA(hostname) - if err != nil { - return tls.Certificate{}, err - } - cert, err = tls.X509KeyPair(certPEM, keyPEM) - if err != nil { - return tls.Certificate{}, err - } - cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) - return cert, err -} +var ( + caCert = MustAsset("cert.pem") + caKey = MustAsset("key.pem") +) -func testProxy(t *testing.T, ca *tls.Certificate, setupReq func(req *http.Request), wrap func(http.Handler) http.Handler, downstream http.HandlerFunc, checkResp func(*http.Response)) { +func testProxy(t *testing.T, setupReq func(req *http.Request), wrap func(http.Handler) http.Handler, downstream http.HandlerFunc, checkResp func(*http.Response)) { ds := httptest.NewTLSServer(downstream) defer ds.Close() + rootCAs := x509.NewCertPool() + if !rootCAs.AppendCertsFromPEM(caCert) { + panic("can't add cert") + } + + ca, err := tls.X509KeyPair(caCert, caKey) + if err != nil { + panic(err) + } + ca.Leaf, err = x509.ParseCertificate(ca.Certificate[0]) + if err != nil { + panic(err) + } + cert, err := GenerateCert(&ca, "www.google.com") + if err != nil { + t.Fatal("GenerateCert:", err) + } p := &Proxy{ - CA: ca, + CA: &ca, TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, TLSServerConfig: &tls.Config{ - MinVersion: tls.VersionTLS12, + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{*cert}, }, Wrap: wrap, } @@ -82,7 +101,7 @@ func testProxy(t *testing.T, ca *tls.Certificate, setupReq func(req *http.Reques return &u, nil }, TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, + RootCAs: rootCAs, }, }, } @@ -97,12 +116,8 @@ func testProxy(t *testing.T, ca *tls.Certificate, setupReq func(req *http.Reques func Test(t *testing.T) { const xHops = "X-Hops" - ca, err := genCA() - if err != nil { - t.Fatal("loadCA:", err) - } - - testProxy(t, &ca, func(req *http.Request) { + testProxy(t, func(req *http.Request) { + // req.Host = "example.com" req.Header.Set(xHops, "a") }, func(upstream http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -126,13 +141,8 @@ func TestNet(t *testing.T) { t.Skip() } - ca, err := genCA() - if err != nil { - t.Fatal("loadCA:", err) - } - var wrapped bool - testProxy(t, &ca, func(req *http.Request) { + testProxy(t, func(req *http.Request) { nreq, _ := http.NewRequest("GET", "https://mitmtest.herokuapp.com/", nil) *req = *nreq }, func(upstream http.Handler) http.Handler { @@ -158,3 +168,88 @@ func TestNet(t *testing.T) { } }) } + +func TestNewListener(t *testing.T) { + ca, err := tls.X509KeyPair(caCert, caKey) + if err != nil { + t.Fatal("X509KeyPair:", err) + } + ca.Leaf, err = x509.ParseCertificate(ca.Certificate[0]) + if err != nil { + t.Fatal("ParseCertificate:", err) + } + + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal("Listen:", err) + } + defer l.Close() + + cert, err := GenerateCert(&ca, "www.google.com") + if err != nil { + t.Fatal("GenerateCert:", err) + } + l = NewListener(l, &ca, &tls.Config{ + MinVersion: tls.VersionSSL30, + Certificates: []tls.Certificate{*cert}, + }) + paddr := l.Addr().String() + + called := false + go http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.Host != "www.google.com" { + t.Errorf("want Host www.google.com, got %s", req.Host) + } + called = true + })) + + rootCAs := x509.NewCertPool() + if !rootCAs.AppendCertsFromPEM(caCert) { + t.Fatal("can't add cert") + } + cc, err := tls.Dial("tcp", paddr, &tls.Config{ + MinVersion: tls.VersionSSL30, + ServerName: "foo.com", + RootCAs: rootCAs, + }) + if err != nil { + t.Fatal("Dial:", err) + } + if err := cc.Handshake(); err != nil { + t.Fatal("Handshake:", err) + } + + bw := bufio.NewWriter(cc) + var w io.Writer = &stickyErrWriter{bw, &err} + io.WriteString(w, "GET / HTTP/1.1\r\n") + io.WriteString(w, "Host: www.google.com\r\n") + io.WriteString(w, "\r\n\r\n") + bw.Flush() + if err != nil { + t.Error("Write:", err) + } + + resp, err := http.ReadResponse(bufio.NewReader(cc), nil) + if err != nil { + t.Fatal("ReadResponse:", err) + } + if !called { + t.Error("want downstream called") + } + if resp.StatusCode != 200 { + t.Errorf("want StatusCode 200, got %d", resp.StatusCode) + } +} + +type stickyErrWriter struct { + io.Writer + err *error +} + +func (w *stickyErrWriter) Write(b []byte) (int, error) { + n, err := w.Writer.Write(b) + if *w.err == nil { + *w.err = err + } + return n, *w.err +}