http2: don't reuse connections that are experiencing errors

When a request on a connection fails to complete successfully,
mark the conn as doNotReuse. It's possible for requests to
fail for reasons unrelated to connection health,
but opening a new connection unnecessarily is less of an
impact than reusing a dead connection.

Fixes golang/go#59690

Change-Id: I40bf6cefae602ead70c3bcf2fe573cc13f34a385
Reviewed-on: https://go-review.googlesource.com/c/net/+/486156
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Bryan Mills <bcmills@google.com>
This commit is contained in:
Damien Neil
2023-04-18 11:18:57 -07:00
parent 0bfab66a03
commit 82780d606d
2 changed files with 237 additions and 132 deletions

View File

@@ -1266,6 +1266,27 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
return res, nil
}
cancelRequest := func(cs *clientStream, err error) error {
cs.cc.mu.Lock()
defer cs.cc.mu.Unlock()
cs.abortStreamLocked(err)
if cs.ID != 0 {
// This request may have failed because of a problem with the connection,
// or for some unrelated reason. (For example, the user might have canceled
// the request without waiting for a response.) Mark the connection as
// not reusable, since trying to reuse a dead connection is worse than
// unnecessarily creating a new one.
//
// If cs.ID is 0, then the request was never allocated a stream ID and
// whatever went wrong was unrelated to the connection. We might have
// timed out waiting for a stream slot when StrictMaxConcurrentStreams
// is set, for example, in which case retrying on a different connection
// will not help.
cs.cc.doNotReuse = true
}
return err
}
for {
select {
case <-cs.respHeaderRecv:
@@ -1280,15 +1301,12 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
return handleResponseHeaders()
default:
waitDone()
return nil, cs.abortErr
return nil, cancelRequest(cs, cs.abortErr)
}
case <-ctx.Done():
err := ctx.Err()
cs.abortStream(err)
return nil, err
return nil, cancelRequest(cs, ctx.Err())
case <-cs.reqCancel:
cs.abortStream(errRequestCanceled)
return nil, errRequestCanceled
return nil, cancelRequest(cs, errRequestCanceled)
}
}
}

View File

