Skip to content

Commit

Permalink
support websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
IrineSistiana committed Nov 23, 2021
1 parent 6493c4c commit ee2961f
Show file tree
Hide file tree
Showing 13 changed files with 860 additions and 422 deletions.
74 changes: 74 additions & 0 deletions core/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (C) 2020-2021, IrineSistiana
//
// This file is part of simple-tls.
//
// simple-tls is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// simple-tls is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

package core

import (
"context"
"crypto/md5"
"errors"
"fmt"
"io"
"net"
"time"
)

type AuthTransport struct {
nextTransport Transport
auth [md5.Size]byte
}

func (t *AuthTransport) Dial(ctx context.Context) (net.Conn, error) {
conn, err := t.nextTransport.Dial(ctx)
if err != nil {
return nil, err
}
if _, err := conn.Write(t.auth[:]); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to write auth: %w", err)
}
return conn, nil
}

func NewAuthTransport(nextTransport Transport, auth string) *AuthTransport {
return &AuthTransport{nextTransport: nextTransport, auth: md5.Sum([]byte(auth))}
}

type AuthTransportHandler struct {
nextHandler TransportHandler
auth [md5.Size]byte
}

func NewAuthTransportHandler(nextHandler TransportHandler, auth string) *AuthTransportHandler {
return &AuthTransportHandler{nextHandler: nextHandler, auth: md5.Sum([]byte(auth))}
}

var errAuthFailed = errors.New("auth failed")

func (h *AuthTransportHandler) Handle(conn net.Conn) error {
var auth [md5.Size]byte
if _, err := io.ReadFull(conn, auth[:]); err != nil {
return fmt.Errorf("failed to read auth header: %w", err)
}

if auth != h.auth {
discardRead(conn, time.Second*15)
return errAuthFailed
}

return h.nextHandler.Handle(conn)
}
184 changes: 103 additions & 81 deletions core/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,124 +18,146 @@
package core

import (
"crypto/md5"
"bytes"
"context"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"errors"
"fmt"
"github.com/IrineSistiana/ctunnel"
"log"
"io/ioutil"
"net"
"strings"
"time"
)

type Client struct {
Listener net.Listener
ServerAddr string
NoTLS bool
Auth string
BindAddr string
DstAddr string
Websocket bool
WebsocketPath string
Mux int
Auth string

ServerName string
CertPool *x509.CertPool
CA string
CertHash string
InsecureSkipVerify bool
Timeout time.Duration
AndroidVPNMode bool
TFO bool
Mux int

dialer net.Dialer
auth [16]byte
tlsConfig *tls.Config
muxPool *muxPool // not nil if Mux > 0

IdleTimeout time.Duration
AndroidVPNMode bool
TFO bool

testListener net.Listener
}

var errEmptyCAFile = errors.New("no valid certificate was found in the ca file")

