From e898025ed96aa6d08e98132b8dca210e9e7a0cd2 Mon Sep 17 00:00:00 2001 From: Michael Fraenkel Date: Sun, 30 May 2021 12:26:36 -0600 Subject: [PATCH] http2: close the request body if needed As per client.Do and Request.Body, the transport is responsible to close the request Body. If there was an error or non 1xx/2xx status code, the transport will wait for the body writer to complete. If there is no data available to read, the body writer will block indefinitely. To prevent this, the body will be closed if it hasn't already. If there was a 1xx/2xx status code, the body will be closed eventually. Updates golang/go#43989 Change-Id: I9a4a5f13658122c562baf915e2c0c8992a023278 Reviewed-on: https://go-review.googlesource.com/c/net/+/323689 Reviewed-by: Damien Neil Trust: Damien Neil Trust: Alexander Rakoczy Run-TryBot: Damien Neil TryBot-Result: Go Bot --- http2/transport.go | 59 ++++++++++++++++++++--------------------- http2/transport_test.go | 45 +++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 30 deletions(-) diff --git a/http2/transport.go b/http2/transport.go index b97adff7..b261beb1 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -385,8 +385,13 @@ func (cs *clientStream) abortRequestBodyWrite(err error) { } cc := cs.cc cc.mu.Lock() - cs.stopReqBody = err - cc.cond.Broadcast() + if cs.stopReqBody == nil { + cs.stopReqBody = err + if cs.req.Body != nil { + cs.req.Body.Close() + } + cc.cond.Broadcast() + } cc.mu.Unlock() } @@ -1110,40 +1115,28 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf return res, false, nil } + handleError := func(err error) (*http.Response, bool, error) { + if !hasBody || bodyWritten { + cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) + } else { + bodyWriter.cancel() + cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) + <-bodyWriter.resc + } + cc.forgetStreamID(cs.ID) + return nil, cs.getStartedWrite(), err + } + for { select { case re := <-readLoopResCh: return handleReadLoopResponse(re) case <-respHeaderTimer: - if !hasBody || bodyWritten { - cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) - } else { - bodyWriter.cancel() - cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) - <-bodyWriter.resc - } - cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), errTimeout + return handleError(errTimeout) case <-ctx.Done(): - if !hasBody || bodyWritten { - cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) - } else { - bodyWriter.cancel() - cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) - <-bodyWriter.resc - } - cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), ctx.Err() + return handleError(ctx.Err()) case <-req.Cancel: - if !hasBody || bodyWritten { - cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) - } else { - bodyWriter.cancel() - cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) - <-bodyWriter.resc - } - cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), errRequestCanceled + return handleError(errRequestCanceled) case <-cs.peerReset: // processResetStream already removed the // stream from the streams map; no need for @@ -1290,7 +1283,13 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) ( // Request.Body is closed by the Transport, // and in multiple cases: server replies <=299 and >299 // while still writing request body - cerr := bodyCloser.Close() + var cerr error + cc.mu.Lock() + if cs.stopReqBody == nil { + cs.stopReqBody = errStopReqBodyWrite + cerr = bodyCloser.Close() + } + cc.mu.Unlock() if err == nil { err = cerr } diff --git a/http2/transport_test.go b/http2/transport_test.go index 750813b2..2da7d9de 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -4899,3 +4899,48 @@ func TestTransportServerResetStreamAtHeaders(t *testing.T) { } res.Body.Close() } + +type closeChecker struct { + io.ReadCloser + closed chan struct{} +} + +func (rc *closeChecker) Close() error { + close(rc.closed) + return rc.ReadCloser.Close() +} + +func TestTransportCloseRequestBody(t *testing.T) { + var statusCode int + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(statusCode) + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + ctx := context.Background() + cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) + if err != nil { + t.Fatal(err) + } + + for _, status := range []int{200, 401} { + t.Run(fmt.Sprintf("status=%d", status), func(t *testing.T) { + statusCode = status + pr, pw := io.Pipe() + pipeClosed := make(chan struct{}) + req, err := http.NewRequest("PUT", "https://dummy.tld/", &closeChecker{pr, pipeClosed}) + if err != nil { + t.Fatal(err) + } + res, err := cc.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + pw.Close() + <-pipeClosed + }) + } +}