diff --git a/http2/client_conn_pool.go b/http2/client_conn_pool.go index 3a67636f..652bc11a 100644 --- a/http2/client_conn_pool.go +++ b/http2/client_conn_pool.go @@ -7,7 +7,9 @@ package http2 import ( + "context" "crypto/tls" + "errors" "net/http" "sync" ) @@ -78,61 +80,69 @@ func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMis // It gets its own connection. traceGetConn(req, addr) const singleUse = true - cc, err := p.t.dialClientConn(addr, singleUse) + cc, err := p.t.dialClientConn(req.Context(), addr, singleUse) if err != nil { return nil, err } return cc, nil } - p.mu.Lock() - for _, cc := range p.conns[addr] { - if st := cc.idleState(); st.canTakeNewRequest { - if p.shouldTraceGetConn(st) { - traceGetConn(req, addr) + for { + p.mu.Lock() + for _, cc := range p.conns[addr] { + if st := cc.idleState(); st.canTakeNewRequest { + if p.shouldTraceGetConn(st) { + traceGetConn(req, addr) + } + p.mu.Unlock() + return cc, nil } - p.mu.Unlock() - return cc, nil } - } - if !dialOnMiss { + if !dialOnMiss { + p.mu.Unlock() + return nil, ErrNoCachedConn + } + traceGetConn(req, addr) + call := p.getStartDialLocked(req.Context(), addr) p.mu.Unlock() - return nil, ErrNoCachedConn + <-call.done + if shouldRetryDial(call, req) { + continue + } + return call.res, call.err } - traceGetConn(req, addr) - call := p.getStartDialLocked(addr) - p.mu.Unlock() - <-call.done - return call.res, call.err } // dialCall is an in-flight Transport dial call to a host. type dialCall struct { - _ incomparable - p *clientConnPool + _ incomparable + p *clientConnPool + // the context associated with the request + // that created this dialCall + ctx context.Context done chan struct{} // closed when done res *ClientConn // valid after done is closed err error // valid after done is closed } // requires p.mu is held. -func (p *clientConnPool) getStartDialLocked(addr string) *dialCall { +func (p *clientConnPool) getStartDialLocked(ctx context.Context, addr string) *dialCall { if call, ok := p.dialing[addr]; ok { // A dial is already in-flight. Don't start another. return call } - call := &dialCall{p: p, done: make(chan struct{})} + call := &dialCall{p: p, done: make(chan struct{}), ctx: ctx} if p.dialing == nil { p.dialing = make(map[string]*dialCall) } p.dialing[addr] = call - go call.dial(addr) + go call.dial(call.ctx, addr) return call } // run in its own goroutine. -func (c *dialCall) dial(addr string) { +func (c *dialCall) dial(ctx context.Context, addr string) { const singleUse = false // shared conn - c.res, c.err = c.p.t.dialClientConn(addr, singleUse) + c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse) close(c.done) c.p.mu.Lock() @@ -276,3 +286,28 @@ type noDialClientConnPool struct{ *clientConnPool } func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { return p.getClientConn(req, addr, noDialOnMiss) } + +// shouldRetryDial reports whether the current request should +// retry dialing after the call finished unsuccessfully, for example +// if the dial was canceled because of a context cancellation or +// deadline expiry. +func shouldRetryDial(call *dialCall, req *http.Request) bool { + if call.err == nil { + // No error, no need to retry + return false + } + if call.ctx == req.Context() { + // If the call has the same context as the request, the dial + // should not be retried, since any cancellation will have come + // from this request. + return false + } + if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) { + // If the call error is not because of a context cancellation or a deadline expiry, + // the dial should not be retried. + return false + } + // Only retry if the error is a context cancellation error or deadline expiry + // and the context associated with the call was canceled or expired. + return call.ctx.Err() != nil +} diff --git a/http2/transport.go b/http2/transport.go index 7688d72c..5ae89cfc 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -564,12 +564,12 @@ func canRetryError(err error) bool { return false } -func (t *Transport) dialClientConn(addr string, singleUse bool) (*ClientConn, error) { +func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) { host, _, err := net.SplitHostPort(addr) if err != nil { return nil, err } - tconn, err := t.dialTLS()("tcp", addr, t.newTLSConfig(host)) + tconn, err := t.dialTLS(ctx)("tcp", addr, t.newTLSConfig(host)) if err != nil { return nil, err } @@ -590,34 +590,28 @@ func (t *Transport) newTLSConfig(host string) *tls.Config { return cfg } -func (t *Transport) dialTLS() func(string, string, *tls.Config) (net.Conn, error) { +func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) { if t.DialTLS != nil { return t.DialTLS } - return t.dialTLSDefault -} - -func (t *Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (net.Conn, error) { - cn, err := tls.Dial(network, addr, cfg) - if err != nil { - return nil, err - } - if err := cn.Handshake(); err != nil { - return nil, err - } - if !cfg.InsecureSkipVerify { - if err := cn.VerifyHostname(cfg.ServerName); err != nil { + return func(network, addr string, cfg *tls.Config) (net.Conn, error) { + dialer := &tls.Dialer{ + Config: cfg, + } + cn, err := dialer.DialContext(ctx, network, addr) + if err != nil { return nil, err } + tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed + state := tlsCn.ConnectionState() + if p := state.NegotiatedProtocol; p != NextProtoTLS { + return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS) + } + if !state.NegotiatedProtocolIsMutual { + return nil, errors.New("http2: could not negotiate protocol mutually") + } + return cn, nil } - state := cn.ConnectionState() - if p := state.NegotiatedProtocol; p != NextProtoTLS { - return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS) - } - if !state.NegotiatedProtocolIsMutual { - return nil, errors.New("http2: could not negotiate protocol mutually") - } - return cn, nil } // disableKeepAlives reports whether connections should be closed as diff --git a/http2/transport_go117_test.go b/http2/transport_go117_test.go new file mode 100644 index 00000000..f5d4e0c1 --- /dev/null +++ b/http2/transport_go117_test.go @@ -0,0 +1,169 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.17 +// +build go1.17 + +package http2 + +import ( + "context" + "crypto/tls" + "errors" + "net/http" + "net/http/httptest" + + "testing" +) + +func TestTransportDialTLSContext(t *testing.T) { + blockCh := make(chan struct{}) + serverTLSConfigFunc := func(ts *httptest.Server) { + ts.Config.TLSConfig = &tls.Config{ + // Triggers the server to request the clients certificate + // during TLS handshake. + ClientAuth: tls.RequestClientCert, + } + } + ts := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) {}, + optOnlyServer, + serverTLSConfigFunc, + ) + defer ts.Close() + tr := &Transport{ + TLSClientConfig: &tls.Config{ + GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { + // Tests that the context provided to `req` is + // passed into this function. + close(blockCh) + <-cri.Context().Done() + return nil, cri.Context().Err() + }, + InsecureSkipVerify: true, + }, + } + defer tr.CloseIdleConnections() + req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + req = req.WithContext(ctx) + errCh := make(chan error) + go func() { + defer close(errCh) + res, err := tr.RoundTrip(req) + if err != nil { + errCh <- err + return + } + res.Body.Close() + }() + // Wait for GetClientCertificate handler to be called + <-blockCh + // Cancel the context + cancel() + // Expect the cancellation error here + err = <-errCh + if err == nil { + t.Fatal("cancelling context during client certificate fetch did not error as expected") + return + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("unexpected error returned after cancellation: %v", err) + } +} + +// TestDialRaceResumesDial tests that, given two concurrent requests +// to the same address, when the first Dial is interrupted because +// the first request's context is cancelled, the second request +// resumes the dial automatically. +func TestDialRaceResumesDial(t *testing.T) { + blockCh := make(chan struct{}) + serverTLSConfigFunc := func(ts *httptest.Server) { + ts.Config.TLSConfig = &tls.Config{ + // Triggers the server to request the clients certificate + // during TLS handshake. + ClientAuth: tls.RequestClientCert, + } + } + ts := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) {}, + optOnlyServer, + serverTLSConfigFunc, + ) + defer ts.Close() + tr := &Transport{ + TLSClientConfig: &tls.Config{ + GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { + select { + case <-blockCh: + // If we already errored, return without error. + return &tls.Certificate{}, nil + default: + } + close(blockCh) + <-cri.Context().Done() + return nil, cri.Context().Err() + }, + InsecureSkipVerify: true, + }, + } + defer tr.CloseIdleConnections() + req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + // Create two requests with independent cancellation. + ctx1, cancel1 := context.WithCancel(context.Background()) + defer cancel1() + req1 := req.WithContext(ctx1) + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + req2 := req.WithContext(ctx2) + errCh := make(chan error) + go func() { + res, err := tr.RoundTrip(req1) + if err != nil { + errCh <- err + return + } + res.Body.Close() + }() + successCh := make(chan struct{}) + go func() { + // Don't start request until first request + // has initiated the handshake. + <-blockCh + res, err := tr.RoundTrip(req2) + if err != nil { + errCh <- err + return + } + res.Body.Close() + // Close successCh to indicate that the second request + // made it to the server successfully. + close(successCh) + }() + // Wait for GetClientCertificate handler to be called + <-blockCh + // Cancel the context first + cancel1() + // Expect the cancellation error here + err = <-errCh + if err == nil { + t.Fatal("cancelling context during client certificate fetch did not error as expected") + return + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("unexpected error returned after cancellation: %v", err) + } + select { + case err := <-errCh: + t.Fatalf("unexpected second error: %v", err) + case <-successCh: + } +} diff --git a/http2/transport_test.go b/http2/transport_test.go index c9c948c2..7b139285 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -3276,7 +3276,8 @@ func TestClientConnPing(t *testing.T) { defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() - cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false) + ctx := context.Background() + cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) if err != nil { t.Fatal(err) } @@ -4278,7 +4279,8 @@ func testClientConnClose(t *testing.T, closeMode closeMode) { defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() - cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false) + ctx := context.Background() + cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { t.Fatal(err) @@ -4788,7 +4790,8 @@ func TestTransportRoundtripCloseOnWriteError(t *testing.T) { tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() - cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false) + ctx := context.Background() + cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false) if err != nil { t.Fatal(err) }