mirror of
https://github.com/golang/net.git
synced 2026-03-31 18:37:08 +09:00
http2: use (*tls.Dialer).DialContext in dialTLS
This lets us propagate the request context into the TLS handshake. Related to CL 295370 Updates golang/go#32406 Change-Id: Ie10c301be19b57b4b3e46ac31bbe87679e1eebc7 Reviewed-on: https://go-review.googlesource.com/c/net/+/295173 Trust: Johan Brandhorst-Satzkorn <johan.brandhorst@gmail.com> Run-TryBot: Johan Brandhorst-Satzkorn <johan.brandhorst@gmail.com> TryBot-Result: Go Bot <gobot@golang.org> Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org> Reviewed-by: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
committed by
Brad Fitzpatrick
parent
7fd8e65b64
commit
bbd867fde5
@@ -7,7 +7,9 @@
|
||||
package http2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
@@ -78,61 +80,69 @@ func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMis
|
||||
// It gets its own connection.
|
||||
traceGetConn(req, addr)
|
||||
const singleUse = true
|
||||
cc, err := p.t.dialClientConn(addr, singleUse)
|
||||
cc, err := p.t.dialClientConn(req.Context(), addr, singleUse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cc, nil
|
||||
}
|
||||
p.mu.Lock()
|
||||
for _, cc := range p.conns[addr] {
|
||||
if st := cc.idleState(); st.canTakeNewRequest {
|
||||
if p.shouldTraceGetConn(st) {
|
||||
traceGetConn(req, addr)
|
||||
for {
|
||||
p.mu.Lock()
|
||||
for _, cc := range p.conns[addr] {
|
||||
if st := cc.idleState(); st.canTakeNewRequest {
|
||||
if p.shouldTraceGetConn(st) {
|
||||
traceGetConn(req, addr)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
return cc, nil
|
||||
}
|
||||
p.mu.Unlock()
|
||||
return cc, nil
|
||||
}
|
||||
}
|
||||
if !dialOnMiss {
|
||||
if !dialOnMiss {
|
||||
p.mu.Unlock()
|
||||
return nil, ErrNoCachedConn
|
||||
}
|
||||
traceGetConn(req, addr)
|
||||
call := p.getStartDialLocked(req.Context(), addr)
|
||||
p.mu.Unlock()
|
||||
return nil, ErrNoCachedConn
|
||||
<-call.done
|
||||
if shouldRetryDial(call, req) {
|
||||
continue
|
||||
}
|
||||
return call.res, call.err
|
||||
}
|
||||
traceGetConn(req, addr)
|
||||
call := p.getStartDialLocked(addr)
|
||||
p.mu.Unlock()
|
||||
<-call.done
|
||||
return call.res, call.err
|
||||
}
|
||||
|
||||
// dialCall is an in-flight Transport dial call to a host.
|
||||
type dialCall struct {
|
||||
_ incomparable
|
||||
p *clientConnPool
|
||||
_ incomparable
|
||||
p *clientConnPool
|
||||
// the context associated with the request
|
||||
// that created this dialCall
|
||||
ctx context.Context
|
||||
done chan struct{} // closed when done
|
||||
res *ClientConn // valid after done is closed
|
||||
err error // valid after done is closed
|
||||
}
|
||||
|
||||
// requires p.mu is held.
|
||||
func (p *clientConnPool) getStartDialLocked(addr string) *dialCall {
|
||||
func (p *clientConnPool) getStartDialLocked(ctx context.Context, addr string) *dialCall {
|
||||
if call, ok := p.dialing[addr]; ok {
|
||||
// A dial is already in-flight. Don't start another.
|
||||
return call
|
||||
}
|
||||
call := &dialCall{p: p, done: make(chan struct{})}
|
||||
call := &dialCall{p: p, done: make(chan struct{}), ctx: ctx}
|
||||
if p.dialing == nil {
|
||||
p.dialing = make(map[string]*dialCall)
|
||||
}
|
||||
p.dialing[addr] = call
|
||||
go call.dial(addr)
|
||||
go call.dial(call.ctx, addr)
|
||||
return call
|
||||
}
|
||||
|
||||
// run in its own goroutine.
|
||||
func (c *dialCall) dial(addr string) {
|
||||
func (c *dialCall) dial(ctx context.Context, addr string) {
|
||||
const singleUse = false // shared conn
|
||||
c.res, c.err = c.p.t.dialClientConn(addr, singleUse)
|
||||
c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse)
|
||||
close(c.done)
|
||||
|
||||
c.p.mu.Lock()
|
||||
@@ -276,3 +286,28 @@ type noDialClientConnPool struct{ *clientConnPool }
|
||||
func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
|
||||
return p.getClientConn(req, addr, noDialOnMiss)
|
||||
}
|
||||
|
||||
// shouldRetryDial reports whether the current request should
|
||||
// retry dialing after the call finished unsuccessfully, for example
|
||||
// if the dial was canceled because of a context cancellation or
|
||||
// deadline expiry.
|
||||
func shouldRetryDial(call *dialCall, req *http.Request) bool {
|
||||
if call.err == nil {
|
||||
// No error, no need to retry
|
||||
return false
|
||||
}
|
||||
if call.ctx == req.Context() {
|
||||
// If the call has the same context as the request, the dial
|
||||
// should not be retried, since any cancellation will have come
|
||||
// from this request.
|
||||
return false
|
||||
}
|
||||
if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) {
|
||||
// If the call error is not because of a context cancellation or a deadline expiry,
|
||||
// the dial should not be retried.
|
||||
return false
|
||||
}
|
||||
// Only retry if the error is a context cancellation error or deadline expiry
|
||||
// and the context associated with the call was canceled or expired.
|
||||
return call.ctx.Err() != nil
|
||||
}
|
||||
|
||||
@@ -564,12 +564,12 @@ func canRetryError(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *Transport) dialClientConn(addr string, singleUse bool) (*ClientConn, error) {
|
||||
func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) {
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tconn, err := t.dialTLS()("tcp", addr, t.newTLSConfig(host))
|
||||
tconn, err := t.dialTLS(ctx)("tcp", addr, t.newTLSConfig(host))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -590,34 +590,28 @@ func (t *Transport) newTLSConfig(host string) *tls.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (t *Transport) dialTLS() func(string, string, *tls.Config) (net.Conn, error) {
|
||||
func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) {
|
||||
if t.DialTLS != nil {
|
||||
return t.DialTLS
|
||||
}
|
||||
return t.dialTLSDefault
|
||||
}
|
||||
|
||||
func (t *Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (net.Conn, error) {
|
||||
cn, err := tls.Dial(network, addr, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := cn.Handshake(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !cfg.InsecureSkipVerify {
|
||||
if err := cn.VerifyHostname(cfg.ServerName); err != nil {
|
||||
return func(network, addr string, cfg *tls.Config) (net.Conn, error) {
|
||||
dialer := &tls.Dialer{
|
||||
Config: cfg,
|
||||
}
|
||||
cn, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed
|
||||
state := tlsCn.ConnectionState()
|
||||
if p := state.NegotiatedProtocol; p != NextProtoTLS {
|
||||
return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS)
|
||||
}
|
||||
if !state.NegotiatedProtocolIsMutual {
|
||||
return nil, errors.New("http2: could not negotiate protocol mutually")
|
||||
}
|
||||
return cn, nil
|
||||
}
|
||||
state := cn.ConnectionState()
|
||||
if p := state.NegotiatedProtocol; p != NextProtoTLS {
|
||||
return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS)
|
||||
}
|
||||
if !state.NegotiatedProtocolIsMutual {
|
||||
return nil, errors.New("http2: could not negotiate protocol mutually")
|
||||
}
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
// disableKeepAlives reports whether connections should be closed as
|
||||
|
||||
169
http2/transport_go117_test.go
Normal file
169
http2/transport_go117_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
// Copyright 2021 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build go1.17
|
||||
// +build go1.17
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTransportDialTLSContext(t *testing.T) {
|
||||
blockCh := make(chan struct{})
|
||||
serverTLSConfigFunc := func(ts *httptest.Server) {
|
||||
ts.Config.TLSConfig = &tls.Config{
|
||||
// Triggers the server to request the clients certificate
|
||||
// during TLS handshake.
|
||||
ClientAuth: tls.RequestClientCert,
|
||||
}
|
||||
}
|
||||
ts := newServerTester(t,
|
||||
func(w http.ResponseWriter, r *http.Request) {},
|
||||
optOnlyServer,
|
||||
serverTLSConfigFunc,
|
||||
)
|
||||
defer ts.Close()
|
||||
tr := &Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
// Tests that the context provided to `req` is
|
||||
// passed into this function.
|
||||
close(blockCh)
|
||||
<-cri.Context().Done()
|
||||
return nil, cri.Context().Err()
|
||||
},
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
defer tr.CloseIdleConnections()
|
||||
req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
res, err := tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
res.Body.Close()
|
||||
}()
|
||||
// Wait for GetClientCertificate handler to be called
|
||||
<-blockCh
|
||||
// Cancel the context
|
||||
cancel()
|
||||
// Expect the cancellation error here
|
||||
err = <-errCh
|
||||
if err == nil {
|
||||
t.Fatal("cancelling context during client certificate fetch did not error as expected")
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("unexpected error returned after cancellation: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDialRaceResumesDial tests that, given two concurrent requests
|
||||
// to the same address, when the first Dial is interrupted because
|
||||
// the first request's context is cancelled, the second request
|
||||
// resumes the dial automatically.
|
||||
func TestDialRaceResumesDial(t *testing.T) {
|
||||
blockCh := make(chan struct{})
|
||||
serverTLSConfigFunc := func(ts *httptest.Server) {
|
||||
ts.Config.TLSConfig = &tls.Config{
|
||||
// Triggers the server to request the clients certificate
|
||||
// during TLS handshake.
|
||||
ClientAuth: tls.RequestClientCert,
|
||||
}
|
||||
}
|
||||
ts := newServerTester(t,
|
||||
func(w http.ResponseWriter, r *http.Request) {},
|
||||
optOnlyServer,
|
||||
serverTLSConfigFunc,
|
||||
)
|
||||
defer ts.Close()
|
||||
tr := &Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
select {
|
||||
case <-blockCh:
|
||||
// If we already errored, return without error.
|
||||
return &tls.Certificate{}, nil
|
||||
default:
|
||||
}
|
||||
close(blockCh)
|
||||
<-cri.Context().Done()
|
||||
return nil, cri.Context().Err()
|
||||
},
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
defer tr.CloseIdleConnections()
|
||||
req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Create two requests with independent cancellation.
|
||||
ctx1, cancel1 := context.WithCancel(context.Background())
|
||||
defer cancel1()
|
||||
req1 := req.WithContext(ctx1)
|
||||
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel2()
|
||||
req2 := req.WithContext(ctx2)
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
res, err := tr.RoundTrip(req1)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
res.Body.Close()
|
||||
}()
|
||||
successCh := make(chan struct{})
|
||||
go func() {
|
||||
// Don't start request until first request
|
||||
// has initiated the handshake.
|
||||
<-blockCh
|
||||
res, err := tr.RoundTrip(req2)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
res.Body.Close()
|
||||
// Close successCh to indicate that the second request
|
||||
// made it to the server successfully.
|
||||
close(successCh)
|
||||
}()
|
||||
// Wait for GetClientCertificate handler to be called
|
||||
<-blockCh
|
||||
// Cancel the context first
|
||||
cancel1()
|
||||
// Expect the cancellation error here
|
||||
err = <-errCh
|
||||
if err == nil {
|
||||
t.Fatal("cancelling context during client certificate fetch did not error as expected")
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("unexpected error returned after cancellation: %v", err)
|
||||
}
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatalf("unexpected second error: %v", err)
|
||||
case <-successCh:
|
||||
}
|
||||
}
|
||||
@@ -3276,7 +3276,8 @@ func TestClientConnPing(t *testing.T) {
|
||||
defer st.Close()
|
||||
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
|
||||
defer tr.CloseIdleConnections()
|
||||
cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
|
||||
ctx := context.Background()
|
||||
cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -4278,7 +4279,8 @@ func testClientConnClose(t *testing.T, closeMode closeMode) {
|
||||
defer st.Close()
|
||||
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
|
||||
defer tr.CloseIdleConnections()
|
||||
cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
|
||||
ctx := context.Background()
|
||||
cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
|
||||
req, err := http.NewRequest("GET", st.ts.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -4788,7 +4790,8 @@ func TestTransportRoundtripCloseOnWriteError(t *testing.T) {
|
||||
|
||||
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
|
||||
defer tr.CloseIdleConnections()
|
||||
cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
|
||||
ctx := context.Background()
|
||||
cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user