func (c *Client) ActiveAndServe() error {
c.dialer = net.Dialer{
Timeout: time.Second * 5,
Control: GetControlFunc(&TcpConfig{AndroidVPN: c.AndroidVPNMode, EnableTFO: c.TFO}),

var l net.Listener
if c.testListener != nil {
l = c.testListener
} else {
var err error
lc := net.ListenConfig{}
l, err = lc.Listen(context.Background(), "tcp", c.BindAddr)
if err != nil {
return err
}
}

if !c.NoTLS {
c.tlsConfig = new(tls.Config)
c.tlsConfig.NextProtos = []string{"http/1.1", "h2"}
c.tlsConfig.ServerName = c.ServerName
c.tlsConfig.RootCAs = c.CertPool
c.tlsConfig.InsecureSkipVerify = c.InsecureSkipVerify
if len(c.ServerName) == 0 {
c.ServerName = strings.SplitN(c.DstAddr, ":", 2)[0]
}

if len(c.Auth) > 0 {
c.auth = md5.Sum([]byte(c.Auth))
var rootCAs *x509.CertPool
if len(c.CA) != 0 {
rootCAs = x509.NewCertPool()
certPEMBlock, err := ioutil.ReadFile(c.CA)
if err != nil {
return fmt.Errorf("cannot read ca file: %w", err)
}
if ok := rootCAs.AppendCertsFromPEM(certPEMBlock); !ok {
return errEmptyCAFile
}
}

if c.Mux > 0 {
c.muxPool = newMuxPool(c.dialServerConn, c.Mux)
dialer := &net.Dialer{
Timeout: time.Second * 5,
Control: GetControlFunc(&TcpConfig{AndroidVPN: c.AndroidVPNMode, EnableTFO: c.TFO}),
}

for {
localConn, err := c.Listener.Accept()
var chb []byte
if len(c.CertHash) != 0 {
b, err := hex.DecodeString(c.CertHash)
if err != nil {
return fmt.Errorf("l.Accept(): %w", err)
return fmt.Errorf("invalid cert hash: %w", err)
}
reduceTCPLoopbackSocketBuf(localConn)
chb = b
}

go func() {
defer localConn.Close()

var serverConn net.Conn
if c.Mux > 0 {
stream, _, err := c.muxPool.GetStream()
if err != nil {
log.Printf("ERROR: muxPool.GetStream: %v", err)
return
tlsConfig := &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
ServerName: c.ServerName,
RootCAs: rootCAs,
InsecureSkipVerify: c.InsecureSkipVerify,
VerifyConnection: func(state tls.ConnectionState) error {
if len(chb) != 0 {
cert := state.PeerCertificates[0]
h := sha256.Sum256(cert.RawTBSCertificate)
if bytes.Equal(h[:len(chb)], chb) {
return nil
}
serverConn = stream
} else {
conn, err := c.dialServerConn()
if err != nil {
log.Printf("ERROR: dialServerConn: %v", err)
return
}
serverConn = conn
return fmt.Errorf("cert hash mismatch, recieved cert hash is [%s]", hex.EncodeToString(h[:]))
}
defer serverConn.Close()

if err := ctunnel.OpenTunnel(localConn, serverConn, c.Timeout); err != nil {
log.Printf("ERROR: ActiveAndServe: openTunnel: %v", err)
if state.Version != tls.VersionTLS13 {
return fmt.Errorf("unsafe tls version %d", state.Version)
}
}()
return nil
},
}
}

func (c *Client) dialServerConn() (net.Conn, error) {
serverConn, err := c.dialer.Dial("tcp", c.ServerAddr)
if err != nil {
return nil, err
var transport Transport
if c.Websocket {
transport = NewWebsocketTransport(c.DstAddr, c.ServerName, c.WebsocketPath, tlsConfig, dialer)
} else {
transport = NewRawConnTransport(c.DstAddr, dialer)
transport = NewTLSTransport(transport, tlsConfig)
}

if !c.NoTLS {
serverTLSConn := tls.Client(serverConn, c.tlsConfig)
if err := tls13HandshakeWithTimeout(serverTLSConn, time.Second*5); err != nil {
serverTLSConn.Close()
return nil, err
}
serverConn = serverTLSConn
if len(c.Auth) > 0 {
transport = NewAuthTransport(transport, c.Auth)
}

// write auth
if len(c.Auth) > 0 {
if _, err := serverConn.Write(c.auth[:]); err != nil {
serverConn.Close()
return nil, fmt.Errorf("failed to write auth: %w", err)
transport = NewMuxTransport(transport, c.Mux)

for {
clientConn, err := l.Accept()
if err != nil {
return err
}
}

// write mode
mode := modePlain
if c.Mux > 0 {
mode = modeMux
}
if _, err := serverConn.Write([]byte{mode}); err != nil {
serverConn.Close()
return nil, fmt.Errorf("failed to write mode: %w", err)
go func() {
defer clientConn.Close()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
serverConn, err := transport.Dial(ctx)
if err != nil {
errLogger.Printf("failed to dial server connection: %v", err)
return
}
defer serverConn.Close()

err = ctunnel.OpenTunnel(clientConn, serverConn, c.IdleTimeout)
if err != nil {
logConnErr(clientConn, fmt.Errorf("tunnel closed: %w", err))
}
}()
}

return serverConn, nil
}
Loading

0 comments on commit ee2961f

Please sign in to comment.