diff --git a/http2/transport.go b/http2/transport.go index 9a874f7b..52991f32 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -345,8 +345,8 @@ type clientStream struct { readErr error // sticky read error; owned by transportResponseBody.Read reqBody io.ReadCloser - reqBodyContentLength int64 // -1 means unknown - reqBodyClosed bool // body has been closed; guarded by cc.mu + reqBodyContentLength int64 // -1 means unknown + reqBodyClosed chan struct{} // guarded by cc.mu; non-nil on Close, closed when done // owned by writeRequest: sentEndStream bool // sent an END_STREAM flag to the peer @@ -376,46 +376,48 @@ func (cs *clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) error } func (cs *clientStream) abortStream(err error) { - var reqBody io.ReadCloser - defer func() { - if reqBody != nil { - reqBody.Close() - } - }() cs.cc.mu.Lock() defer cs.cc.mu.Unlock() - reqBody = cs.abortStreamLocked(err) + cs.abortStreamLocked(err) } -func (cs *clientStream) abortStreamLocked(err error) io.ReadCloser { +func (cs *clientStream) abortStreamLocked(err error) { cs.abortOnce.Do(func() { cs.abortErr = err close(cs.abort) }) - var reqBody io.ReadCloser - if cs.reqBody != nil && !cs.reqBodyClosed { - cs.reqBodyClosed = true - reqBody = cs.reqBody + if cs.reqBody != nil { + cs.closeReqBodyLocked() } // TODO(dneil): Clean up tests where cs.cc.cond is nil. if cs.cc.cond != nil { // Wake up writeRequestBody if it is waiting on flow control. cs.cc.cond.Broadcast() } - return reqBody } func (cs *clientStream) abortRequestBodyWrite() { cc := cs.cc cc.mu.Lock() defer cc.mu.Unlock() - if cs.reqBody != nil && !cs.reqBodyClosed { - cs.reqBody.Close() - cs.reqBodyClosed = true + if cs.reqBody != nil && cs.reqBodyClosed == nil { + cs.closeReqBodyLocked() cc.cond.Broadcast() } } +func (cs *clientStream) closeReqBodyLocked() { + if cs.reqBodyClosed != nil { + return + } + cs.reqBodyClosed = make(chan struct{}) + reqBodyClosed := cs.reqBodyClosed + go func() { + cs.reqBody.Close() + close(reqBodyClosed) + }() +} + type stickyErrWriter struct { conn net.Conn timeout time.Duration @@ -771,12 +773,6 @@ func (cc *ClientConn) SetDoNotReuse() { } func (cc *ClientConn) setGoAway(f *GoAwayFrame) { - var reqBodiesToClose []io.ReadCloser - defer func() { - for _, reqBody := range reqBodiesToClose { - reqBody.Close() - } - }() cc.mu.Lock() defer cc.mu.Unlock() @@ -793,10 +789,7 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) { last := f.LastStreamID for streamID, cs := range cc.streams { if streamID > last { - reqBody := cs.abortStreamLocked(errClientConnGotGoAway) - if reqBody != nil { - reqBodiesToClose = append(reqBodiesToClose, reqBody) - } + cs.abortStreamLocked(errClientConnGotGoAway) } } } @@ -1049,19 +1042,11 @@ func (cc *ClientConn) sendGoAway() error { func (cc *ClientConn) closeForError(err error) { cc.mu.Lock() cc.closed = true - - var reqBodiesToClose []io.ReadCloser for _, cs := range cc.streams { - reqBody := cs.abortStreamLocked(err) - if reqBody != nil { - reqBodiesToClose = append(reqBodiesToClose, reqBody) - } + cs.abortStreamLocked(err) } cc.cond.Broadcast() cc.mu.Unlock() - for _, reqBody := range reqBodiesToClose { - reqBody.Close() - } cc.closeConn() } @@ -1458,11 +1443,19 @@ func (cs *clientStream) cleanupWriteRequest(err error) { // and in multiple cases: server replies <=299 and >299 // while still writing request body cc.mu.Lock() + mustCloseBody := false + if cs.reqBody != nil && cs.reqBodyClosed == nil { + mustCloseBody = true + cs.reqBodyClosed = make(chan struct{}) + } bodyClosed := cs.reqBodyClosed - cs.reqBodyClosed = true cc.mu.Unlock() - if !bodyClosed && cs.reqBody != nil { + if mustCloseBody { cs.reqBody.Close() + close(bodyClosed) + } + if bodyClosed != nil { + <-bodyClosed } if err != nil && cs.sentEndStream { @@ -1642,7 +1635,7 @@ func (cs *clientStream) writeRequestBody(req *http.Request) (err error) { } if err != nil { cc.mu.Lock() - bodyClosed := cs.reqBodyClosed + bodyClosed := cs.reqBodyClosed != nil cc.mu.Unlock() switch { case bodyClosed: @@ -1737,7 +1730,7 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) if cc.closed { return 0, errClientConnClosed } - if cs.reqBodyClosed { + if cs.reqBodyClosed != nil { return 0, errStopReqBodyWrite } select { @@ -2110,24 +2103,17 @@ func (rl *clientConnReadLoop) cleanup() { } cc.closed = true - var reqBodiesToClose []io.ReadCloser for _, cs := range cc.streams { select { case <-cs.peerClosed: // The server closed the stream before closing the conn, // so no need to interrupt it. default: - reqBody := cs.abortStreamLocked(err) - if reqBody != nil { - reqBodiesToClose = append(reqBodiesToClose, reqBody) - } + cs.abortStreamLocked(err) } } cc.cond.Broadcast() cc.mu.Unlock() - for _, reqBody := range reqBodiesToClose { - reqBody.Close() - } } // countReadFrameError calls Transport.CountError with a string