mirror of
https://github.com/golang/net.git
synced 2026-03-31 18:37:08 +09:00
http2: don't reuse connections that are experiencing errors
When a request on a connection fails to complete successfully, mark the conn as doNotReuse. It's possible for requests to fail for reasons unrelated to connection health, but opening a new connection unnecessarily is less of an impact than reusing a dead connection. Fixes golang/go#59690 Change-Id: I40bf6cefae602ead70c3bcf2fe573cc13f34a385 Reviewed-on: https://go-review.googlesource.com/c/net/+/486156 Run-TryBot: Damien Neil <dneil@google.com> TryBot-Result: Gopher Robot <gobot@golang.org> Reviewed-by: Bryan Mills <bcmills@google.com>
This commit is contained in:
@@ -1266,6 +1266,27 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
cancelRequest := func(cs *clientStream, err error) error {
|
||||
cs.cc.mu.Lock()
|
||||
defer cs.cc.mu.Unlock()
|
||||
cs.abortStreamLocked(err)
|
||||
if cs.ID != 0 {
|
||||
// This request may have failed because of a problem with the connection,
|
||||
// or for some unrelated reason. (For example, the user might have canceled
|
||||
// the request without waiting for a response.) Mark the connection as
|
||||
// not reusable, since trying to reuse a dead connection is worse than
|
||||
// unnecessarily creating a new one.
|
||||
//
|
||||
// If cs.ID is 0, then the request was never allocated a stream ID and
|
||||
// whatever went wrong was unrelated to the connection. We might have
|
||||
// timed out waiting for a stream slot when StrictMaxConcurrentStreams
|
||||
// is set, for example, in which case retrying on a different connection
|
||||
// will not help.
|
||||
cs.cc.doNotReuse = true
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-cs.respHeaderRecv:
|
||||
@@ -1280,15 +1301,12 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return handleResponseHeaders()
|
||||
default:
|
||||
waitDone()
|
||||
return nil, cs.abortErr
|
||||
return nil, cancelRequest(cs, cs.abortErr)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
err := ctx.Err()
|
||||
cs.abortStream(err)
|
||||
return nil, err
|
||||
return nil, cancelRequest(cs, ctx.Err())
|
||||
case <-cs.reqCancel:
|
||||
cs.abortStream(errRequestCanceled)
|
||||
return nil, errRequestCanceled
|
||||
return nil, cancelRequest(cs, errRequestCanceled)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -775,7 +775,6 @@ func newClientTester(t *testing.T) *clientTester {
|
||||
cc, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
}
|
||||
sc, err := ln.Accept()
|
||||
if err != nil {
|
||||
@@ -1765,6 +1764,18 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) {
|
||||
defer tr.CloseIdleConnections()
|
||||
|
||||
checkRoundTrip := func(req *http.Request, wantErr error, desc string) {
|
||||
// Make an arbitrary request to ensure we get the server's
|
||||
// settings frame and initialize peerMaxHeaderListSize.
|
||||
req0, err := http.NewRequest("GET", st.ts.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("newRequest: NewRequest: %v", err)
|
||||
}
|
||||
res0, err := tr.RoundTrip(req0)
|
||||
if err != nil {
|
||||
t.Errorf("%v: Initial RoundTrip err = %v", desc, err)
|
||||
}
|
||||
res0.Body.Close()
|
||||
|
||||
res, err := tr.RoundTrip(req)
|
||||
if err != wantErr {
|
||||
if res != nil {
|
||||
@@ -1825,13 +1836,9 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) {
|
||||
return req
|
||||
}
|
||||
|
||||
// Make an arbitrary request to ensure we get the server's
|
||||
// settings frame and initialize peerMaxHeaderListSize.
|
||||
// Validate peerMaxHeaderListSize.
|
||||
req := newRequest()
|
||||
checkRoundTrip(req, nil, "Initial request")
|
||||
|
||||
// Get the ClientConn associated with the request and validate
|
||||
// peerMaxHeaderListSize.
|
||||
addr := authorityAddr(req.URL.Scheme, req.URL.Host)
|
||||
cc, err := tr.connPool().GetClientConn(req, addr)
|
||||
if err != nil {
|
||||
@@ -3738,35 +3745,33 @@ func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.D
|
||||
ct.run()
|
||||
}
|
||||
|
||||
func TestTransportRetryAfterGOAWAY(t *testing.T) {
|
||||
var dialer struct {
|
||||
sync.Mutex
|
||||
count int
|
||||
}
|
||||
ct1 := make(chan *clientTester)
|
||||
ct2 := make(chan *clientTester)
|
||||
|
||||
func testClientMultipleDials(t *testing.T, client func(*Transport), server func(int, *clientTester)) {
|
||||
ln := newLocalListener(t)
|
||||
defer ln.Close()
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
count int
|
||||
conns []net.Conn
|
||||
)
|
||||
var wg sync.WaitGroup
|
||||
tr := &Transport{
|
||||
TLSClientConfig: tlsConfigInsecure,
|
||||
}
|
||||
tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
|
||||
dialer.Lock()
|
||||
defer dialer.Unlock()
|
||||
dialer.count++
|
||||
if dialer.count == 3 {
|
||||
return nil, errors.New("unexpected number of dials")
|
||||
}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
count++
|
||||
cc, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial error: %v", err)
|
||||
}
|
||||
conns = append(conns, cc)
|
||||
sc, err := ln.Accept()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("accept error: %v", err)
|
||||
}
|
||||
conns = append(conns, sc)
|
||||
ct := &clientTester{
|
||||
t: t,
|
||||
tr: tr,
|
||||
@@ -3774,19 +3779,26 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) {
|
||||
sc: sc,
|
||||
fr: NewFramer(sc, sc),
|
||||
}
|
||||
switch dialer.count {
|
||||
case 1:
|
||||
ct1 <- ct
|
||||
case 2:
|
||||
ct2 <- ct
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(count int) {
|
||||
defer wg.Done()
|
||||
server(count, ct)
|
||||
sc.Close()
|
||||
}(count)
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
errs := make(chan error, 3)
|
||||
client(tr)
|
||||
tr.CloseIdleConnections()
|
||||
ln.Close()
|
||||
for _, c := range conns {
|
||||
c.Close()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Client.
|
||||
go func() {
|
||||
func TestTransportRetryAfterGOAWAY(t *testing.T) {
|
||||
client := func(tr *Transport) {
|
||||
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
|
||||
res, err := tr.RoundTrip(req)
|
||||
if res != nil {
|
||||
@@ -3796,102 +3808,76 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) {
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
err = fmt.Errorf("RoundTrip: %v", err)
|
||||
}
|
||||
errs <- err
|
||||
}()
|
||||
|
||||
connToClose := make(chan io.Closer, 2)
|
||||
|
||||
// Server for the first request.
|
||||
go func() {
|
||||
ct := <-ct1
|
||||
|
||||
connToClose <- ct.cc
|
||||
ct.greet()
|
||||
hf, err := ct.firstHeaders()
|
||||
if err != nil {
|
||||
errs <- fmt.Errorf("server1 failed reading HEADERS: %v", err)
|
||||
return
|
||||
}
|
||||
t.Logf("server1 got %v", hf)
|
||||
if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil {
|
||||
errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err)
|
||||
return
|
||||
}
|
||||
errs <- nil
|
||||
}()
|
||||
|
||||
// Server for the second request.
|
||||
go func() {
|
||||
ct := <-ct2
|
||||
|
||||
connToClose <- ct.cc
|
||||
ct.greet()
|
||||
hf, err := ct.firstHeaders()
|
||||
if err != nil {
|
||||
errs <- fmt.Errorf("server2 failed reading HEADERS: %v", err)
|
||||
return
|
||||
}
|
||||
t.Logf("server2 got %v", hf)
|
||||
|
||||
var buf bytes.Buffer
|
||||
enc := hpack.NewEncoder(&buf)
|
||||
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
|
||||
enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
|
||||
err = ct.fr.WriteHeaders(HeadersFrameParam{
|
||||
StreamID: hf.StreamID,
|
||||
EndHeaders: true,
|
||||
EndStream: false,
|
||||
BlockFragment: buf.Bytes(),
|
||||
})
|
||||
if err != nil {
|
||||
errs <- fmt.Errorf("server2 failed writing response HEADERS: %v", err)
|
||||
} else {
|
||||
errs <- nil
|
||||
}
|
||||
}()
|
||||
|
||||
for k := 0; k < 3; k++ {
|
||||
err := <-errs
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
t.Errorf("RoundTrip: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
close(connToClose)
|
||||
for c := range connToClose {
|
||||
c.Close()
|
||||
server := func(count int, ct *clientTester) {
|
||||
switch count {
|
||||
case 1:
|
||||
ct.greet()
|
||||
hf, err := ct.firstHeaders()
|
||||
if err != nil {
|
||||
t.Errorf("server1 failed reading HEADERS: %v", err)
|
||||
return
|
||||
}
|
||||
t.Logf("server1 got %v", hf)
|
||||
if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil {
|
||||
t.Errorf("server1 failed writing GOAWAY: %v", err)
|
||||
return
|
||||
}
|
||||
case 2:
|
||||
ct.greet()
|
||||
hf, err := ct.firstHeaders()
|
||||
if err != nil {
|
||||
t.Errorf("server2 failed reading HEADERS: %v", err)
|
||||
return
|
||||
}
|
||||
t.Logf("server2 got %v", hf)
|
||||
|
||||
var buf bytes.Buffer
|
||||
enc := hpack.NewEncoder(&buf)
|
||||
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
|
||||
enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
|
||||
err = ct.fr.WriteHeaders(HeadersFrameParam{
|
||||
StreamID: hf.StreamID,
|
||||
EndHeaders: true,
|
||||
EndStream: false,
|
||||
BlockFragment: buf.Bytes(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("server2 failed writing response HEADERS: %v", err)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected number of dials")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
testClientMultipleDials(t, client, server)
|
||||
}
|
||||
|
||||
func TestTransportRetryAfterRefusedStream(t *testing.T) {
|
||||
clientDone := make(chan struct{})
|
||||
ct := newClientTester(t)
|
||||
ct.client = func() error {
|
||||
defer ct.cc.(*net.TCPConn).CloseWrite()
|
||||
if runtime.GOOS == "plan9" {
|
||||
// CloseWrite not supported on Plan 9; Issue 17906
|
||||
defer ct.cc.(*net.TCPConn).Close()
|
||||
}
|
||||
client := func(tr *Transport) {
|
||||
defer close(clientDone)
|
||||
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
|
||||
resp, err := ct.tr.RoundTrip(req)
|
||||
resp, err := tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("RoundTrip: %v", err)
|
||||
t.Errorf("RoundTrip: %v", err)
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != 204 {
|
||||
return fmt.Errorf("Status = %v; want 204", resp.StatusCode)
|
||||
t.Errorf("Status = %v; want 204", resp.StatusCode)
|
||||
return
|
||||
}
|
||||
return nil
|
||||
}
|
||||
ct.server = func() error {
|
||||
|
||||
server := func(count int, ct *clientTester) {
|
||||
ct.greet()
|
||||
var buf bytes.Buffer
|
||||
enc := hpack.NewEncoder(&buf)
|
||||
nreq := 0
|
||||
|
||||
for {
|
||||
f, err := ct.fr.ReadFrame()
|
||||
if err != nil {
|
||||
@@ -3900,19 +3886,19 @@ func TestTransportRetryAfterRefusedStream(t *testing.T) {
|
||||
// If the client's done, it
|
||||
// will have reported any
|
||||
// errors on its side.
|
||||
return nil
|
||||
default:
|
||||
return err
|
||||
t.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
switch f := f.(type) {
|
||||
case *WindowUpdateFrame, *SettingsFrame:
|
||||
case *HeadersFrame:
|
||||
if !f.HeadersEnded() {
|
||||
return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
|
||||
t.Errorf("headers should have END_HEADERS be ended: %v", f)
|
||||
return
|
||||
}
|
||||
nreq++
|
||||
if nreq == 1 {
|
||||
if count == 1 {
|
||||
ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
|
||||
} else {
|
||||
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
|
||||
@@ -3924,11 +3910,13 @@ func TestTransportRetryAfterRefusedStream(t *testing.T) {
|
||||
})
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("Unexpected client frame %v", f)
|
||||
t.Errorf("Unexpected client frame %v", f)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
ct.run()
|
||||
|
||||
testClientMultipleDials(t, client, server)
|
||||
}
|
||||
|
||||
func TestTransportRetryHasLimit(t *testing.T) {
|
||||
@@ -4143,6 +4131,7 @@ func TestTransportRequestsStallAtServerLimit(t *testing.T) {
|
||||
greet := make(chan struct{}) // server sends initial SETTINGS frame
|
||||
gotRequest := make(chan struct{}) // server received a request
|
||||
clientDone := make(chan struct{})
|
||||
cancelClientRequest := make(chan struct{})
|
||||
|
||||
// Collect errors from goroutines.
|
||||
var wg sync.WaitGroup
|
||||
@@ -4221,9 +4210,8 @@ func TestTransportRequestsStallAtServerLimit(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body)
|
||||
if k == maxConcurrent {
|
||||
// This request will be canceled.
|
||||
cancel := make(chan struct{})
|
||||
req.Cancel = cancel
|
||||
close(cancel)
|
||||
req.Cancel = cancelClientRequest
|
||||
close(cancelClientRequest)
|
||||
_, err := ct.tr.RoundTrip(req)
|
||||
close(clientRequestCancelled)
|
||||
if err == nil {
|
||||
@@ -5986,14 +5974,21 @@ func TestTransportRetriesOnStreamProtocolError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientConnReservations(t *testing.T) {
|
||||
cc := &ClientConn{
|
||||
reqHeaderMu: make(chan struct{}, 1),
|
||||
streams: make(map[uint32]*clientStream),
|
||||
maxConcurrentStreams: initialMaxConcurrentStreams,
|
||||
nextStreamID: 1,
|
||||
t: &Transport{},
|
||||
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
}, func(s *Server) {
|
||||
s.MaxConcurrentStreams = initialMaxConcurrentStreams
|
||||
})
|
||||
defer st.Close()
|
||||
|
||||
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
|
||||
defer tr.CloseIdleConnections()
|
||||
|
||||
cc, err := tr.newClientConn(st.cc, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cc.cond = sync.NewCond(&cc.mu)
|
||||
|
||||
req, _ := http.NewRequest("GET", st.ts.URL, nil)
|
||||
n := 0
|
||||
for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
|
||||
n++
|
||||
@@ -6001,8 +5996,8 @@ func TestClientConnReservations(t *testing.T) {
|
||||
if n != initialMaxConcurrentStreams {
|
||||
t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams)
|
||||
}
|
||||
if _, err := cc.RoundTrip(new(http.Request)); !errors.Is(err, errNilRequestURL) {
|
||||
t.Fatalf("RoundTrip error = %v; want errNilRequestURL", err)
|
||||
if _, err := cc.RoundTrip(req); err != nil {
|
||||
t.Fatalf("RoundTrip error = %v", err)
|
||||
}
|
||||
n2 := 0
|
||||
for n2 <= 5 && cc.ReserveNewRequest() {
|
||||
@@ -6014,7 +6009,7 @@ func TestClientConnReservations(t *testing.T) {
|
||||
|
||||
// Use up all the reservations
|
||||
for i := 0; i < n; i++ {
|
||||
cc.RoundTrip(new(http.Request))
|
||||
cc.RoundTrip(req)
|
||||
}
|
||||
|
||||
n2 = 0
|
||||
@@ -6370,3 +6365,95 @@ func TestTransportSlowClose(t *testing.T) {
|
||||
}
|
||||
res.Body.Close()
|
||||
}
|
||||
|
||||
type blockReadConn struct {
|
||||
net.Conn
|
||||
blockc chan struct{}
|
||||
}
|
||||
|
||||
func (c *blockReadConn) Read(b []byte) (n int, err error) {
|
||||
<-c.blockc
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
|
||||
func TestTransportReuseAfterError(t *testing.T) {
|
||||
serverReqc := make(chan struct{}, 3)
|
||||
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
serverReqc <- struct{}{}
|
||||
}, optOnlyServer)
|
||||
defer st.Close()
|
||||
|
||||
var (
|
||||
unblockOnce sync.Once
|
||||
blockc = make(chan struct{})
|
||||
connCountMu sync.Mutex
|
||||
connCount int
|
||||
)
|
||||
tr := &Transport{
|
||||
TLSClientConfig: tlsConfigInsecure,
|
||||
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
|
||||
// The first connection dialed will block on reads until blockc is closed.
|
||||
connCountMu.Lock()
|
||||
defer connCountMu.Unlock()
|
||||
connCount++
|
||||
conn, err := tls.Dial(network, addr, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if connCount == 1 {
|
||||
return &blockReadConn{
|
||||
Conn: conn,
|
||||
blockc: blockc,
|
||||
}, nil
|
||||
}
|
||||
return conn, nil
|
||||
},
|
||||
}
|
||||
defer tr.CloseIdleConnections()
|
||||
defer unblockOnce.Do(func() {
|
||||
// Ensure that reads on blockc are unblocked if we return early.
|
||||
close(blockc)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", st.ts.URL, nil)
|
||||
|
||||
// Request 1 is made on conn 1.
|
||||
// Reading the response will block.
|
||||
// Wait until the server receives the request, and continue.
|
||||
req1c := make(chan struct{})
|
||||
go func() {
|
||||
defer close(req1c)
|
||||
res1, err := tr.RoundTrip(req.Clone(context.Background()))
|
||||
if err != nil {
|
||||
t.Errorf("request 1: %v", err)
|
||||
} else {
|
||||
res1.Body.Close()
|
||||
}
|
||||
}()
|
||||
<-serverReqc
|
||||
|
||||
// Request 2 is also made on conn 1.
|
||||
// Reading the response will block.
|
||||
// The request fails when the context deadline expires.
|
||||
// Conn 1 should now be flagged as unfit for reuse.
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer cancel()
|
||||
_, err := tr.RoundTrip(req.Clone(timeoutCtx))
|
||||
if err == nil {
|
||||
t.Errorf("request 2 unexpectedly succeeded (want timeout)")
|
||||
}
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
|
||||
// Request 3 is made on a new conn, and succeeds.
|
||||
res3, err := tr.RoundTrip(req.Clone(context.Background()))
|
||||
if err != nil {
|
||||
t.Fatalf("request 3: %v", err)
|
||||
}
|
||||
res3.Body.Close()
|
||||
|
||||
// Unblock conn 1, and verify that request 1 completes.
|
||||
unblockOnce.Do(func() {
|
||||
close(blockc)
|
||||
})
|
||||
<-req1c
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user