http2: merge duplicate Transport dials

Fixes golang/go#13397
Updates golang/go#6891

Change-Id: I1e4c7bfe60c6abf9a03f2888aa6abc3891c309e7
Reviewed-on: https://go-review.googlesource.com/17134
Reviewed-by: Ian Lance Taylor <iant@golang.org>
This commit is contained in:
Brad Fitzpatrick
2015-11-30 18:41:58 +00:00
parent 62ac18b461
commit 195180cfeb
2 changed files with 135 additions and 14 deletions

View File

@@ -19,11 +19,12 @@ type ClientConnPool interface {
type clientConnPool struct {
t *Transport
mu sync.Mutex // TODO: switch to RWMutex
mu sync.Mutex // TODO: maybe switch to RWMutex
// TODO: add support for sharing conns based on cert names
// (e.g. share conn for googleapis.com and appspot.com)
conns map[string][]*ClientConn // key is host:port
keys map[*ClientConn][]string
conns map[string][]*ClientConn // key is host:port
dialing map[string]*dialCall // currently in-flight dials
keys map[*ClientConn][]string
}
func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
@@ -38,26 +39,65 @@ func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMis
return cc, nil
}
}
p.mu.Unlock()
if !dialOnMiss {
p.mu.Unlock()
return nil, ErrNoCachedConn
}
call := p.getStartDialLocked(addr)
p.mu.Unlock()
<-call.done
return call.res, call.err
}
// TODO(bradfitz): use a singleflight.Group to only lock once per 'key'.
// Probably need to vendor it in as github.com/golang/sync/singleflight
// though, since the net package already uses it? Also lines up with
// sameer, bcmills, et al wanting to open source some sync stuff.
cc, err := p.t.dialClientConn(addr)
if err != nil {
return nil, err
// dialCall is an in-flight Transport dial call to a host.
type dialCall struct {
p *clientConnPool
done chan struct{} // closed when done
res *ClientConn // valid after done is closed
err error // valid after done is closed
}
// requires p.mu is held.
func (p *clientConnPool) getStartDialLocked(addr string) *dialCall {
if call, ok := p.dialing[addr]; ok {
// A dial is already in-flight. Don't start another.
return call
}
p.addConn(addr, cc)
return cc, nil
call := &dialCall{p: p, done: make(chan struct{})}
if p.dialing == nil {
p.dialing = make(map[string]*dialCall)
}
p.dialing[addr] = call
go call.dial(addr)
return call
}
// run in its own goroutine.
func (c *dialCall) dial(addr string) {
c.res, c.err = c.p.t.dialClientConn(addr)
close(c.done)
c.p.mu.Lock()
delete(c.p.dialing, addr)
if c.err == nil {
c.p.addConnLocked(addr, c.res)
}
c.p.mu.Unlock()
}
func (p *clientConnPool) addConn(key string, cc *ClientConn) {
p.mu.Lock()
defer p.mu.Unlock()
p.addConnLocked(key, cc)
p.mu.Unlock()
}
// p.mu must be held
func (p *clientConnPool) addConnLocked(key string, cc *ClientConn) {
for _, v := range p.conns[key] {
if v == cc {
return
}
}
if p.conns == nil {
p.conns = make(map[string][]*ClientConn)
}

View File

@@ -128,6 +128,87 @@ func TestTransportReusesConns(t *testing.T) {
}
}
// Tests that the Transport only keeps one pending dial open per destination address.
// https://golang.org/issue/13397
func TestTransportGroupsPendingDials(t *testing.T) {
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, r.RemoteAddr)
}, optOnlyServer)
defer st.Close()
tr := &Transport{
TLSClientConfig: tlsConfigInsecure,
}
defer tr.CloseIdleConnections()
var (
mu sync.Mutex
dials = map[string]int{}
)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req, err := http.NewRequest("GET", st.ts.URL, nil)
if err != nil {
t.Error(err)
return
}
res, err := tr.RoundTrip(req)
if err != nil {
t.Error(err)
return
}
defer res.Body.Close()
slurp, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Errorf("Body read: %v", err)
}
addr := strings.TrimSpace(string(slurp))
if addr == "" {
t.Errorf("didn't get an addr in response")
}
mu.Lock()
dials[addr]++
mu.Unlock()
}()
}
wg.Wait()
if len(dials) != 1 {
t.Errorf("saw %d dials; want 1: %v", len(dials), dials)
}
tr.CloseIdleConnections()
if err := retry(50, 10*time.Millisecond, func() error {
cp, ok := tr.connPool().(*clientConnPool)
if !ok {
return fmt.Errorf("Conn pool is %T; want *clientConnPool", tr.connPool())
}
if len(cp.dialing) != 0 {
return fmt.Errorf("dialing map = %v; want empty", cp.dialing)
}
if len(cp.conns) != 0 {
return fmt.Errorf("conns = %v; want empty", cp.conns)
}
if len(cp.keys) != 0 {
return fmt.Errorf("keys = %v; want empty", cp.keys)
}
return nil
}); err != nil {
t.Error("State of pool after CloseIdleConnections: %v", err)
}
}
func retry(tries int, delay time.Duration, fn func() error) error {
var err error
for i := 0; i < tries; i++ {
err = fn()
if err == nil {
return nil
}
time.Sleep(delay)
}
return err
}
func TestTransportAbortClosesPipes(t *testing.T) {
shutdown := make(chan struct{})
st := newServerTester(t,