mirror of
https://github.com/golang/net.git
synced 2026-03-31 18:37:08 +09:00
http2: don't return from RoundTrip until request body is closed
Moving the Request.Body.Close call out from the ClientConn mutex results in some cases where RoundTrip returns while the Close is still in progress. This should be legal (RoundTrip explicitly allows for this), but net/http relies on Close never being called after RoundTrip returns. Add additional synchronization to ensure Close calls complete before RoundTrip returns. Fixes golang/go#55896 Change-Id: Ie3d4773966745e83987d219927929cb56ec1a7ad Reviewed-on: https://go-review.googlesource.com/c/net/+/435535 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com> Run-TryBot: Damien Neil <dneil@google.com> TryBot-Result: Gopher Robot <gobot@golang.org>
This commit is contained in:
@@ -345,8 +345,8 @@ type clientStream struct {
|
||||
readErr error // sticky read error; owned by transportResponseBody.Read
|
||||
|
||||
reqBody io.ReadCloser
|
||||
reqBodyContentLength int64 // -1 means unknown
|
||||
reqBodyClosed bool // body has been closed; guarded by cc.mu
|
||||
reqBodyContentLength int64 // -1 means unknown
|
||||
reqBodyClosed chan struct{} // guarded by cc.mu; non-nil on Close, closed when done
|
||||
|
||||
// owned by writeRequest:
|
||||
sentEndStream bool // sent an END_STREAM flag to the peer
|
||||
@@ -376,46 +376,48 @@ func (cs *clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) error
|
||||
}
|
||||
|
||||
func (cs *clientStream) abortStream(err error) {
|
||||
var reqBody io.ReadCloser
|
||||
defer func() {
|
||||
if reqBody != nil {
|
||||
reqBody.Close()
|
||||
}
|
||||
}()
|
||||
cs.cc.mu.Lock()
|
||||
defer cs.cc.mu.Unlock()
|
||||
reqBody = cs.abortStreamLocked(err)
|
||||
cs.abortStreamLocked(err)
|
||||
}
|
||||
|
||||
func (cs *clientStream) abortStreamLocked(err error) io.ReadCloser {
|
||||
func (cs *clientStream) abortStreamLocked(err error) {
|
||||
cs.abortOnce.Do(func() {
|
||||
cs.abortErr = err
|
||||
close(cs.abort)
|
||||
})
|
||||
var reqBody io.ReadCloser
|
||||
if cs.reqBody != nil && !cs.reqBodyClosed {
|
||||
cs.reqBodyClosed = true
|
||||
reqBody = cs.reqBody
|
||||
if cs.reqBody != nil {
|
||||
cs.closeReqBodyLocked()
|
||||
}
|
||||
// TODO(dneil): Clean up tests where cs.cc.cond is nil.
|
||||
if cs.cc.cond != nil {
|
||||
// Wake up writeRequestBody if it is waiting on flow control.
|
||||
cs.cc.cond.Broadcast()
|
||||
}
|
||||
return reqBody
|
||||
}
|
||||
|
||||
func (cs *clientStream) abortRequestBodyWrite() {
|
||||
cc := cs.cc
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
if cs.reqBody != nil && !cs.reqBodyClosed {
|
||||
cs.reqBody.Close()
|
||||
cs.reqBodyClosed = true
|
||||
if cs.reqBody != nil && cs.reqBodyClosed == nil {
|
||||
cs.closeReqBodyLocked()
|
||||
cc.cond.Broadcast()
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *clientStream) closeReqBodyLocked() {
|
||||
if cs.reqBodyClosed != nil {
|
||||
return
|
||||
}
|
||||
cs.reqBodyClosed = make(chan struct{})
|
||||
reqBodyClosed := cs.reqBodyClosed
|
||||
go func() {
|
||||
cs.reqBody.Close()
|
||||
close(reqBodyClosed)
|
||||
}()
|
||||
}
|
||||
|
||||
type stickyErrWriter struct {
|
||||
conn net.Conn
|
||||
timeout time.Duration
|
||||
@@ -771,12 +773,6 @@ func (cc *ClientConn) SetDoNotReuse() {
|
||||
}
|
||||
|
||||
func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
|
||||
var reqBodiesToClose []io.ReadCloser
|
||||
defer func() {
|
||||
for _, reqBody := range reqBodiesToClose {
|
||||
reqBody.Close()
|
||||
}
|
||||
}()
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
|
||||
@@ -793,10 +789,7 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
|
||||
last := f.LastStreamID
|
||||
for streamID, cs := range cc.streams {
|
||||
if streamID > last {
|
||||
reqBody := cs.abortStreamLocked(errClientConnGotGoAway)
|
||||
if reqBody != nil {
|
||||
reqBodiesToClose = append(reqBodiesToClose, reqBody)
|
||||
}
|
||||
cs.abortStreamLocked(errClientConnGotGoAway)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1049,19 +1042,11 @@ func (cc *ClientConn) sendGoAway() error {
|
||||
func (cc *ClientConn) closeForError(err error) {
|
||||
cc.mu.Lock()
|
||||
cc.closed = true
|
||||
|
||||
var reqBodiesToClose []io.ReadCloser
|
||||
for _, cs := range cc.streams {
|
||||
reqBody := cs.abortStreamLocked(err)
|
||||
if reqBody != nil {
|
||||
reqBodiesToClose = append(reqBodiesToClose, reqBody)
|
||||
}
|
||||
cs.abortStreamLocked(err)
|
||||
}
|
||||
cc.cond.Broadcast()
|
||||
cc.mu.Unlock()
|
||||
for _, reqBody := range reqBodiesToClose {
|
||||
reqBody.Close()
|
||||
}
|
||||
cc.closeConn()
|
||||
}
|
||||
|
||||
@@ -1458,11 +1443,19 @@ func (cs *clientStream) cleanupWriteRequest(err error) {
|
||||
// and in multiple cases: server replies <=299 and >299
|
||||
// while still writing request body
|
||||
cc.mu.Lock()
|
||||
mustCloseBody := false
|
||||
if cs.reqBody != nil && cs.reqBodyClosed == nil {
|
||||
mustCloseBody = true
|
||||
cs.reqBodyClosed = make(chan struct{})
|
||||
}
|
||||
bodyClosed := cs.reqBodyClosed
|
||||
cs.reqBodyClosed = true
|
||||
cc.mu.Unlock()
|
||||
if !bodyClosed && cs.reqBody != nil {
|
||||
if mustCloseBody {
|
||||
cs.reqBody.Close()
|
||||
close(bodyClosed)
|
||||
}
|
||||
if bodyClosed != nil {
|
||||
<-bodyClosed
|
||||
}
|
||||
|
||||
if err != nil && cs.sentEndStream {
|
||||
@@ -1642,7 +1635,7 @@ func (cs *clientStream) writeRequestBody(req *http.Request) (err error) {
|
||||
}
|
||||
if err != nil {
|
||||
cc.mu.Lock()
|
||||
bodyClosed := cs.reqBodyClosed
|
||||
bodyClosed := cs.reqBodyClosed != nil
|
||||
cc.mu.Unlock()
|
||||
switch {
|
||||
case bodyClosed:
|
||||
@@ -1737,7 +1730,7 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error)
|
||||
if cc.closed {
|
||||
return 0, errClientConnClosed
|
||||
}
|
||||
if cs.reqBodyClosed {
|
||||
if cs.reqBodyClosed != nil {
|
||||
return 0, errStopReqBodyWrite
|
||||
}
|
||||
select {
|
||||
@@ -2110,24 +2103,17 @@ func (rl *clientConnReadLoop) cleanup() {
|
||||
}
|
||||
cc.closed = true
|
||||
|
||||
var reqBodiesToClose []io.ReadCloser
|
||||
for _, cs := range cc.streams {
|
||||
select {
|
||||
case <-cs.peerClosed:
|
||||
// The server closed the stream before closing the conn,
|
||||
// so no need to interrupt it.
|
||||
default:
|
||||
reqBody := cs.abortStreamLocked(err)
|
||||
if reqBody != nil {
|
||||
reqBodiesToClose = append(reqBodiesToClose, reqBody)
|
||||
}
|
||||
cs.abortStreamLocked(err)
|
||||
}
|
||||
}
|
||||
cc.cond.Broadcast()
|
||||
cc.mu.Unlock()
|
||||
for _, reqBody := range reqBodiesToClose {
|
||||
reqBody.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// countReadFrameError calls Transport.CountError with a string
|
||||
|
||||
Reference in New Issue
Block a user