diff --git a/http2/frame.go b/http2/frame.go index 0178647..8a0f30d 100644 --- a/http2/frame.go +++ b/http2/frame.go @@ -1478,7 +1478,7 @@ func (mh *MetaHeadersFrame) checkPseudos() error { pf := mh.PseudoFields() for i, hf := range pf { switch hf.Name { - case ":method", ":path", ":scheme", ":authority": + case ":method", ":path", ":scheme", ":authority", ":protocol": isRequest = true case ":status": isResponse = true diff --git a/http2/http2.go b/http2/http2.go index 479ba4b..8d66420 100644 --- a/http2/http2.go +++ b/http2/http2.go @@ -130,6 +130,10 @@ func (s Setting) Valid() error { if s.Val != 1 && s.Val != 0 { return ConnectionError(ErrCodeProtocol) } + case SettingEnableConnectProtocol: + if s.Val != 1 && s.Val != 0 { + return ConnectionError(ErrCodeProtocol) + } case SettingInitialWindowSize: if s.Val > 1<<31-1 { return ConnectionError(ErrCodeFlowControl) @@ -147,21 +151,23 @@ func (s Setting) Valid() error { type SettingID uint16 const ( - SettingHeaderTableSize SettingID = 0x1 - SettingEnablePush SettingID = 0x2 - SettingMaxConcurrentStreams SettingID = 0x3 - SettingInitialWindowSize SettingID = 0x4 - SettingMaxFrameSize SettingID = 0x5 - SettingMaxHeaderListSize SettingID = 0x6 + SettingHeaderTableSize SettingID = 0x1 + SettingEnablePush SettingID = 0x2 + SettingMaxConcurrentStreams SettingID = 0x3 + SettingInitialWindowSize SettingID = 0x4 + SettingMaxFrameSize SettingID = 0x5 + SettingMaxHeaderListSize SettingID = 0x6 + SettingEnableConnectProtocol SettingID = 0x8 ) var settingName = map[SettingID]string{ - SettingHeaderTableSize: "HEADER_TABLE_SIZE", - SettingEnablePush: "ENABLE_PUSH", - SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS", - SettingInitialWindowSize: "INITIAL_WINDOW_SIZE", - SettingMaxFrameSize: "MAX_FRAME_SIZE", - SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE", + SettingHeaderTableSize: "HEADER_TABLE_SIZE", + SettingEnablePush: "ENABLE_PUSH", + SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS", + SettingInitialWindowSize: "INITIAL_WINDOW_SIZE", + SettingMaxFrameSize: "MAX_FRAME_SIZE", + SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE", + SettingEnableConnectProtocol: "ENABLE_CONNECT_PROTOCOL", } func (s SettingID) String() string { diff --git a/http2/server.go b/http2/server.go index 2d859af..146a9bc 100644 --- a/http2/server.go +++ b/http2/server.go @@ -829,6 +829,7 @@ func (sc *serverConn) serve() { {SettingMaxConcurrentStreams, sc.advMaxStreams}, {SettingMaxHeaderListSize, sc.maxHeaderListSize()}, {SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())}, + {SettingEnableConnectProtocol, 1}, }, }) sc.unackedSettings++ @@ -2012,12 +2013,23 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res scheme: f.PseudoValue("scheme"), authority: f.PseudoValue("authority"), path: f.PseudoValue("path"), + protocol: f.PseudoValue("protocol"), } isConnect := rp.method == "CONNECT" if isConnect { - if rp.path != "" || rp.scheme != "" || rp.authority == "" { - return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol)) + if rp.protocol == "" { + // This is an ordinary CONNECT. It should only have a host (authority). + if rp.path != "" || rp.scheme != "" || rp.authority == "" { + return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol)) + } + } else { + // This is an extended CONNECT (https://datatracker.ietf.org/doc/html/rfc8441#section-4) + + // we MUST have a scheme and path + if rp.path == "" || rp.scheme == "" { + return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol)) + } } } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { // See 8.1.2.6 Malformed Requests and Responses: @@ -2071,6 +2083,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res type requestParam struct { method string scheme, authority, path string + protocol string header http.Header } @@ -2112,7 +2125,7 @@ func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp requestParam) (*r var url_ *url.URL var requestURI string - if rp.method == "CONNECT" { + if rp.method == "CONNECT" && rp.protocol == "" { url_ = &url.URL{Host: rp.authority} requestURI = rp.authority // mimic HTTP/1 server behavior } else { diff --git a/http2/stream_test.go b/http2/stream_test.go new file mode 100644 index 0000000..6766e47 --- /dev/null +++ b/http2/stream_test.go @@ -0,0 +1,146 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http2 + +import ( + "crypto/tls" + "crypto/x509" + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" +) + +var startHTTP2ServerOnce sync.Once +var http2ServerAddr string +var http2Server *httptest.Server +func startHTTP2Server() { + mux := http.NewServeMux() + + mux.HandleFunc("/stream", func(w http.ResponseWriter, r *http.Request) { + writeFlusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "writer cannot be flushed", http.StatusInternalServerError) + return + } + + // Before begining any sort of streaming type behavior, we + // need to push some response headers so the client knows + // it is ok to start streaming. + w.WriteHeader(http.StatusOK) + writeFlusher.Flush() + + buf := make([]byte, 1024) + for { + nbytes, err := r.Body.Read(buf) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + _, err = w.Write(buf[:nbytes]) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + writeFlusher.Flush() + } + }) + + http2Server = httptest.NewUnstartedServer(mux) + + // Force the http server to use our patch http2 server rather than + // the one bundled in the stdlib. + ConfigureServer(http2Server.Config, nil) + + // tell the server to support HTTP/2 in the ALPN negotiation + http2Server.TLS = &tls.Config{ + NextProtos: []string{NextProtoTLS}, + } + + http2Server.StartTLS() + + http2ServerAddr = http2Server.Listener.Addr().String() +} + +func TestHTTP2Stream(t *testing.T) { + startHTTP2ServerOnce.Do(startHTTP2Server) + + client := makeClient(t) + + // NOTE: using this idiom will mean writes are not context + // safe. For the real websocket code, we need to make + // a wrapper that allows us to cancel the writes if + // our context gets canceled. This is fine for a POC + // though. + sr, sw := io.Pipe() + req, err := http.NewRequest("CONNECT", endpoint("/stream"), sr) + if err != nil { + t.Fatal(err) + } + + // TODO(ethan): This is a gross hack. Users shouldn't be setting + // psudo headers by setting things in the headers hashmap. + // I think the real solution here is to add a new `Protocol` + // field to the `http.Request` struct. + req.Header.Add("HACK-HTTP2-Protocol", "websocket") + + resp, err := client.Transport.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + + defer func() { + err = resp.Body.Close() + if err != nil { + t.Errorf("close resp body err: %s", err) + } + + err = sw.Close() + if err != nil { + t.Errorf("close stream writer err: %s", err) + } + }() + + for i := 0; i < 2; i++ { + _, err = sw.Write([]byte("ping")) + if err != nil { + t.Fatalf("write err: %s", err) + } + + buf := make([]byte, 64) + nbytes, err := resp.Body.Read(buf) + if err != nil { + t.Fatalf("read err: %s", err) + } + + if string(buf[:nbytes]) != "ping" { + t.Errorf("buf = %q, want 'ping'", string(buf[:nbytes])) + } + } +} + +func makeClient(t *testing.T) *http.Client { + t.Helper() + + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(http2Server.TLS.Certificates[0].Certificate[0]) + + conf := &tls.Config{ + InsecureSkipVerify: true, + } + + return &http.Client{ + Transport: &Transport{ + TLSClientConfig: conf, + }, + } +} + +func endpoint(path string) string { + return "https://" + http2ServerAddr + path +} diff --git a/http2/transport.go b/http2/transport.go index 4ded4df..b148c49 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -291,6 +291,9 @@ type ClientConn struct { // Lock reqmu BEFORE mu or wmu. reqHeaderMu chan struct{} + // true if the server responded with SETTINGS_ENABLE_CONNECT_PROTOCOL=1 + serverAllowsExtendedConnect bool + // wmu is held while writing. // Acquire BEFORE mu when holding both, to avoid blocking mu on network writes. // Only acquire both at the same time when changing peer settings. @@ -1118,6 +1121,14 @@ func (cc *ClientConn) decrStreamReservationsLocked() { } func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { + if req.Method == "CONNECT" && req.Header.Get("HACK-HTTP2-Protocol") != "" { + // This is an extended CONNECT https://datatracker.ietf.org/doc/html/rfc8441#section-4 + // We need to check if the server supports it. + if err := cc.checkServerSupportsExtendedConnect(); err != nil { + return nil, err + } + } + ctx := req.Context() cs := &clientStream{ cc: cc, @@ -1199,6 +1210,33 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { } } +func (cc *ClientConn) checkServerSupportsExtendedConnect() error { + if !cc.seenSettings { + // If we have not yet seen the server's settings frame, we + // are likely the first connection to this host. We should + // force the issue by sending a ping. Ping will block + // until we get the pong back or the connection's context gets + // canceled. + pingTimeout := cc.t.pingTimeout() + ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + defer cancel() + err := cc.Ping(ctx) + if err != nil { + return fmt.Errorf("http2: fetching server settings: %w", err) + } + + if !cc.seenSettings { + return errors.New("http2: refused to send settings frame") + } + } + + if !cc.serverAllowsExtendedConnect { + return errors.New("http2: server does not support extended connect") + } + + return nil +} + // doRequest runs for the duration of the request lifetime. // // It sends the request and performs post-request cleanup (closing Request.Body, etc.). @@ -1662,6 +1700,7 @@ func (cs *clientStream) writeRequestBody(req *http.Request) (err error) { return err } + cc.wmu.Lock() defer cc.wmu.Unlock() var trls []byte @@ -1744,8 +1783,10 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail return nil, err } + protocol := req.Header.Get("HACK-HTTP2-Protocol") + var path string - if req.Method != "CONNECT" { + if req.Method != "CONNECT" || (cc.serverAllowsExtendedConnect && protocol != "") { path = req.URL.RequestURI() if !validPseudoPath(path) { orig := path @@ -1787,10 +1828,15 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail m = http.MethodGet } f(":method", m) - if req.Method != "CONNECT" { + + if req.Method != "CONNECT" || (cc.serverAllowsExtendedConnect && protocol != "") { f(":path", path) f(":scheme", req.URL.Scheme) + if protocol != "" { + f(":protocol", protocol) + } } + if trailers != "" { f("trailer", trailers) } @@ -2709,6 +2755,8 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error { seenMaxConcurrentStreams = true case SettingMaxHeaderListSize: cc.peerMaxHeaderListSize = uint64(s.Val) + case SettingEnableConnectProtocol: + cc.serverAllowsExtendedConnect = s.Val == 1 case SettingInitialWindowSize: // Values above the maximum flow-control // window size of 2^31-1 MUST be treated as a