From 195180cfebf7362bd243a52477697895128c8777 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 30 Nov 2015 18:41:58 +0000 Subject: [PATCH] http2: merge duplicate Transport dials Fixes golang/go#13397 Updates golang/go#6891 Change-Id: I1e4c7bfe60c6abf9a03f2888aa6abc3891c309e7 Reviewed-on: https://go-review.googlesource.com/17134 Reviewed-by: Ian Lance Taylor --- http2/client_conn_pool.go | 68 +++++++++++++++++++++++++------- http2/transport_test.go | 81 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 135 insertions(+), 14 deletions(-) diff --git a/http2/client_conn_pool.go b/http2/client_conn_pool.go index eeac8384..e59c800f 100644 --- a/http2/client_conn_pool.go +++ b/http2/client_conn_pool.go @@ -19,11 +19,12 @@ type ClientConnPool interface { type clientConnPool struct { t *Transport - mu sync.Mutex // TODO: switch to RWMutex + mu sync.Mutex // TODO: maybe switch to RWMutex // TODO: add support for sharing conns based on cert names // (e.g. share conn for googleapis.com and appspot.com) - conns map[string][]*ClientConn // key is host:port - keys map[*ClientConn][]string + conns map[string][]*ClientConn // key is host:port + dialing map[string]*dialCall // currently in-flight dials + keys map[*ClientConn][]string } func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { @@ -38,26 +39,65 @@ func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMis return cc, nil } } - p.mu.Unlock() if !dialOnMiss { + p.mu.Unlock() return nil, ErrNoCachedConn } + call := p.getStartDialLocked(addr) + p.mu.Unlock() + <-call.done + return call.res, call.err +} - // TODO(bradfitz): use a singleflight.Group to only lock once per 'key'. - // Probably need to vendor it in as github.com/golang/sync/singleflight - // though, since the net package already uses it? Also lines up with - // sameer, bcmills, et al wanting to open source some sync stuff. - cc, err := p.t.dialClientConn(addr) - if err != nil { - return nil, err +// dialCall is an in-flight Transport dial call to a host. +type dialCall struct { + p *clientConnPool + done chan struct{} // closed when done + res *ClientConn // valid after done is closed + err error // valid after done is closed +} + +// requires p.mu is held. +func (p *clientConnPool) getStartDialLocked(addr string) *dialCall { + if call, ok := p.dialing[addr]; ok { + // A dial is already in-flight. Don't start another. + return call } - p.addConn(addr, cc) - return cc, nil + call := &dialCall{p: p, done: make(chan struct{})} + if p.dialing == nil { + p.dialing = make(map[string]*dialCall) + } + p.dialing[addr] = call + go call.dial(addr) + return call +} + +// run in its own goroutine. +func (c *dialCall) dial(addr string) { + c.res, c.err = c.p.t.dialClientConn(addr) + close(c.done) + + c.p.mu.Lock() + delete(c.p.dialing, addr) + if c.err == nil { + c.p.addConnLocked(addr, c.res) + } + c.p.mu.Unlock() } func (p *clientConnPool) addConn(key string, cc *ClientConn) { p.mu.Lock() - defer p.mu.Unlock() + p.addConnLocked(key, cc) + p.mu.Unlock() +} + +// p.mu must be held +func (p *clientConnPool) addConnLocked(key string, cc *ClientConn) { + for _, v := range p.conns[key] { + if v == cc { + return + } + } if p.conns == nil { p.conns = make(map[string][]*ClientConn) } diff --git a/http2/transport_test.go b/http2/transport_test.go index 31b6459c..83791575 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -128,6 +128,87 @@ func TestTransportReusesConns(t *testing.T) { } } +// Tests that the Transport only keeps one pending dial open per destination address. +// https://golang.org/issue/13397 +func TestTransportGroupsPendingDials(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, r.RemoteAddr) + }, optOnlyServer) + defer st.Close() + tr := &Transport{ + TLSClientConfig: tlsConfigInsecure, + } + defer tr.CloseIdleConnections() + var ( + mu sync.Mutex + dials = map[string]int{} + ) + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Error(err) + return + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Error(err) + return + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Errorf("Body read: %v", err) + } + addr := strings.TrimSpace(string(slurp)) + if addr == "" { + t.Errorf("didn't get an addr in response") + } + mu.Lock() + dials[addr]++ + mu.Unlock() + }() + } + wg.Wait() + if len(dials) != 1 { + t.Errorf("saw %d dials; want 1: %v", len(dials), dials) + } + tr.CloseIdleConnections() + if err := retry(50, 10*time.Millisecond, func() error { + cp, ok := tr.connPool().(*clientConnPool) + if !ok { + return fmt.Errorf("Conn pool is %T; want *clientConnPool", tr.connPool()) + } + if len(cp.dialing) != 0 { + return fmt.Errorf("dialing map = %v; want empty", cp.dialing) + } + if len(cp.conns) != 0 { + return fmt.Errorf("conns = %v; want empty", cp.conns) + } + if len(cp.keys) != 0 { + return fmt.Errorf("keys = %v; want empty", cp.keys) + } + return nil + }); err != nil { + t.Error("State of pool after CloseIdleConnections: %v", err) + } +} + +func retry(tries int, delay time.Duration, fn func() error) error { + var err error + for i := 0; i < tries; i++ { + err = fn() + if err == nil { + return nil + } + time.Sleep(delay) + } + return err +} + func TestTransportAbortClosesPipes(t *testing.T) { shutdown := make(chan struct{}) st := newServerTester(t,