diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 955a7225..95285739 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -4,6 +4,7 @@ import ( std_bufio "bufio" "context" "encoding/base64" + "io" "net" "net/http" "strings" @@ -37,7 +38,6 @@ func HandleConnectionEx(ctx context.Context, conn net.Conn, reader *std_bufio.Re if err != nil { return E.Cause(err, "read http request") } - if authenticator != nil { var ( username string @@ -81,11 +81,15 @@ func HandleConnectionEx(ctx context.Context, conn net.Conn, reader *std_bufio.Re } if request.Method == "CONNECT" { - portStr := request.URL.Port() - if portStr == "" { - portStr = "80" + destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), request.URL.Port()) + if destination.Port == 0 { + switch request.URL.Scheme { + case "https", "wss": + destination.Port = 443 + default: + destination.Port = 80 + } } - destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), portStr) _, err = conn.Write([]byte(F.ToString("HTTP/", request.ProtoMajor, ".", request.ProtoMinor, " 200 Connection established\r\n\r\n"))) if err != nil { return E.Cause(err, "write http response") @@ -108,11 +112,48 @@ func HandleConnectionEx(ctx context.Context, conn net.Conn, reader *std_bufio.Re handlerEx.NewConnectionEx(ctx, requestConn, source, destination, onClose) return nil } - } - - err = handleHTTPConnection(ctx, handler, handlerEx, conn, request, source) - if err != nil { - return err + } else if strings.ToLower(request.Header.Get("Connection")) == "upgrade" { + destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), request.URL.Port()) + if destination.Port == 0 { + switch request.URL.Scheme { + case "https", "wss": + destination.Port = 443 + default: + destination.Port = 80 + } + } + serverConn, clientConn := pipe.Pipe() + go func() { + if handler != nil { + //nolint:staticcheck + err := handler.NewConnection(ctx, clientConn, M.Metadata{Protocol: "http", Source: source, Destination: destination}) + if err != nil { + common.Close(serverConn, clientConn) + } + } else { + handlerEx.NewConnectionEx(ctx, clientConn, source, destination, func(it error) { + if it != nil { + common.Close(serverConn, clientConn) + } + }) + } + }() + err = request.Write(serverConn) + if err != nil { + return E.Cause(err, "http: write upgrade request") + } + if reader.Buffered() > 0 { + _, err = io.CopyN(serverConn, reader, int64(reader.Buffered())) + if err != nil { + return err + } + } + return bufio.CopyConn(ctx, conn, serverConn) + } else { + err = handleHTTPConnection(ctx, handler, handlerEx, conn, request, source) + if err != nil { + return err + } } } } @@ -198,7 +239,6 @@ func handleHTTPConnection( if !keepAlive { return conn.Close() } - return nil }