From 98ff5ea2d99228defacd70e23f1127475e03302b Mon Sep 17 00:00:00 2001 From: spongehah <2635879218@qq.com> Date: Wed, 11 Sep 2024 13:22:10 +0800 Subject: [PATCH] WIP(x/net/http/client): Implement BodyChunk --- x/net/http/client.go | 2 +- x/net/http/request.go | 11 +- x/net/http/response.go | 72 ++- x/net/http/transfer.go | 28 +- x/net/http/transport.go | 948 ++++++++++++++++++---------------------- 5 files changed, 504 insertions(+), 557 deletions(-) diff --git a/x/net/http/client.go b/x/net/http/client.go index 002397a..fd6a705 100644 --- a/x/net/http/client.go +++ b/x/net/http/client.go @@ -307,7 +307,7 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d forkReq() } - // TODO(spongehah) timeout(send) + // TODO(spongehah) tmp timeout(send) //stopTimer, didTimeout := setRequestCancel(req, rt, deadline) req.timeoutch = make(chan struct{}, 1) req.deadline = deadline diff --git a/x/net/http/request.go b/x/net/http/request.go index c5146ed..e9279fc 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -294,7 +294,7 @@ func (r *Request) write(client *hyper.ClientConn, taskData *taskData, exec *hype //} // Prepare the hyper.Request - hyperReq, err := r.newHyperRequest(taskData.pc.isProxy, taskData.req.extra) + hyperReq, err := r.newHyperRequest(taskData.pc.isProxy, taskData.req.extra, taskData.req) if err != nil { return err } @@ -308,7 +308,7 @@ func (r *Request) write(client *hyper.ClientConn, taskData *taskData, exec *hype return err } -func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header) (*hyper.Request, error) { +func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header, treq *transportRequest) (*hyper.Request, error) { // Find the target host. Prefer the Host: header, but if that // is not given, use the host from the request URL. // @@ -401,11 +401,6 @@ func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header) (*hyper.R } // Process Body,ContentLength,Close,Trailer - //tw, err := newTransferWriter(r) - //if err != nil { - // return err - //} - //err = tw.writeHeader(w, trace) err = r.writeHeader(reqHeaders) if err != nil { return nil, err @@ -433,7 +428,7 @@ func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header) (*hyper.R } // Write body and trailer - err = r.writeBody(hyperReq) + err = r.writeBody(hyperReq, treq) if err != nil { return nil, err } diff --git a/x/net/http/response.go b/x/net/http/response.go index 6ff5b3d..7ad452b 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -32,7 +32,7 @@ func (r *Response) closeBody() { } } -func ReadResponse(r *io.PipeReader, req *Request, hyperResp *hyper.Response) (*Response, error) { +func ReadResponse(r *bodyChunk, req *Request, hyperResp *hyper.Response) (*Response, error) { resp := &Response{ Request: req, Header: make(Header), @@ -117,3 +117,73 @@ func (r *Response) bodyIsWritable() bool { _, ok := r.Body.(io.Writer) return ok } + +func (resp *Response) checkRespBody(taskData *taskData) (needContinue bool) { + pc := taskData.pc + bodyWritable := resp.bodyIsWritable() + hasBody := taskData.req.Method != "HEAD" && resp.ContentLength != 0 + + if resp.Close || taskData.req.Close || resp.StatusCode <= 199 || bodyWritable { + // Don't do keep-alive on error if either party requested a close + // or we get an unexpected informational (1xx) response. + // StatusCode 100 is already handled above. + pc.alive = false + } + + if !hasBody || bodyWritable { + replaced := pc.t.replaceReqCanceler(taskData.req.cancelKey, nil) + + // Put the idle conn back into the pool before we send the response + // so if they process it quickly and make another request, they'll + // get this same conn. But we use the unbuffered channel 'rc' + // to guarantee that persistConn.roundTrip got out of its select + // potentially waiting for this persistConn to close. + pc.alive = pc.alive && + replaced && pc.tryPutIdleConn() + + if bodyWritable { + pc.closeErr = errCallerOwnsConn + } + + select { + case taskData.resc <- responseAndError{res: resp}: + case <-taskData.callerGone: + readLoopDefer(pc, true) + return true + } + // Now that they've read from the unbuffered channel, they're safely + // out of the select that also waits on this goroutine to die, so + // we're allowed to exit now if needed (if alive is false) + readLoopDefer(pc, false) + return true + } + return false +} + +func (r *Response) wrapBodyEOFSignalAndGzip(taskData *taskData) { + body := &bodyEOFSignal{ + body: r.Body, + earlyCloseFn: func() error { + return nil + }, + fn: func(err error) error { + isEOF := err == io.EOF + if !isEOF { + if cerr := taskData.pc.canceled(); cerr != nil { + return cerr + } + } + return err + }, + } + r.Body = body + // TODO(spongehah) gzip(wrapBodyEOFSignal) + //if taskData.addedGzip && EqualFold(r.Header.Get("Content-Encoding"), "gzip") { + // println("gzip reader") + // r.Body = &gzipReader{body: body} + // r.Header.Del("Content-Encoding") + // r.Header.Del("Content-Length") + // r.ContentLength = -1 + // r.Uncompressed = true + //} +} diff --git a/x/net/http/transfer.go b/x/net/http/transfer.go index 103200c..bd19979 100644 --- a/x/net/http/transfer.go +++ b/x/net/http/transfer.go @@ -98,7 +98,7 @@ func (uste *unsupportedTEError) Error() string { } // msg is *Request or *Response. -func readTransfer(msg any, r *io.PipeReader) (err error) { +func readTransfer(msg any, r *bodyChunk) (err error) { t := &transferReader{RequestMethod: "GET"} // Unify input @@ -173,8 +173,6 @@ func readTransfer(msg any, r *io.PipeReader) (err error) { if isResponse && noResponseBodyExpected(t.RequestMethod) || !bodyAllowedForStatus(t.StatusCode) { t.Body = NoBody } else { - // TODO(spongehah) ChunkReader(readTransfer) - //t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close} t.Body = &body{src: r, hdr: msg, r: r, closing: t.Close} } case realLength == 0: @@ -665,15 +663,15 @@ func (req *Request) unwrapBody() io.Reader { return req.Body } -func (r *Request) writeBody(hyperReq *hyper.Request) error { +func (r *Request) writeBody(hyperReq *hyper.Request, treq *transportRequest) error { if r.Body != nil { var body = r.unwrapBody() hyperReqBody := hyper.NewBody() buf := make([]byte, defaultChunkSize) reqData := &bodyReq{ - body: body, - buf: buf, - closeBody: r.closeBody, + body: body, + buf: buf, + treq: treq, } hyperReqBody.SetUserdata(c.Pointer(reqData)) hyperReqBody.SetDataFunc(setPostData) @@ -683,9 +681,9 @@ func (r *Request) writeBody(hyperReq *hyper.Request) error { } type bodyReq struct { - body io.Reader - buf []byte - closeBody func() error + body io.Reader + buf []byte + treq *transportRequest } func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { @@ -694,10 +692,11 @@ func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.In if err != nil { if err == io.EOF { *chunk = nil - req.closeBody() + req.treq.closeBody() return hyper.PollReady } fmt.Println("error reading request body: ", err) + req.treq.setError(requestBodyReadError{err}) return hyper.PollError } if n > 0 { @@ -706,10 +705,11 @@ func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.In } if n == 0 { *chunk = nil - req.closeBody() + req.treq.closeBody() return hyper.PollReady } - req.closeBody() - fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) + req.treq.closeBody() + err = fmt.Errorf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) + req.treq.setError(requestBodyReadError{err}) return hyper.PollError } diff --git a/x/net/http/transport.go b/x/net/http/transport.go index 44d721d..a10a434 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -38,6 +38,7 @@ var DefaultTransport RoundTripper = &Transport{ // MaxIdleConnsPerHost. const DefaultMaxIdleConnsPerHost = 2 const debugSwitch = true +const debugReadWriteLoop = true // Debug switch provided for developers type Transport struct { idleMu sync.Mutex @@ -46,53 +47,24 @@ type Transport struct { idleConnWait map[connectMethodKey]wantConnQueue // waiting getConns idleLRU connLRU - altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme reqMu sync.Mutex reqCanceler map[cancelKey]func(error) - Proxy func(*Request) (*url.URL, error) + + altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme connsPerHostMu sync.Mutex connsPerHost map[connectMethodKey]int connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns - // DisableKeepAlives, if true, disables HTTP keep-alives and - // will only use the connection to the server for a single - // HTTP request. - // - // This is unrelated to the similarly named TCP keep-alives. - DisableKeepAlives bool - - // DisableCompression, if true, prevents the Transport from - // requesting compression with an "Accept-Encoding: gzip" - // request header when the Request contains no existing - // Accept-Encoding value. If the Transport requests gzip on - // its own and gets a gzipped response, it's transparently - // decoded in the Response.Body. However, if the user - // explicitly requested gzip it is not automatically - // uncompressed. - DisableCompression bool + Proxy func(*Request) (*url.URL, error) - // MaxIdleConns controls the maximum number of idle (keep-alive) - // connections across all hosts. Zero means no limit. - MaxIdleConns int + DisableKeepAlives bool + DisableCompression bool - // MaxIdleConnsPerHost, if non-zero, controls the maximum idle - // (keep-alive) connections to keep per-host. If zero, - // DefaultMaxIdleConnsPerHost is used. + MaxIdleConns int MaxIdleConnsPerHost int - - // MaxConnsPerHost optionally limits the total number of - // connections per host, including connections in the dialing, - // active, and idle states. On limit violation, dials will block. - // - // Zero means no limit. - MaxConnsPerHost int - - // IdleConnTimeout is the maximum amount of time an idle - // (keep-alive) connection will remain idle before closing - // itself. - // Zero means no limit. - IdleConnTimeout time.Duration + MaxConnsPerHost int + IdleConnTimeout time.Duration // libuv and hyper related loopInitOnce sync.Once @@ -516,14 +488,11 @@ func (t *Transport) alternateRoundTripper(req *Request) RoundTripper { // useRegisteredProtocol reports whether an alternate protocol (as registered // with Transport.RegisterProtocol) should be respected for this request. func (t *Transport) useRegisteredProtocol(req *Request) bool { - if req.URL.Scheme == "https" && req.requiresHTTP1() { - // If this request requires HTTP/1, don't use the - // "https" alternate protocol, which is used by the - // HTTP/2 code to take over requests if there's an - // existing cached HTTP/2 connection. - return false - } - return true + // If this request requires HTTP/1, don't use the + // "https" alternate protocol, which is used by the + // HTTP/2 code to take over requests if there's an + // existing cached HTTP/2 connection. + return !(req.URL.Scheme == "https" && req.requiresHTTP1()) } // CancelRequest cancels an in-flight request by closing its connection. @@ -586,11 +555,11 @@ func getMilliseconds(deadline time.Time) uint64 { func (t *Transport) RoundTrip(req *Request) (*Response, error) { if debugSwitch { - println("RoundTrip start") - defer println("RoundTrip end") + println("############### RoundTrip start") + defer println("############### RoundTrip end") } t.loopInitOnce.Do(func() { - println("init loop") + println("############### init loop") t.loop = libuv.LoopNew() t.async = &libuv.Async{} t.exec = hyper.NewExecutor() @@ -620,7 +589,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { req.timer.Start(onTimeout, getMilliseconds(req.deadline), 0) if debugSwitch { - println("timer start") + println("############### timer start") } didTimeout = func() bool { return req.timer.GetDueIn() == 0 } stopTimer = func() { @@ -628,7 +597,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { req.timer.Stop() (*libuv.Handle)(c.Pointer(req.timer)).Close(nil) if debugSwitch { - println("timer close") + println("############### timer close") } } } else { @@ -654,8 +623,8 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { func (t *Transport) doRoundTrip(req *Request) (*Response, error) { if debugSwitch { - println("doRoundTrip start") - defer println("doRoundTrip end") + println("############### doRoundTrip start") + defer println("############### doRoundTrip end") } //t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) //ctx := req.Context() @@ -715,7 +684,6 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { } for { - // TODO(spongehah) timeout(t.doRoundTrip) //select { //case <-ctx.Done(): // req.closeBody() @@ -766,7 +734,6 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { } // Failed. Clean up and determine whether to retry. - // TODO(spongehah) ConnPool(t.doRoundTrip) if http2isNoCachedConnError(err) { if t.removeIdleConn(pconn) { t.decConnsPerHost(pconn.cacheKey) @@ -800,8 +767,8 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persistConn, err error) { if debugSwitch { - println("getConn start") - defer println("getConn end") + println("############### getConn start") + defer println("############### getConn end") } req := treq.Request //trace := treq.trace @@ -824,7 +791,6 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi } }() - // TODO(spongehah) ConnPool(t.getConn) // Queue for idle connection. if delivered := t.queueForIdleConn(w); delivered { pc := w.pc @@ -853,28 +819,28 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // trace.GotConn(httptrace.GotConnInfo{Conn: w.pc.conn, Reused: w.pc.isReused()}) //} if w.err != nil { - // If the request has been canceled, that's probably - // what caused w.err; if so, prefer to return the - // cancellation error (see golang.org/issue/16049). - select { - // TODO(spongehah) timeout(t.getConn) - //case <-req.Cancel: - // return nil, errRequestCanceledConn - //case <-req.Context().Done(): - // return nil, req.Context().Err() - case <-req.timeoutch: - if debugSwitch { - println("getConn: timeoutch") - } - return nil, errors.New("timeout: req.Context().Err()") - case err := <-cancelc: - if err == errRequestCanceled { - err = errRequestCanceledConn - } - return nil, err - default: - // return below + return nil, w.err + } + // If the request has been canceled, that's probably + // what caused w.err; if so, prefer to return the + // cancellation error (see golang.org/issue/16049). + select { + //case <-req.Cancel: + // return nil, errRequestCanceledConn + //case <-req.Context().Done(): + // return nil, req.Context().Err() + case <-req.timeoutch: + if debugSwitch { + println("############### getConn: timeoutch") + } + return nil, errors.New("timeout: req.Context().Err()") + case err := <-cancelc: + if err == errRequestCanceled { + err = errRequestCanceledConn } + return nil, err + default: + // return below } return w.pc, w.err } @@ -883,8 +849,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // Once w receives permission to dial, it will do so in a separate goroutine. func (t *Transport) queueForDial(w *wantConn) { if debugSwitch { - println("queueForDial start") - defer println("queueForDial end") + println("############### queueForDial start") + defer println("############### queueForDial end") } w.beforeDial() @@ -919,13 +885,12 @@ func (t *Transport) queueForDial(w *wantConn) { // If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()]. func (t *Transport) dialConnFor(w *wantConn) { if debugSwitch { - println("dialConnFor start") - defer println("dialConnFor end") + println("############### dialConnFor start") + defer println("############### dialConnFor end") } defer w.afterDial() pc, err := t.dialConn(w.timeoutch, w.cm) - // TODO(spongehah) ConnPool(t.dialConnFor) delivered := w.tryDeliver(pc, err) // If the connection was successfully established but was not passed to w, // or is a shareable HTTP/2 connection @@ -994,8 +959,8 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn *persistConn, err error) { if debugSwitch { - println("dialConn start") - defer println("dialConn end") + println("############### dialConn start") + defer println("############### dialConn end") } select { case <-timeoutch: @@ -1102,6 +1067,22 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * // } //} + pconn.closeErr = errReadLoopExiting + pconn.tryPutIdleConn = func() bool { + if err := pconn.t.tryPutIdleConn(pconn); err != nil { + pconn.closeErr = err + // TODO(spongehah) trace(dialConn) + //if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { + // trace.PutIdleConn(err) + //} + return false + } + //if trace != nil && trace.PutIdleConn != nil { + // trace.PutIdleConn(nil) + //} + return true + } + select { case <-timeoutch: err = errors.New("[t.dialConn] request timeout") @@ -1114,8 +1095,8 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * func (t *Transport) dial(addr string) (*connData, error) { if debugSwitch { - println("dial start") - defer println("dial end") + println("############### dial start") + defer println("############### dial end") } host, port, err := net.SplitHostPort(addr) if err != nil { @@ -1150,12 +1131,11 @@ func (t *Transport) dial(addr string) (*connData, error) { func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { if debugSwitch { - println("roundTrip start") - defer println("roundTrip end") + println("############### roundTrip start") + defer println("############### roundTrip end") } testHookEnterRoundTrip() if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { - // TODO(spongehah) ConnPool(pc.roundTrip) pc.t.putOrCloseIdleConn(pc) return nil, errRequestCanceled } @@ -1168,40 +1148,8 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err headerFn(req.extraHeaders()) } - // Ask for a compressed version if the caller didn't set their - // own value for Accept-Encoding. We only attempt to - // uncompress the gzip stream if we were the layer that - // requested it. - requestedGzip := false - // TODO(spongehah) gzip(pc.roundTrip) - //if !pc.t.DisableCompression && - // req.Header.Get("Accept-Encoding") == "" && - // req.Header.Get("Range") == "" && - // req.Method != "HEAD" { - // // Request gzip only, not deflate. Deflate is ambiguous and - // // not as universally supported anyway. - // // See: https://zlib.net/zlib_faq.html#faq39 - // // - // // Note that we don't request this for HEAD requests, - // // due to a bug in nginx: - // // https://trac.nginx.org/nginx/ticket/358 - // // https://golang.org/issue/5522 - // // - // // We don't request gzip if the request is for a range, since - // // auto-decoding a portion of a gzipped document will just fail - // // anyway. See https://golang.org/issue/8923 - // requestedGzip = true - // req.extraHeaders().Set("Accept-Encoding", "gzip") - //} - - // The 100-continue operation in Hyper is handled in the newHyperRequest function. - - // Keep-Alive - if pc.t.DisableKeepAlives && - !req.wantsClose() && - !isProtocolSwitchHeader(req.Header) { - req.extraHeaders().Set("Connection", "close") - } + // Set extra headers, such as Accept-Encoding, Connection(Keep-Alive). + requestedGzip := pc.setExtraHeaders(req) gone := make(chan struct{}, 1) defer close(gone) @@ -1229,7 +1177,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err } if pc.client == nil && !pc.isReused() { - println("first") // Hookup the IO hyperIo := newIoWithConnReadWrite(pc.conn) // We need an executor generally to poll futures @@ -1243,7 +1190,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // Send the request to readWriteLoop(). pc.t.exec.Push(handshakeTask) } else { - println("second") taskData.taskId = read err = req.write(pc.client, taskData, pc.t.exec) if err != nil { @@ -1264,12 +1210,12 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err for { testHookWaitResLoop() if debugSwitch { - println("roundTrip for") + println("############### roundTrip for") } select { case err := <-writeErrCh: if debugSwitch { - println("roundTrip: writeErrch") + println("############### roundTrip: writeErrch") } if err != nil { pc.close(fmt.Errorf("write error: %w", err)) @@ -1278,17 +1224,9 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err } return nil, pc.mapRoundTripError(req, startBytesWritten, err) } - //if d := pc.t.ResponseHeaderTimeout; d > 0 { - // if debugRoundTrip { - // //req.logf("starting timer for %v", d) - // } - // timer := time.NewTimer(d) - // defer timer.Stop() // prevent leaks - // respHeaderTimer = timer.C - //} case <-pcClosed: if debugSwitch { - println("roundTrip: pcClosed") + println("############### roundTrip: pcClosed") } pcClosed = nil if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) { @@ -1297,7 +1235,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err //case <-respHeaderTimer: case re := <-resc: if debugSwitch { - println("roundTrip: resc") + println("############### roundTrip: resc") } if (re.res == nil) == (re.err == nil) { return nil, fmt.Errorf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil) @@ -1306,7 +1244,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) } return re.res, nil - // TODO(spongehah) timeout(pc.roundTrip) //case <-cancelChan: // canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled) // cancelChan = nil @@ -1316,7 +1253,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // ctxDoneChan = nil case <-timeoutch: if debugSwitch { - println("roundTrip: timeoutch") + println("############### roundTrip: timeoutch") } canceled = pc.t.cancelRequest(req.cancelKey, errors.New("timeout: req.Context().Err()")) timeoutch = nil @@ -1330,361 +1267,277 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err func readWriteLoop(checker *libuv.Check) { t := (*Transport)((*libuv.Handle)(c.Pointer(checker)).GetData()) - // Read this once, before loop starts. (to avoid races in tests) - //testHookMu.Lock() - //testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead - //testHookMu.Unlock() - - const debugReadWriteLoop = true // Debug switch provided for developers - - // The polling state machine! - // Poll all ready tasks and act on them... - for { - task := t.exec.Poll() + // The polling state machine! Poll all ready tasks and act on them... + task := t.exec.Poll() + for task != nil { if debugSwitch { - println("polling") - } - if task == nil { - return - } - taskData := (*taskData)(task.Userdata()) - var taskId taskId - if taskData != nil { - taskId = taskData.taskId - } else { - taskId = notSet + println("############### polling") } + t.handleTask(task) + task = t.exec.Poll() + } +} + +func (t *Transport) handleTask(task *hyper.Task) { + taskData := (*taskData)(task.Userdata()) + if taskData == nil { + // A background task for hyper_client completed... + task.Free() + return + } + var err error + pc := taskData.pc + // If original taskId is set, we need to check it + err = checkTaskType(task, taskData) + if err != nil { + readLoopDefer(pc, true) + return + } + switch taskData.taskId { + case handshake: if debugReadWriteLoop { - println("taskId: ", taskId) + println("############### write") } - switch taskId { - case handshake: - if debugReadWriteLoop { - println("write") - } - - err := checkTaskType(task, handshake) - if err != nil { - taskData.writeErrCh <- err - task.Free() - continue - } - - pc := taskData.pc - select { - case <-pc.closech: - task.Free() - continue - default: - } - pc.client = (*hyper.ClientConn)(task.Value()) + // Check if the connection is closed + select { + case <-pc.closech: task.Free() + return + default: + } - // TODO(spongehah) Proxy(writeLoop) - taskData.taskId = read - err = taskData.req.Request.write(pc.client, taskData, t.exec) - - if err != nil { - //pc.writeErrCh <- err // to the body reader, which might recycle us - taskData.writeErrCh <- err // to the roundTrip function - pc.close(err) - continue - } - - if debugReadWriteLoop { - println("write end") - } - case read: - if debugReadWriteLoop { - println("read") - } + pc.client = (*hyper.ClientConn)(task.Value()) + task.Free() - pc := taskData.pc - - err := checkTaskType(task, read) - if bre, ok := err.(requestBodyReadError); ok { - err = bre.error - // Errors reading from the user's - // Request.Body are high priority. - // Set it here before sending on the - // channels below or calling - // pc.close() which tears down - // connections and causes other - // errors. - taskData.req.setError(err) - } - if err != nil { - //pc.writeErrCh <- err // to the body reader, which might recycle us - taskData.writeErrCh <- err // to the roundTrip function - pc.close(err) - continue - } + // TODO(spongehah) Proxy(writeLoop) + taskData.taskId = read + err = taskData.req.Request.write(pc.client, taskData, t.exec) - if pc.closeErr == nil { - pc.closeErr = errReadLoopExiting - } - // TODO(spongehah) ConnPool(readWriteLoop) - if pc.tryPutIdleConn == nil { - pc.tryPutIdleConn = func() bool { - if err := pc.t.tryPutIdleConn(pc); err != nil { - pc.closeErr = err - // TODO(spongehah) trace(readWriteLoop) - //if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { - // trace.PutIdleConn(err) - //} - return false - } - //if trace != nil && trace.PutIdleConn != nil { - // trace.PutIdleConn(nil) - //} - return true - } - } + if err != nil { + //pc.writeErrCh <- err // to the body reader, which might recycle us + taskData.writeErrCh <- err // to the roundTrip function + pc.close(err) + return + } - // Take the results - hyperResp := (*hyper.Response)(task.Value()) - task.Free() + if debugReadWriteLoop { + println("############### write end") + } + case read: + if debugReadWriteLoop { + println("############### read") + } - pc.mu.Lock() - if pc.numExpectedResponses == 0 { - pc.readLoopPeekFailLocked(hyperResp, err) - pc.mu.Unlock() + // Take the results + hyperResp := (*hyper.Response)(task.Value()) + task.Free() - // defer - readLoopDefer(pc, t) - continue - } + //pc.mu.Lock() + if pc.numExpectedResponses == 0 { + pc.readLoopPeekFailLocked(hyperResp, err) pc.mu.Unlock() + readLoopDefer(pc, true) + return + } + //pc.mu.Unlock() - //trace := httptrace.ContextClientTrace(rc.req.Context()) - - var resp *Response - var respBody *hyper.Body - if err == nil { - var pr *io.PipeReader - pr, taskData.bodyWriter = io.Pipe() - resp, err = ReadResponse(pr, taskData.req.Request, hyperResp) - respBody = hyperResp.Body() - } else { - err = transportReadFromServerError{err} - pc.closeErr = err + var resp *Response + if err == nil { + async := &libuv.Async{} + async.SetData(c.Pointer(taskData)) + t.loop.Async(async, readyToRead) + bc := &bodyChunk{ + readCh: make(chan []byte, 1), + done: make(chan struct{}, 1), + async: async, } + pc.bodyChunk = bc + resp, err = ReadResponse(bc, taskData.req.Request, hyperResp) + taskData.hyperBody = hyperResp.Body() + } else { + err = transportReadFromServerError{err} + pc.closeErr = err + } - // No longer need the response - hyperResp.Free() + // No longer need the response + hyperResp.Free() - if err != nil { - select { - case taskData.resc <- responseAndError{err: err}: - case <-taskData.callerGone: - // defer - readLoopDefer(pc, t) - continue - } - // defer - readLoopDefer(pc, t) - continue + if err != nil { + select { + case taskData.resc <- responseAndError{err: err}: + case <-taskData.callerGone: + readLoopDefer(pc, true) + return } + readLoopDefer(pc, true) + return + } - pc.mu.Lock() - pc.numExpectedResponses-- - pc.mu.Unlock() + dataTask := taskData.hyperBody.Data() + taskData.taskId = readBodyChunk + dataTask.SetUserdata(c.Pointer(taskData)) + t.exec.Push(dataTask) - bodyWritable := resp.bodyIsWritable() - hasBody := taskData.req.Method != "HEAD" && resp.ContentLength != 0 + if !taskData.req.deadline.IsZero() { + (*timeoutData)((*libuv.Handle)(c.Pointer(taskData.req.timer)).GetData()).taskData = taskData + } - if resp.Close || taskData.req.Close || resp.StatusCode <= 199 || bodyWritable { - // Don't do keep-alive on error if either party requested a close - // or we get an unexpected informational (1xx) response. - // StatusCode 100 is already handled above. - pc.alive = false - } + //pc.mu.Lock() + pc.numExpectedResponses-- + //pc.mu.Unlock() - if !hasBody || bodyWritable { - replaced := pc.t.replaceReqCanceler(taskData.req.cancelKey, nil) - - // TODO(spongehah) ConnPool(readWriteLoop) - // Put the idle conn back into the pool before we send the response - // so if they process it quickly and make another request, they'll - // get this same conn. But we use the unbuffered channel 'rc' - // to guarantee that persistConn.roundTrip got out of its select - // potentially waiting for this persistConn to close. - pc.alive = pc.alive && - replaced && pc.tryPutIdleConn() - //pc.alive = pc.alive && - // !pc.sawEOF && - // pc.wroteRequest() && - // replaced && pc.tryPutIdleConn() - - if bodyWritable { - pc.closeErr = errCallerOwnsConn - } + needContinue := resp.checkRespBody(taskData) + if needContinue { + return + } - select { - case taskData.resc <- responseAndError{res: resp}: - case <-taskData.callerGone: - // defer - readLoopDefer(pc, t) - continue - } - // Now that they've read from the unbuffered channel, they're safely - // out of the select that also waits on this goroutine to die, so - // we're allowed to exit now if needed (if alive is false) - //testHookReadLoopBeforeNextRead() - if pc.alive == false { - // defer - readLoopDefer(pc, t) - } - continue - } + resp.wrapBodyEOFSignalAndGzip(taskData) - body := &bodyEOFSignal{ - body: resp.Body, - earlyCloseFn: func() error { - taskData.bodyWriter.Close() - return nil - }, - fn: func(err error) error { - isEOF := err == io.EOF - if !isEOF { - if cerr := pc.canceled(); cerr != nil { - return cerr - } - } - return err - }, - } - resp.Body = body - - // TODO(spongehah) gzip(pc.readWriteLoop) - //if taskData.addedGzip && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { - // println("gzip reader") - // resp.Body = &gzipReader{body: body} - // resp.Header.Del("Content-Encoding") - // resp.Header.Del("Content-Length") - // resp.ContentLength = -1 - // resp.Uncompressed = true - //} + // TODO(spongehah) select blocking(readWriteLoop) + //select { + //case taskData.resc <- responseAndError{res: resp}: + //case <-taskData.callerGone: + // // defer + // readLoopDefer(pc, true) + // continue + //} + select { + case <-taskData.callerGone: + readLoopDefer(pc, true) + return + default: + } + taskData.resc <- responseAndError{res: resp} - bodyForeachTask := respBody.Foreach(appendToResponseBody, c.Pointer(taskData.bodyWriter)) - taskData.taskId = readDone - bodyForeachTask.SetUserdata(c.Pointer(taskData)) - t.exec.Push(bodyForeachTask) - if taskData.req.timer != nil { - (*timeoutData)((*libuv.Handle)(c.Pointer(taskData.req.timer)).GetData()).taskData = taskData - } + if debugReadWriteLoop { + println("############### read end") + } + case readBodyChunk: + if debugReadWriteLoop { + println("############### readBodyChunk") + } - // TODO(spongehah) select blocking(readWriteLoop) - //select { - //case taskData.resc <- responseAndError{res: resp}: - //case <-taskData.callerGone: - // // defer - // readLoopDefer(pc, t) - // continue - //} - select { - case <-taskData.callerGone: - // defer - readLoopDefer(pc, t) - continue - default: - } - taskData.resc <- responseAndError{res: resp} + taskType := task.Type() + if taskType == hyper.TaskBuf { + chunk := (*hyper.Buf)(task.Value()) + chunkLen := chunk.Len() + bytes := unsafe.Slice(chunk.Bytes(), chunkLen) + // Free chunk and task + chunk.Free() + task.Free() + // Write to the channel + pc.bodyChunk.readCh <- bytes + println("############### send bytes") + println("############### readBodyChunk end 1") + return + } - if debugReadWriteLoop { - println("read end") - } - case readDone: - // A background task of reading the response body is completed - if debugReadWriteLoop { - println("readDone") - } - if taskData.bodyWriter != nil { - taskData.bodyWriter.Close() - } - checkTaskType(task, readDone) + // taskType == taskEmpty (check in checkTaskType) + task.Free() + taskData.hyperBody.Free() + taskData.hyperBody = nil + pc.bodyChunk.async.Close(nil) + close(pc.bodyChunk.readCh) + close(pc.bodyChunk.done) + replaced := t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool + pc.alive = pc.alive && + replaced && pc.tryPutIdleConn() - bodyEOF := task.Type() == hyper.TaskEmpty - // free the task - task.Free() + readLoopDefer(pc, false) - pc := taskData.pc - - replaced := t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool - // TODO(spongehah) ConnPool(readWriteLoop) - pc.alive = pc.alive && - bodyEOF && - replaced && pc.tryPutIdleConn() - //pc.alive = pc.alive && - // bodyEOF && - // !pc.sawEOF && - // pc.wroteRequest() && - // replaced && tryPutIdleConn(trace) - - // TODO(spongehah) timeout(t.readWriteLoop) - //case <-rw.rc.req.Cancel: - // pc.alive = false - // pc.t.CancelRequest(rw.rc.req) - //case <-rw.rc.req.Context().Done(): - // pc.alive = false - // pc.t.cancelRequest(rw.rc.cancelKey, rw.rc.req.Context().Err()) - //case <-pc.closech: - // pc.alive = false - //} + if debugReadWriteLoop { + println("############### readBodyChunk end 2") + } + } +} - //select { - //case <-taskData.req.timeoutch: - // continue - //case <-pc.closech: - // pc.alive = false - //default: - //} +func readyToRead(aysnc *libuv.Async) { + println("############### AsyncCb: readyToRead") + taskData := (*taskData)(aysnc.GetData()) + dataTask := taskData.hyperBody.Data() + dataTask.SetUserdata(c.Pointer(taskData)) + taskData.pc.t.exec.Push(dataTask) +} - if pc.alive == false { - // defer - readLoopDefer(pc, t) - } +type bodyChunk struct { + chunk []byte + readCh chan []byte + async *libuv.Async + done chan struct{} +} - //testHookReadLoopBeforeNextRead() - if debugReadWriteLoop { - println("readDone end") - } - case notSet: - // A background task for hyper_client completed... - task.Free() +func (bc *bodyChunk) Read(p []byte) (n int, err error) { + // If there are still unread chunks, read them first + if len(bc.chunk) > 0 { + n = copy(p, bc.chunk) + bc.chunk = bc.chunk[n:] + if len(bc.chunk) == 0 { + bc.async.Send() } + return n, nil + } + + // Attempt to read a new chunk from a channel + select { + case chunk, ok := <-bc.readCh: + if !ok { + // The channel has been closed, indicating that all data has been read + return 0, io.EOF + } + n = copy(p, chunk) + if n < len(chunk) { + // If the capacity of p is insufficient to hold the whole chunk, save the rest of the chunk + bc.chunk = chunk[n:] + } + bc.async.Send() + return n, nil + case <-bc.done: + // If the done channel is closed, the read needs to be terminated + return 0, io.EOF } } -func readLoopDefer(pc *persistConn, t *Transport) { +// readLoopDefer Replace the defer function of readLoop in stdlib +func readLoopDefer(pc *persistConn, force bool) { + if pc.alive == true && !force { + return + } pc.close(pc.closeErr) - // TODO(spongehah) ConnPool(readLoopDefer) - t.removeIdleConn(pc) + pc.t.removeIdleConn(pc) } // ---------------------------------------------------------- +type connData struct { + TcpHandle libuv.Tcp + ConnectReq libuv.Connect + ReadBuf libuv.Buf + ReadBufFilled uintptr + nwrite int64 // bytes written(Replaced from persistConn's nwrite) + ReadWaker *hyper.Waker + WriteWaker *hyper.Waker +} + type taskData struct { taskId taskId - bodyWriter *io.PipeWriter req *transportRequest pc *persistConn addedGzip bool writeErrCh chan error callerGone chan struct{} resc chan responseAndError + hyperBody *hyper.Body } -type connData struct { - TcpHandle libuv.Tcp - ConnectReq libuv.Connect - ReadBuf libuv.Buf - ReadBufFilled uintptr - nwrite int64 // bytes written(Replaced from persistConn's nwrite) - ReadWaker *hyper.Waker - WriteWaker *hyper.Waker -} +// taskId The unique identifier of the next task polled from the executor +type taskId c.Int + +const ( + handshake taskId = iota + 1 + read + readBodyChunk +) func (conn *connData) Close() error { if conn == nil { @@ -1709,8 +1562,8 @@ func (conn *connData) Close() error { // onConnect is the libuv callback for a successful connection func onConnect(req *libuv.Connect, status c.Int) { if debugSwitch { - println("connect start") - defer println("connect end") + println("############### connect start") + defer println("############### connect end") } conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) @@ -1723,8 +1576,6 @@ func onConnect(req *libuv.Connect, status c.Int) { // allocBuffer allocates a buffer for reading from a socket func allocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { - //conn := (*ConnData)(handle.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(handle)).data conn := (*connData)(handle.GetData()) if conn.ReadBuf.Base == nil { conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Malloc(suggestedSize)), c.Uint(suggestedSize)) @@ -1738,9 +1589,7 @@ func allocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { // onRead is the libuv callback for reading from a socket // This callback function is called when data is available to be read func onRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { - // Get the connection data associated with the stream conn := (*connData)((*libuv.Handle)(c.Pointer(stream)).GetData()) - // If data was read (nread > 0) if nread > 0 { // Update the amount of filled buffer @@ -1757,9 +1606,7 @@ func onRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { // readCallBack read callback function for Hyper library func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { - // Get the user data (connection data) conn := (*connData)(userdata) - // If there's data in the buffer if conn.ReadBufFilled > 0 { // Calculate how much data to copy (minimum of filled amount and requested amount) @@ -1793,9 +1640,7 @@ func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uin // onWrite is the libuv callback for writing to a socket // Callback function called after a write operation completes func onWrite(req *libuv.Write, status c.Int) { - // Get the connection data associated with the write request conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) - // If there's a pending write waker if conn.WriteWaker != nil { // Wake up the pending write operation @@ -1807,7 +1652,6 @@ func onWrite(req *libuv.Write, status c.Int) { // writeCallBack write callback function for Hyper library func writeCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { - // Get the user data (connection data) conn := (*connData)(userdata) // Create a libuv buffer initBuf := libuv.InitBuf((*c.Char)(c.Pointer(buf)), c.Uint(bufLen)) @@ -1838,8 +1682,8 @@ func writeCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui // onTimeout is the libuv callback for a timeout func onTimeout(timer *libuv.Timer) { if debugSwitch { - println("onTimeout start") - defer println("onTimeout end") + println("############### onTimeout start") + defer println("############### onTimeout end") } data := (*timeoutData)((*libuv.Handle)(c.Pointer(timer)).GetData()) close(data.timeoutch) @@ -1850,8 +1694,7 @@ func onTimeout(timer *libuv.Timer) { pc := taskData.pc pc.alive = false pc.t.cancelRequest(taskData.req.cancelKey, errors.New("timeout: req.Context().Err()")) - // defer - readLoopDefer(pc, pc.t) + readLoopDefer(pc, true) } } @@ -1864,62 +1707,53 @@ func newIoWithConnReadWrite(connData *connData) *hyper.Io { return hyperIo } -// taskId The unique identifier of the next task polled from the executor -type taskId c.Int - -const ( - notSet taskId = iota - handshake - read - readDone -) - // checkTaskType checks the task type -func checkTaskType(task *hyper.Task, curTaskId taskId) error { - switch curTaskId { - case handshake: - if task.Type() == hyper.TaskError { - log.Printf("[readWriteLoop::handshake]handshake task error!\n") - return fail((*hyper.Error)(task.Value())) - } - if task.Type() != hyper.TaskClientConn { - return fmt.Errorf("[readWriteLoop::handshake]unexpected task type\n") - } - return nil - case read: - if task.Type() == hyper.TaskError { - log.Printf("[readWriteLoop::read]write task error!\n") - return fail((*hyper.Error)(task.Value())) - } - if task.Type() != hyper.TaskResponse { - c.Printf(c.Str("[readWriteLoop::read]unexpected task type\n")) - return errors.New("[readWriteLoop::read]unexpected task type\n") +func checkTaskType(task *hyper.Task, taskData *taskData) (err error) { + curTaskId := taskData.taskId + taskType := task.Type() + if taskType == hyper.TaskError { + err = fail((*hyper.Error)(task.Value()), curTaskId) + } + if err == nil { + switch curTaskId { + case handshake: + if taskType != hyper.TaskClientConn { + err = errors.New("[readWriteLoop::handshake]unexpected task type\n") + } + case read: + if taskType != hyper.TaskResponse { + err = errors.New("[readWriteLoop::read]unexpected task type\n") + } + case readBodyChunk: + if taskType != hyper.TaskBuf && taskType != hyper.TaskEmpty { + err = errors.New("[readWriteLoop::readBodyChunk]unexpected task type\n") + } } - return nil - case readDone: - if task.Type() == hyper.TaskError { - log.Printf("[readWriteLoop::readDone]read response body error!\n") - return fail((*hyper.Error)(task.Value())) + } + if err != nil { + task.Free() + if curTaskId == handshake || curTaskId == read { + taskData.writeErrCh <- err + taskData.pc.close(err) } - return nil - case notSet: + taskData.pc.alive = false } - return errors.New("[readWriteLoop]unexpected task type\n") + return } // fail prints the error details and panics -func fail(err *hyper.Error) error { +func fail(err *hyper.Error, taskId taskId) error { if err != nil { - c.Printf(c.Str("[readWriteLoop]error code: %d\n"), err.Code()) + c.Printf(c.Str("[readWriteLoop(taskId: %d)]error code: %d\n"), taskId, err.Code()) // grab the error details var errBuf [256]c.Char errLen := err.Print((*uint8)(c.Pointer(&errBuf[:][0])), uintptr(len(errBuf))) - c.Printf(c.Str("[readWriteLoop]details: %.*s\n"), c.Int(errLen), c.Pointer(&errBuf[:][0])) + c.Printf(c.Str("[readWriteLoop(taskId: %d)]details: %.*s\n"), taskId, c.Int(errLen), c.Pointer(&errBuf[:][0])) // clean up the error err.Free() - return fmt.Errorf("[readWriteLoop]hyper request error, error code: %d\n", int(err.Code())) + return fmt.Errorf("[readWriteLoop(taskId: %d)]hyper request error, The two lines above show the error code and error details", taskId) } return nil } @@ -1928,15 +1762,14 @@ func fail(err *hyper.Error) error { // error values for debugging and testing, not seen by users. var ( - errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled") - errConnBroken = errors.New("http: putIdleConn: connection is in bad state") - errCloseIdle = errors.New("http: putIdleConn: CloseIdleConnections was called") - errTooManyIdle = errors.New("http: putIdleConn: too many idle connections") - errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host") - errCloseIdleConns = errors.New("http: CloseIdleConnections called") - errReadLoopExiting = errors.New("http: Transport.readWriteLoop.read exiting") - errReadWriteLoopExiting = errors.New("http: Transport.readWriteLoop exiting") - errIdleConnTimeout = errors.New("http: idle connection timeout") + errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled") + errConnBroken = errors.New("http: putIdleConn: connection is in bad state") + errCloseIdle = errors.New("http: putIdleConn: CloseIdleConnections was called") + errTooManyIdle = errors.New("http: putIdleConn: too many idle connections") + errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host") + errCloseIdleConns = errors.New("http: CloseIdleConnections called") + errReadLoopExiting = errors.New("http: Transport.readWriteLoop.read exiting") + errIdleConnTimeout = errors.New("http: idle connection timeout") // errServerClosedIdle is not seen by users for idempotent requests, but may be // seen by a user if the server shuts down an idle connection and sends its FIN @@ -1971,14 +1804,6 @@ func (e *httpError) Error() string { return e.err } func (e *httpError) Timeout() bool { return e.timeout } func (e *httpError) Temporary() bool { return true } -// fakeLocker is a sync.Locker which does nothing. It's used to guard -// test-only fields when not under test, to avoid runtime atomic -// overhead. -type fakeLocker struct{} - -func (fakeLocker) Lock() {} -func (fakeLocker) Unlock() {} - // nothingWrittenError wraps a write errors which ended up writing zero bytes. type nothingWrittenError struct { error @@ -2014,9 +1839,6 @@ var ( testHookRoundTripRetried = nop testHookPrePendingDial = nop testHookPostPendingDial = nop - - testHookMu sync.Locker = fakeLocker{} // guards following - testHookReadLoopBeforeNextRead = nop ) var portMap = map[string]string{ @@ -2076,10 +1898,34 @@ type persistConn struct { mutateHeaderFunc func(Header) // other - alive bool // Replace the alive in readLoop - closeErr error // Replace the closeErr in readLoop - tryPutIdleConn func() bool // Replace the tryPutIdleConn in readLoop - client *hyper.ClientConn + alive bool // Replace the alive in readLoop + closeErr error // Replace the closeErr in readLoop + tryPutIdleConn func() bool // Replace the tryPutIdleConn in readLoop + client *hyper.ClientConn // http long connection client handle + bodyChunk *bodyChunk // Implement non-blocking consumption of each responseBody chunk +} + +// CloseIdleConnections closes any connections which were previously +// connected from previous requests but are now sitting idle in +// a "keep-alive" state. It does not interrupt any connections currently +// in use. +func (t *Transport) CloseIdleConnections() { + // TODO(spongehah) http2 + //t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) + t.idleMu.Lock() + m := t.idleConn + t.idleConn = nil + t.closeIdle = true // close newly idle connections + t.idleLRU = connLRU{} + t.idleMu.Unlock() + for _, conns := range m { + for _, pconn := range conns { + pconn.close(errCloseIdleConns) + } + } + //if t2 := t.h2transport; t2 != nil { + // t2.CloseIdleConnections() + //} } func (pc *persistConn) cancelRequest(err error) { @@ -2110,7 +1956,7 @@ func (pc *persistConn) markReused() { func (pc *persistConn) closeLocked(err error) { if debugSwitch { - println("pc closed") + println("############### pc closed") } if err == nil { panic("nil error") @@ -2256,13 +2102,11 @@ func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool { // the 1st response byte from the server. return true } - if err == errServerClosedIdle { - // The server replied with io.EOF while we were trying to - // read the response. Probably an unfortunately keep-alive - // timeout, just as the client was writing a request. - return true - } - return false // conservatively + // The server replied with io.EOF while we were trying to + // read the response. Probably an unfortunately keep-alive + // timeout, just as the client was writing a request. + // conservatively return false. + return err == errServerClosedIdle } // closeConnIfStillIdle closes the connection if it's still sitting idle. @@ -2300,6 +2144,45 @@ func (pc *persistConn) readLoopPeekFailLocked(resp *hyper.Response, err error) { pc.closeLocked(fmt.Errorf("readLoopPeekFailLocked: %w", err)) } +// setExtraHeaders Set extra headers, such as Accept-Encoding, Connection(Keep-Alive). +func (pc *persistConn) setExtraHeaders(req *transportRequest) bool { + // Ask for a compressed version if the caller didn't set their + // own value for Accept-Encoding. We only attempt to + // uncompress the gzip stream if we were the layer that + // requested it. + requestedGzip := false + // TODO(spongehah) gzip(pc.roundTrip) + //if !pc.t.DisableCompression && + // req.Header.Get("Accept-Encoding") == "" && + // req.Header.Get("Range") == "" && + // req.Method != "HEAD" { + // // Request gzip only, not deflate. Deflate is ambiguous and + // // not as universally supported anyway. + // // See: https://zlib.net/zlib_faq.html#faq39 + // // + // // Note that we don't request this for HEAD requests, + // // due to a bug in nginx: + // // https://trac.nginx.org/nginx/ticket/358 + // // https://golang.org/issue/5522 + // // + // // We don't request gzip if the request is for a range, since + // // auto-decoding a portion of a gzipped document will just fail + // // anyway. See https://golang.org/issue/8923 + // requestedGzip = true + // req.extraHeaders().Set("Accept-Encoding", "gzip") + //} + + // The 100-continue operation in Hyper is handled in the newHyperRequest function. + + // Keep-Alive + if pc.t.DisableKeepAlives && + !req.wantsClose() && + !isProtocolSwitchHeader(req.Header) { + req.extraHeaders().Set("Connection", "close") + } + return requestedGzip +} + func is408Message(resp *hyper.Response) bool { httpVersion := int(resp.Version()) if httpVersion != 10 && httpVersion != 11 { @@ -2435,7 +2318,6 @@ func (w *wantConn) cancel(t *Transport, err error) { w.err = err w.mu.Unlock() - // TODO(spongehah) ConnPool(w.cancel) if pc != nil { t.putOrCloseIdleConn(pc) }