@@ -775,7 +775,6 @@ func newClientTester(t *testing.T) *clientTester {
cc, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatal(err)
}
sc, err := ln.Accept()
if err != nil {
@@ -1765,6 +1764,18 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) {
defer tr.CloseIdleConnections()
checkRoundTrip := func(req *http.Request, wantErr error, desc string) {
// Make an arbitrary request to ensure we get the server's
// settings frame and initialize peerMaxHeaderListSize.
req0, err := http.NewRequest("GET", st.ts.URL, nil)
if err != nil {
t.Fatalf("newRequest: NewRequest: %v", err)
}
res0, err := tr.RoundTrip(req0)
if err != nil {
t.Errorf("%v: Initial RoundTrip err = %v", desc, err)
}
res0.Body.Close()
res, err := tr.RoundTrip(req)
if err != wantErr {
if res != nil {
@@ -1825,13 +1836,9 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) {
return req
}
// Make an arbitrary request to ensure we get the server's
// settings frame and initialize peerMaxHeaderListSize.
// Validate peerMaxHeaderListSize.
req := newRequest()
checkRoundTrip(req, nil, "Initial request")
// Get the ClientConn associated with the request and validate
// peerMaxHeaderListSize.
addr := authorityAddr(req.URL.Scheme, req.URL.Host)
cc, err := tr.connPool().GetClientConn(req, addr)
if err != nil {
@@ -3738,35 +3745,33 @@ func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.D
ct.run()
}
func TestTransportRetryAfterGOAWAY(t *testing.T) {
var dialer struct {
sync.Mutex
count int
}
ct1 := make(chan *clientTester)
ct2 := make(chan *clientTester)
func testClientMultipleDials(t *testing.T, client func(*Transport), server func(int, *clientTester)) {
ln := newLocalListener(t)
defer ln.Close()
var (
mu sync.Mutex
count int
conns []net.Conn
)
var wg sync.WaitGroup
tr := &Transport{
TLSClientConfig: tlsConfigInsecure,
}
tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
dialer.Lock()
defer dialer.Unlock()
dialer.count++
if dialer.count == 3 {
return nil, errors.New("unexpected number of dials")
}
mu.Lock()
defer mu.Unlock()
count++
cc, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
return nil, fmt.Errorf("dial error: %v", err)
}
conns = append(conns, cc)
sc, err := ln.Accept()
if err != nil {
return nil, fmt.Errorf("accept error: %v", err)
}
conns = append(conns, sc)
ct := &clientTester{
t: t,
tr: tr,
@@ -3774,19 +3779,26 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) {
sc: sc,
fr: NewFramer(sc, sc),
}
switch dialer.count {
case 1:
ct1 <- ct
case 2:
ct2 <- ct
}
wg.Add(1)
go func(count int) {
defer wg.Done()
server(count, ct)
sc.Close()
}(count)
return cc, nil
}
errs := make(chan error, 3)
client(tr)
tr.CloseIdleConnections()
ln.Close()
for _, c := range conns {
c.Close()
}
wg.Wait()
}
// Client.
go func() {
func TestTransportRetryAfterGOAWAY(t *testing.T) {
client := func(tr *Transport) {
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
res, err := tr.RoundTrip(req)
if res != nil {
@@ -3796,102 +3808,76 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) {
}
}
if err != nil {
err = fmt.Errorf("RoundTrip: %v", err)
}
errs <- err
}()
connToClose := make(chan io.Closer, 2)
// Server for the first request.
go func() {
ct := <-ct1
connToClose <- ct.cc
ct.greet()
hf, err := ct.firstHeaders()
if err != nil {
errs <- fmt.Errorf("server1 failed reading HEADERS: %v", err)
return
}
t.Logf("server1 got %v", hf)
if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil {
errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err)
return
}
errs <- nil
}()
// Server for the second request.
go func() {
ct := <-ct2
connToClose <- ct.cc
ct.greet()
hf, err := ct.firstHeaders()
if err != nil {
errs <- fmt.Errorf("server2 failed reading HEADERS: %v", err)
return
}
t.Logf("server2 got %v", hf)
var buf bytes.Buffer
enc := hpack.NewEncoder(&buf)
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
err = ct.fr.WriteHeaders(HeadersFrameParam{
StreamID: hf.StreamID,
EndHeaders: true,
EndStream: false,
BlockFragment: buf.Bytes(),
})
if err != nil {
errs <- fmt.Errorf("server2 failed writing response HEADERS: %v", err)
} else {
errs <- nil
}
}()
for k := 0; k < 3; k++ {
err := <-errs
if err != nil {
t.Error(err)
t.Errorf("RoundTrip: %v", err)
}
}
close(connToClose)
for c := range connToClose {
c.Close()
server := func(count int, ct *clientTester) {
switch count {
case 1:
ct.greet()
hf, err := ct.firstHeaders()
if err != nil {
t.Errorf("server1 failed reading HEADERS: %v", err)
return
}
t.Logf("server1 got %v", hf)
if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil {
t.Errorf("server1 failed writing GOAWAY: %v", err)
return
}
case 2:
ct.greet()
hf, err := ct.firstHeaders()
if err != nil {
t.Errorf("server2 failed reading HEADERS: %v", err)
return
}
t.Logf("server2 got %v", hf)
var buf bytes.Buffer
enc := hpack.NewEncoder(&buf)
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
err = ct.fr.WriteHeaders(HeadersFrameParam{
StreamID: hf.StreamID,
EndHeaders: true,
EndStream: false,
BlockFragment: buf.Bytes(),
})
if err != nil {
t.Errorf("server2 failed writing response HEADERS: %v", err)
}
default:
t.Errorf("unexpected number of dials")
return
}
}
testClientMultipleDials(t, client, server)
}
func TestTransportRetryAfterRefusedStream(t *testing.T) {
clientDone := make(chan struct{})
ct := newClientTester(t)
ct.client = func() error {
defer ct.cc.(*net.TCPConn).CloseWrite()
if runtime.GOOS == "plan9" {
// CloseWrite not supported on Plan 9; Issue 17906
defer ct.cc.(*net.TCPConn).Close()
}
client := func(tr *Transport) {
defer close(clientDone)
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
resp, err := ct.tr.RoundTrip(req)
resp, err := tr.RoundTrip(req)
if err != nil {
return fmt.Errorf("RoundTrip: %v", err)
t.Errorf("RoundTrip: %v", err)
return
}
resp.Body.Close()
if resp.StatusCode != 204 {
return fmt.Errorf("Status = %v; want 204", resp.StatusCode)
t.Errorf("Status = %v; want 204", resp.StatusCode)
return
}
return nil
}
ct.server = func() error {
server := func(count int, ct *clientTester) {
ct.greet()
var buf bytes.Buffer
enc := hpack.NewEncoder(&buf)
nreq := 0
for {
f, err := ct.fr.ReadFrame()
if err != nil {
@@ -3900,19 +3886,19 @@ func TestTransportRetryAfterRefusedStream(t *testing.T) {
// If the client's done, it
// will have reported any
// errors on its side.
return nil
default:
return err
t.Error(err)
}
return
}
switch f := f.(type) {
case *WindowUpdateFrame, *SettingsFrame:
case *HeadersFrame:
if !f.HeadersEnded() {
return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
t.Errorf("headers should have END_HEADERS be ended: %v", f)
return
}
nreq++
if nreq == 1 {
if count == 1 {
ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
} else {
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
@@ -3924,11 +3910,13 @@ func TestTransportRetryAfterRefusedStream(t *testing.T) {
})
}
default:
return fmt.Errorf("Unexpected client frame %v", f)
t.Errorf("Unexpected client frame %v", f)
return
}
}
}
ct.run()
testClientMultipleDials(t, client, server)
}
func TestTransportRetryHasLimit(t *testing.T) {
@@ -4143,6 +4131,7 @@ func TestTransportRequestsStallAtServerLimit(t *testing.T) {
greet := make(chan struct{}) // server sends initial SETTINGS frame
gotRequest := make(chan struct{}) // server received a request
clientDone := make(chan struct{})
cancelClientRequest := make(chan struct{})
// Collect errors from goroutines.
var wg sync.WaitGroup
@@ -4221,9 +4210,8 @@ func TestTransportRequestsStallAtServerLimit(t *testing.T) {
req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body)
if k == maxConcurrent {
// This request will be canceled.
cancel := make(chan struct{})
req.Cancel = cancel
close(cancel)
req.Cancel = cancelClientRequest
close(cancelClientRequest)
_, err := ct.tr.RoundTrip(req)
close(clientRequestCancelled)
if err == nil {
@@ -5986,14 +5974,21 @@ func TestTransportRetriesOnStreamProtocolError(t *testing.T) {
}
func TestClientConnReservations(t *testing.T) {
cc := &ClientConn{
reqHeaderMu: make(chan struct{}, 1),
streams: make(map[uint32]*clientStream),
maxConcurrentStreams: initialMaxConcurrentStreams,
nextStreamID: 1,
t: &Transport{},
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
}, func(s *Server) {
s.MaxConcurrentStreams = initialMaxConcurrentStreams
})
defer st.Close()
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
cc, err := tr.newClientConn(st.cc, false)
if err != nil {
t.Fatal(err)
}
cc.cond = sync.NewCond(&cc.mu)
req, _ := http.NewRequest("GET", st.ts.URL, nil)
n := 0
for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
n++
@@ -6001,8 +5996,8 @@ func TestClientConnReservations(t *testing.T) {
if n != initialMaxConcurrentStreams {
t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams)
}
if _, err := cc.RoundTrip(new(http.Request)); !errors.Is(err, errNilRequestURL) {
t.Fatalf("RoundTrip error = %v; want errNilRequestURL", err)
if _, err := cc.RoundTrip(req); err != nil {
t.Fatalf("RoundTrip error = %v", err)
}
n2 := 0
for n2 <= 5 && cc.ReserveNewRequest() {
@@ -6014,7 +6009,7 @@ func TestClientConnReservations(t *testing.T) {
// Use up all the reservations
for i := 0; i < n; i++ {
cc.RoundTrip(new(http.Request))
cc.RoundTrip(req)
}
n2 = 0
@@ -6370,3 +6365,95 @@ func TestTransportSlowClose(t *testing.T) {
}
res.Body.Close()
}
type blockReadConn struct {
net.Conn
blockc chan struct{}
}
func (c *blockReadConn) Read(b []byte) (n int, err error) {
<-c.blockc
return c.Conn.Read(b)
}
func TestTransportReuseAfterError(t *testing.T) {
serverReqc := make(chan struct{}, 3)
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
serverReqc <- struct{}{}
}, optOnlyServer)
defer st.Close()
var (
unblockOnce sync.Once
blockc = make(chan struct{})
connCountMu sync.Mutex
connCount int
)
tr := &Transport{
TLSClientConfig: tlsConfigInsecure,
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
// The first connection dialed will block on reads until blockc is closed.
connCountMu.Lock()
defer connCountMu.Unlock()
connCount++
conn, err := tls.Dial(network, addr, cfg)
if err != nil {
return nil, err
}
if connCount == 1 {
return &blockReadConn{
Conn: conn,
blockc: blockc,
}, nil
}
return conn, nil
},
}
defer tr.CloseIdleConnections()
defer unblockOnce.Do(func() {
// Ensure that reads on blockc are unblocked if we return early.
close(blockc)
})
req, _ := http.NewRequest("GET", st.ts.URL, nil)
// Request 1 is made on conn 1.
// Reading the response will block.
// Wait until the server receives the request, and continue.
req1c := make(chan struct{})
go func() {
defer close(req1c)
res1, err := tr.RoundTrip(req.Clone(context.Background()))
if err != nil {
t.Errorf("request 1: %v", err)
} else {
res1.Body.Close()
}
}()
<-serverReqc
// Request 2 is also made on conn 1.
// Reading the response will block.
// The request fails when the context deadline expires.
// Conn 1 should now be flagged as unfit for reuse.
timeoutCtx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
_, err := tr.RoundTrip(req.Clone(timeoutCtx))
if err == nil {
t.Errorf("request 2 unexpectedly succeeded (want timeout)")
}
time.Sleep(1 * time.Millisecond)
// Request 3 is made on a new conn, and succeeds.
res3, err := tr.RoundTrip(req.Clone(context.Background()))
if err != nil {
t.Fatalf("request 3: %v", err)
}
res3.Body.Close()
// Unblock conn 1, and verify that request 1 completes.
unblockOnce.Do(func() {
close(blockc)
})
<-req1c
}