http2: don't return from RoundTrip until request body is closed

Moving the Request.Body.Close call out from the ClientConn mutex
results in some cases where RoundTrip returns while the Close is
still in progress. This should be legal (RoundTrip explicitly allows
for this), but net/http relies on Close never being called after
RoundTrip returns.

Add additional synchronization to ensure Close calls complete
before RoundTrip returns.

Fixes golang/go#55896

Change-Id: Ie3d4773966745e83987d219927929cb56ec1a7ad
Reviewed-on: https://go-review.googlesource.com/c/net/+/435535
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
This commit is contained in:
Damien Neil
2022-09-27 15:13:53 -07:00
parent f486391704
commit 107f3e3c3b

View File

@@ -345,8 +345,8 @@ type clientStream struct {
readErr error // sticky read error; owned by transportResponseBody.Read
reqBody io.ReadCloser
reqBodyContentLength int64 // -1 means unknown
reqBodyClosed bool // body has been closed; guarded by cc.mu
reqBodyContentLength int64 // -1 means unknown
reqBodyClosed chan struct{} // guarded by cc.mu; non-nil on Close, closed when done
// owned by writeRequest:
sentEndStream bool // sent an END_STREAM flag to the peer
@@ -376,46 +376,48 @@ func (cs *clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) error
}
func (cs *clientStream) abortStream(err error) {
var reqBody io.ReadCloser
defer func() {
if reqBody != nil {
reqBody.Close()
}
}()
cs.cc.mu.Lock()
defer cs.cc.mu.Unlock()
reqBody = cs.abortStreamLocked(err)
cs.abortStreamLocked(err)
}
func (cs *clientStream) abortStreamLocked(err error) io.ReadCloser {
func (cs *clientStream) abortStreamLocked(err error) {
cs.abortOnce.Do(func() {
cs.abortErr = err
close(cs.abort)
})
var reqBody io.ReadCloser
if cs.reqBody != nil && !cs.reqBodyClosed {
cs.reqBodyClosed = true
reqBody = cs.reqBody
if cs.reqBody != nil {
cs.closeReqBodyLocked()
}
// TODO(dneil): Clean up tests where cs.cc.cond is nil.
if cs.cc.cond != nil {
// Wake up writeRequestBody if it is waiting on flow control.
cs.cc.cond.Broadcast()
}
return reqBody
}
func (cs *clientStream) abortRequestBodyWrite() {
cc := cs.cc
cc.mu.Lock()
defer cc.mu.Unlock()
if cs.reqBody != nil && !cs.reqBodyClosed {
cs.reqBody.Close()
cs.reqBodyClosed = true
if cs.reqBody != nil && cs.reqBodyClosed == nil {
cs.closeReqBodyLocked()
cc.cond.Broadcast()
}
}
func (cs *clientStream) closeReqBodyLocked() {
if cs.reqBodyClosed != nil {
return
}
cs.reqBodyClosed = make(chan struct{})
reqBodyClosed := cs.reqBodyClosed
go func() {
cs.reqBody.Close()
close(reqBodyClosed)
}()
}
type stickyErrWriter struct {
conn net.Conn
timeout time.Duration
@@ -771,12 +773,6 @@ func (cc *ClientConn) SetDoNotReuse() {
}
func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
var reqBodiesToClose []io.ReadCloser
defer func() {
for _, reqBody := range reqBodiesToClose {
reqBody.Close()
}
}()
cc.mu.Lock()
defer cc.mu.Unlock()
@@ -793,10 +789,7 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
last := f.LastStreamID
for streamID, cs := range cc.streams {
if streamID > last {
reqBody := cs.abortStreamLocked(errClientConnGotGoAway)
if reqBody != nil {
reqBodiesToClose = append(reqBodiesToClose, reqBody)
}
cs.abortStreamLocked(errClientConnGotGoAway)
}
}
}
@@ -1049,19 +1042,11 @@ func (cc *ClientConn) sendGoAway() error {
func (cc *ClientConn) closeForError(err error) {
cc.mu.Lock()
cc.closed = true
var reqBodiesToClose []io.ReadCloser
for _, cs := range cc.streams {
reqBody := cs.abortStreamLocked(err)
if reqBody != nil {
reqBodiesToClose = append(reqBodiesToClose, reqBody)
}
cs.abortStreamLocked(err)
}
cc.cond.Broadcast()
cc.mu.Unlock()
for _, reqBody := range reqBodiesToClose {
reqBody.Close()
}
cc.closeConn()
}
@@ -1458,11 +1443,19 @@ func (cs *clientStream) cleanupWriteRequest(err error) {
// and in multiple cases: server replies <=299 and >299
// while still writing request body
cc.mu.Lock()
mustCloseBody := false
if cs.reqBody != nil && cs.reqBodyClosed == nil {
mustCloseBody = true
cs.reqBodyClosed = make(chan struct{})
}
bodyClosed := cs.reqBodyClosed
cs.reqBodyClosed = true
cc.mu.Unlock()
if !bodyClosed && cs.reqBody != nil {
if mustCloseBody {
cs.reqBody.Close()
close(bodyClosed)
}
if bodyClosed != nil {
<-bodyClosed
}
if err != nil && cs.sentEndStream {
@@ -1642,7 +1635,7 @@ func (cs *clientStream) writeRequestBody(req *http.Request) (err error) {
}
if err != nil {
cc.mu.Lock()
bodyClosed := cs.reqBodyClosed
bodyClosed := cs.reqBodyClosed != nil
cc.mu.Unlock()
switch {
case bodyClosed:
@@ -1737,7 +1730,7 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error)
if cc.closed {
return 0, errClientConnClosed
}
if cs.reqBodyClosed {
if cs.reqBodyClosed != nil {
return 0, errStopReqBodyWrite
}
select {
@@ -2110,24 +2103,17 @@ func (rl *clientConnReadLoop) cleanup() {
}
cc.closed = true
var reqBodiesToClose []io.ReadCloser
for _, cs := range cc.streams {
select {
case <-cs.peerClosed:
// The server closed the stream before closing the conn,
// so no need to interrupt it.
default:
reqBody := cs.abortStreamLocked(err)
if reqBody != nil {
reqBodiesToClose = append(reqBodiesToClose, reqBody)
}
cs.abortStreamLocked(err)
}
}
cc.cond.Broadcast()
cc.mu.Unlock()
for _, reqBody := range reqBodiesToClose {
reqBody.Close()
}
}
// countReadFrameError calls Transport.CountError with a string