http2: use (*tls.Dialer).DialContext in dialTLS

This lets us propagate the request context into the TLS
handshake.

Related to CL 295370
Updates golang/go#32406

Change-Id: Ie10c301be19b57b4b3e46ac31bbe87679e1eebc7
Reviewed-on: https://go-review.googlesource.com/c/net/+/295173
Trust: Johan Brandhorst-Satzkorn <johan.brandhorst@gmail.com>
Run-TryBot: Johan Brandhorst-Satzkorn <johan.brandhorst@gmail.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
Johan Brandhorst
2021-02-23 10:49:51 +00:00
committed by Brad Fitzpatrick
parent 7fd8e65b64
commit bbd867fde5
4 changed files with 251 additions and 50 deletions

View File

@@ -7,7 +7,9 @@
package http2
import (
"context"
"crypto/tls"
"errors"
"net/http"
"sync"
)
@@ -78,61 +80,69 @@ func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMis
// It gets its own connection.
traceGetConn(req, addr)
const singleUse = true
cc, err := p.t.dialClientConn(addr, singleUse)
cc, err := p.t.dialClientConn(req.Context(), addr, singleUse)
if err != nil {
return nil, err
}
return cc, nil
}
p.mu.Lock()
for _, cc := range p.conns[addr] {
if st := cc.idleState(); st.canTakeNewRequest {
if p.shouldTraceGetConn(st) {
traceGetConn(req, addr)
for {
p.mu.Lock()
for _, cc := range p.conns[addr] {
if st := cc.idleState(); st.canTakeNewRequest {
if p.shouldTraceGetConn(st) {
traceGetConn(req, addr)
}
p.mu.Unlock()
return cc, nil
}
p.mu.Unlock()
return cc, nil
}
}
if !dialOnMiss {
if !dialOnMiss {
p.mu.Unlock()
return nil, ErrNoCachedConn
}
traceGetConn(req, addr)
call := p.getStartDialLocked(req.Context(), addr)
p.mu.Unlock()
return nil, ErrNoCachedConn
<-call.done
if shouldRetryDial(call, req) {
continue
}
return call.res, call.err
}
traceGetConn(req, addr)
call := p.getStartDialLocked(addr)
p.mu.Unlock()
<-call.done
return call.res, call.err
}
// dialCall is an in-flight Transport dial call to a host.
type dialCall struct {
_ incomparable
p *clientConnPool
_ incomparable
p *clientConnPool
// the context associated with the request
// that created this dialCall
ctx context.Context
done chan struct{} // closed when done
res *ClientConn // valid after done is closed
err error // valid after done is closed
}
// requires p.mu is held.
func (p *clientConnPool) getStartDialLocked(addr string) *dialCall {
func (p *clientConnPool) getStartDialLocked(ctx context.Context, addr string) *dialCall {
if call, ok := p.dialing[addr]; ok {
// A dial is already in-flight. Don't start another.
return call
}
call := &dialCall{p: p, done: make(chan struct{})}
call := &dialCall{p: p, done: make(chan struct{}), ctx: ctx}
if p.dialing == nil {
p.dialing = make(map[string]*dialCall)
}
p.dialing[addr] = call
go call.dial(addr)
go call.dial(call.ctx, addr)
return call
}
// run in its own goroutine.
func (c *dialCall) dial(addr string) {
func (c *dialCall) dial(ctx context.Context, addr string) {
const singleUse = false // shared conn
c.res, c.err = c.p.t.dialClientConn(addr, singleUse)
c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse)
close(c.done)
c.p.mu.Lock()
@@ -276,3 +286,28 @@ type noDialClientConnPool struct{ *clientConnPool }
func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
return p.getClientConn(req, addr, noDialOnMiss)
}
// shouldRetryDial reports whether the current request should
// retry dialing after the call finished unsuccessfully, for example
// if the dial was canceled because of a context cancellation or
// deadline expiry.
func shouldRetryDial(call *dialCall, req *http.Request) bool {
if call.err == nil {
// No error, no need to retry
return false
}
if call.ctx == req.Context() {
// If the call has the same context as the request, the dial
// should not be retried, since any cancellation will have come
// from this request.
return false
}
if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) {
// If the call error is not because of a context cancellation or a deadline expiry,
// the dial should not be retried.
return false
}
// Only retry if the error is a context cancellation error or deadline expiry
// and the context associated with the call was canceled or expired.
return call.ctx.Err() != nil
}

View File

@@ -564,12 +564,12 @@ func canRetryError(err error) bool {
return false
}
func (t *Transport) dialClientConn(addr string, singleUse bool) (*ClientConn, error) {
func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
tconn, err := t.dialTLS()("tcp", addr, t.newTLSConfig(host))
tconn, err := t.dialTLS(ctx)("tcp", addr, t.newTLSConfig(host))
if err != nil {
return nil, err
}
@@ -590,34 +590,28 @@ func (t *Transport) newTLSConfig(host string) *tls.Config {
return cfg
}
func (t *Transport) dialTLS() func(string, string, *tls.Config) (net.Conn, error) {
func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) {
if t.DialTLS != nil {
return t.DialTLS
}
return t.dialTLSDefault
}
func (t *Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (net.Conn, error) {
cn, err := tls.Dial(network, addr, cfg)
if err != nil {
return nil, err
}
if err := cn.Handshake(); err != nil {
return nil, err
}
if !cfg.InsecureSkipVerify {
if err := cn.VerifyHostname(cfg.ServerName); err != nil {
return func(network, addr string, cfg *tls.Config) (net.Conn, error) {
dialer := &tls.Dialer{
Config: cfg,
}
cn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed
state := tlsCn.ConnectionState()
if p := state.NegotiatedProtocol; p != NextProtoTLS {
return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS)
}
if !state.NegotiatedProtocolIsMutual {
return nil, errors.New("http2: could not negotiate protocol mutually")
}
return cn, nil
}
state := cn.ConnectionState()
if p := state.NegotiatedProtocol; p != NextProtoTLS {
return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS)
}
if !state.NegotiatedProtocolIsMutual {
return nil, errors.New("http2: could not negotiate protocol mutually")
}
return cn, nil
}
// disableKeepAlives reports whether connections should be closed as

View File

@@ -0,0 +1,169 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.17
// +build go1.17
package http2
import (
"context"
"crypto/tls"
"errors"
"net/http"
"net/http/httptest"
"testing"
)
func TestTransportDialTLSContext(t *testing.T) {
blockCh := make(chan struct{})
serverTLSConfigFunc := func(ts *httptest.Server) {
ts.Config.TLSConfig = &tls.Config{
// Triggers the server to request the clients certificate
// during TLS handshake.
ClientAuth: tls.RequestClientCert,
}
}
ts := newServerTester(t,
func(w http.ResponseWriter, r *http.Request) {},
optOnlyServer,
serverTLSConfigFunc,
)
defer ts.Close()
tr := &Transport{
TLSClientConfig: &tls.Config{
GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
// Tests that the context provided to `req` is
// passed into this function.
close(blockCh)
<-cri.Context().Done()
return nil, cri.Context().Err()
},
InsecureSkipVerify: true,
},
}
defer tr.CloseIdleConnections()
req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil)
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
req = req.WithContext(ctx)
errCh := make(chan error)
go func() {
defer close(errCh)
res, err := tr.RoundTrip(req)
if err != nil {
errCh <- err
return
}
res.Body.Close()
}()
// Wait for GetClientCertificate handler to be called
<-blockCh
// Cancel the context
cancel()
// Expect the cancellation error here
err = <-errCh
if err == nil {
t.Fatal("cancelling context during client certificate fetch did not error as expected")
return
}
if !errors.Is(err, context.Canceled) {
t.Fatalf("unexpected error returned after cancellation: %v", err)
}
}
// TestDialRaceResumesDial tests that, given two concurrent requests
// to the same address, when the first Dial is interrupted because
// the first request's context is cancelled, the second request
// resumes the dial automatically.
func TestDialRaceResumesDial(t *testing.T) {
blockCh := make(chan struct{})
serverTLSConfigFunc := func(ts *httptest.Server) {
ts.Config.TLSConfig = &tls.Config{
// Triggers the server to request the clients certificate
// during TLS handshake.
ClientAuth: tls.RequestClientCert,
}
}
ts := newServerTester(t,
func(w http.ResponseWriter, r *http.Request) {},
optOnlyServer,
serverTLSConfigFunc,
)
defer ts.Close()
tr := &Transport{
TLSClientConfig: &tls.Config{
GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
select {
case <-blockCh:
// If we already errored, return without error.
return &tls.Certificate{}, nil
default:
}
close(blockCh)
<-cri.Context().Done()
return nil, cri.Context().Err()
},
InsecureSkipVerify: true,
},
}
defer tr.CloseIdleConnections()
req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil)
if err != nil {
t.Fatal(err)
}
// Create two requests with independent cancellation.
ctx1, cancel1 := context.WithCancel(context.Background())
defer cancel1()
req1 := req.WithContext(ctx1)
ctx2, cancel2 := context.WithCancel(context.Background())
defer cancel2()
req2 := req.WithContext(ctx2)
errCh := make(chan error)
go func() {
res, err := tr.RoundTrip(req1)
if err != nil {
errCh <- err
return
}
res.Body.Close()
}()
successCh := make(chan struct{})
go func() {
// Don't start request until first request
// has initiated the handshake.
<-blockCh
res, err := tr.RoundTrip(req2)
if err != nil {
errCh <- err
return
}
res.Body.Close()
// Close successCh to indicate that the second request
// made it to the server successfully.
close(successCh)
}()
// Wait for GetClientCertificate handler to be called
<-blockCh
// Cancel the context first
cancel1()
// Expect the cancellation error here
err = <-errCh
if err == nil {
t.Fatal("cancelling context during client certificate fetch did not error as expected")
return
}
if !errors.Is(err, context.Canceled) {
t.Fatalf("unexpected error returned after cancellation: %v", err)
}
select {
case err := <-errCh:
t.Fatalf("unexpected second error: %v", err)
case <-successCh:
}
}

View File

@@ -3276,7 +3276,8 @@ func TestClientConnPing(t *testing.T) {
defer st.Close()
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
ctx := context.Background()
cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
if err != nil {
t.Fatal(err)
}
@@ -4278,7 +4279,8 @@ func testClientConnClose(t *testing.T, closeMode closeMode) {
defer st.Close()
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
ctx := context.Background()
cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
req, err := http.NewRequest("GET", st.ts.URL, nil)
if err != nil {
t.Fatal(err)
@@ -4788,7 +4790,8 @@ func TestTransportRoundtripCloseOnWriteError(t *testing.T) {
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
ctx := context.Background()
cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
if err != nil {
t.Fatal(err)
}