mirror of
https://github.com/golang/net.git
synced 2026-03-31 18:37:08 +09:00
http2: client conn pool abstraction
Change-Id: Icbf40b26a25c7084efd062a0a66385450ec537aa Reviewed-on: https://go-review.googlesource.com/16699 Reviewed-by: Blake Mizerany <blake.mizerany@gmail.com>
This commit is contained in:
118
http2/client_conn_pool.go
Normal file
118
http2/client_conn_pool.go
Normal file
@@ -0,0 +1,118 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Transport code's client connection pooling.
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ClientConnPool manages a pool of HTTP/2 client connections.
|
||||
type ClientConnPool interface {
|
||||
GetClientConn(req *http.Request, addr string) (*ClientConn, error)
|
||||
MarkDead(*ClientConn)
|
||||
}
|
||||
|
||||
type clientConnPool struct {
|
||||
t *Transport
|
||||
mu sync.Mutex // TODO: switch to RWMutex
|
||||
// TODO: add support for sharing conns based on cert names
|
||||
// (e.g. share conn for googleapis.com and appspot.com)
|
||||
conns map[string][]*ClientConn // key is host:port
|
||||
keys map[*ClientConn][]string
|
||||
}
|
||||
|
||||
func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
|
||||
return p.getClientConn(req, addr, true)
|
||||
}
|
||||
|
||||
func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) {
|
||||
p.mu.Lock()
|
||||
for _, cc := range p.conns[addr] {
|
||||
if cc.CanTakeNewRequest() {
|
||||
p.mu.Unlock()
|
||||
return cc, nil
|
||||
}
|
||||
}
|
||||
p.mu.Unlock()
|
||||
if !dialOnMiss {
|
||||
return nil, ErrNoCachedConn
|
||||
}
|
||||
|
||||
// TODO(bradfitz): use a singleflight.Group to only lock once per 'key'.
|
||||
// Probably need to vendor it in as github.com/golang/sync/singleflight
|
||||
// though, since the net package already uses it? Also lines up with
|
||||
// sameer, bcmills, et al wanting to open source some sync stuff.
|
||||
cc, err := p.t.dialClientConn(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.addConn(addr, cc)
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
func (p *clientConnPool) addConn(key string, cc *ClientConn) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.conns == nil {
|
||||
p.conns = make(map[string][]*ClientConn)
|
||||
}
|
||||
if p.keys == nil {
|
||||
p.keys = make(map[*ClientConn][]string)
|
||||
}
|
||||
p.conns[key] = append(p.conns[key], cc)
|
||||
p.keys[cc] = append(p.keys[cc], key)
|
||||
}
|
||||
|
||||
func (p *clientConnPool) MarkDead(cc *ClientConn) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
for _, key := range p.keys[cc] {
|
||||
vv, ok := p.conns[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
newList := filterOutClientConn(vv, cc)
|
||||
if len(newList) > 0 {
|
||||
p.conns[key] = newList
|
||||
} else {
|
||||
delete(p.conns, key)
|
||||
}
|
||||
}
|
||||
delete(p.keys, cc)
|
||||
}
|
||||
|
||||
func (p *clientConnPool) closeIdleConnections() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
// TODO: don't close a cc if it was just added to the pool
|
||||
// milliseconds ago and has never been used. There's currently
|
||||
// a small race window with the HTTP/1 Transport's integration
|
||||
// where it can add an idle conn just before using it, and
|
||||
// somebody else can concurrently call CloseIdleConns and
|
||||
// break some caller's RoundTrip.
|
||||
for _, vv := range p.conns {
|
||||
for _, cc := range vv {
|
||||
cc.closeIfIdle()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filterOutClientConn(in []*ClientConn, exclude *ClientConn) []*ClientConn {
|
||||
out := in[:0]
|
||||
for _, v := range in {
|
||||
if v != exclude {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
// If we filtered it out, zero out the last item to prevent
|
||||
// the GC from seeing it.
|
||||
if len(in) != len(out) {
|
||||
in[len(in)-1] = nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -57,20 +57,33 @@ type Transport struct {
|
||||
// tls.Client. If nil, the default configuration is used.
|
||||
TLSClientConfig *tls.Config
|
||||
|
||||
// TODO: switch to RWMutex
|
||||
// TODO: add support for sharing conns based on cert names
|
||||
// (e.g. share conn for googleapis.com and appspot.com)
|
||||
connMu sync.Mutex
|
||||
conns map[string][]*clientConn // key is host:port
|
||||
// ConnPool optionally specifies an alternate connection pool to use.
|
||||
// If nil, the default is used.
|
||||
ConnPool ClientConnPool
|
||||
|
||||
connPoolOnce sync.Once
|
||||
connPoolOrDef ClientConnPool // non-nil version of ConnPool
|
||||
}
|
||||
|
||||
// clientConn is the state of a single HTTP/2 client connection to an
|
||||
func (t *Transport) connPool() ClientConnPool {
|
||||
t.connPoolOnce.Do(t.initConnPool)
|
||||
return t.connPoolOrDef
|
||||
}
|
||||
|
||||
func (t *Transport) initConnPool() {
|
||||
if t.ConnPool != nil {
|
||||
t.connPoolOrDef = t.ConnPool
|
||||
} else {
|
||||
t.connPoolOrDef = &clientConnPool{t: t}
|
||||
}
|
||||
}
|
||||
|
||||
// ClientConn is the state of a single HTTP/2 client connection to an
|
||||
// HTTP/2 server.
|
||||
type clientConn struct {
|
||||
type ClientConn struct {
|
||||
t *Transport
|
||||
tconn net.Conn
|
||||
tlsState *tls.ConnectionState
|
||||
connKey []string // key(s) this connection is cached in, in t.conns
|
||||
tconn net.Conn // usually *tls.Conn, except specialized impls
|
||||
tlsState *tls.ConnectionState // nil only for specialized impls
|
||||
|
||||
// readLoop goroutine fields:
|
||||
readerDone chan struct{} // closed on error
|
||||
@@ -102,7 +115,7 @@ type clientConn struct {
|
||||
// clientStream is the state for a single HTTP/2 stream. One of these
|
||||
// is created for each Transport.RoundTrip call.
|
||||
type clientStream struct {
|
||||
cc *clientConn
|
||||
cc *ClientConn
|
||||
ID uint32
|
||||
resc chan resAndError
|
||||
bufPipe pipe // buffered pipe with the flow-controlled response payload
|
||||
@@ -154,24 +167,28 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return t.RoundTripOpt(req, RoundTripOpt{})
|
||||
}
|
||||
|
||||
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
|
||||
// and returns a host:port. The port 443 is added if needed.
|
||||
func authorityAddr(authority string) (addr string) {
|
||||
if _, _, err := net.SplitHostPort(authority); err == nil {
|
||||
return authority
|
||||
}
|
||||
return net.JoinHostPort(authority, "443")
|
||||
}
|
||||
|
||||
// RoundTripOpt is like RoundTrip, but takes options.
|
||||
func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
|
||||
if req.URL.Scheme != "https" {
|
||||
return nil, errors.New("http2: unsupported scheme")
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(req.URL.Host)
|
||||
if err != nil {
|
||||
host = req.URL.Host
|
||||
port = "443"
|
||||
}
|
||||
|
||||
addr := authorityAddr(req.URL.Host)
|
||||
for {
|
||||
cc, err := t.getClientConn(host, port, opt.OnlyCachedConn)
|
||||
cc, err := t.connPool().GetClientConn(req, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := cc.roundTrip(req)
|
||||
res, err := cc.RoundTrip(req)
|
||||
if shouldRetryRequest(err) { // TODO: or clientconn is overloaded (too many outstanding requests)?
|
||||
continue
|
||||
}
|
||||
@@ -186,12 +203,8 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
|
||||
// connected from previous requests but are now sitting idle.
|
||||
// It does not interrupt any connections currently in use.
|
||||
func (t *Transport) CloseIdleConnections() {
|
||||
t.connMu.Lock()
|
||||
defer t.connMu.Unlock()
|
||||
for _, vv := range t.conns {
|
||||
for _, cc := range vv {
|
||||
cc.closeIfIdle()
|
||||
}
|
||||
if cp, ok := t.connPool().(*clientConnPool); ok {
|
||||
cp.closeIdleConnections()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -202,100 +215,16 @@ func shouldRetryRequest(err error) bool {
|
||||
return err == errClientConnClosed
|
||||
}
|
||||
|
||||
func (t *Transport) removeClientConn(cc *clientConn) {
|
||||
t.connMu.Lock()
|
||||
defer t.connMu.Unlock()
|
||||
for _, key := range cc.connKey {
|
||||
vv, ok := t.conns[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
newList := filterOutClientConn(vv, cc)
|
||||
if len(newList) > 0 {
|
||||
t.conns[key] = newList
|
||||
} else {
|
||||
delete(t.conns, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filterOutClientConn(in []*clientConn, exclude *clientConn) []*clientConn {
|
||||
out := in[:0]
|
||||
for _, v := range in {
|
||||
if v != exclude {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
// If we filtered it out, zero out the last item to prevent
|
||||
// the GC from seeing it.
|
||||
if len(in) != len(out) {
|
||||
in[len(in)-1] = nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// AddIdleConn adds c as an idle conn for Transport.
|
||||
// It assumes that c has not yet exchanged SETTINGS frames.
|
||||
// The addr maybe be either "host" or "host:port".
|
||||
func (t *Transport) AddIdleConn(addr string, c *tls.Conn) error {
|
||||
var key string
|
||||
_, _, err := net.SplitHostPort(addr)
|
||||
if err == nil {
|
||||
key = addr
|
||||
} else {
|
||||
key = addr + ":443"
|
||||
}
|
||||
cc, err := t.newClientConn(key, c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.addConn(key, cc)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Transport) addConn(key string, cc *clientConn) {
|
||||
t.connMu.Lock()
|
||||
defer t.connMu.Unlock()
|
||||
if t.conns == nil {
|
||||
t.conns = make(map[string][]*clientConn)
|
||||
}
|
||||
t.conns[key] = append(t.conns[key], cc)
|
||||
}
|
||||
|
||||
func (t *Transport) getClientConn(host, port string, onlyCached bool) (*clientConn, error) {
|
||||
key := net.JoinHostPort(host, port)
|
||||
|
||||
t.connMu.Lock()
|
||||
for _, cc := range t.conns[key] {
|
||||
if cc.canTakeNewRequest() {
|
||||
t.connMu.Unlock()
|
||||
return cc, nil
|
||||
}
|
||||
}
|
||||
t.connMu.Unlock()
|
||||
if onlyCached {
|
||||
return nil, ErrNoCachedConn
|
||||
}
|
||||
|
||||
// TODO(bradfitz): use a singleflight.Group to only lock once per 'key'.
|
||||
// Probably need to vendor it in as github.com/golang/sync/singleflight
|
||||
// though, since the net package already uses it? Also lines up with
|
||||
// sameer, bcmills, et al wanting to open source some sync stuff.
|
||||
cc, err := t.dialClientConn(host, port, key)
|
||||
func (t *Transport) dialClientConn(addr string) (*ClientConn, error) {
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.addConn(key, cc)
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
func (t *Transport) dialClientConn(host, port, key string) (*clientConn, error) {
|
||||
tconn, err := t.dialTLS()("tcp", net.JoinHostPort(host, port), t.newTLSConfig(host))
|
||||
tconn, err := t.dialTLS()("tcp", addr, t.newTLSConfig(host))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.newClientConn(key, tconn)
|
||||
return t.NewClientConn(tconn)
|
||||
}
|
||||
|
||||
func (t *Transport) newTLSConfig(host string) *tls.Config {
|
||||
@@ -338,15 +267,14 @@ func (t *Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (net.C
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
func (t *Transport) newClientConn(key string, tconn net.Conn) (*clientConn, error) {
|
||||
if _, err := tconn.Write(clientPreface); err != nil {
|
||||
func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
|
||||
if _, err := c.Write(clientPreface); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cc := &clientConn{
|
||||
cc := &ClientConn{
|
||||
t: t,
|
||||
tconn: tconn,
|
||||
connKey: []string{key}, // TODO: cert's validated hostnames too
|
||||
tconn: c,
|
||||
readerDone: make(chan struct{}),
|
||||
nextStreamID: 1,
|
||||
maxFrameSize: 16 << 10, // spec default
|
||||
@@ -359,15 +287,15 @@ func (t *Transport) newClientConn(key string, tconn net.Conn) (*clientConn, erro
|
||||
|
||||
// TODO: adjust this writer size to account for frame size +
|
||||
// MTU + crypto/tls record padding.
|
||||
cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
|
||||
cc.br = bufio.NewReader(tconn)
|
||||
cc.bw = bufio.NewWriter(stickyErrWriter{c, &cc.werr})
|
||||
cc.br = bufio.NewReader(c)
|
||||
cc.fr = NewFramer(cc.bw, cc.br)
|
||||
cc.henc = hpack.NewEncoder(&cc.hbuf)
|
||||
|
||||
type connectionStater interface {
|
||||
ConnectionState() tls.ConnectionState
|
||||
}
|
||||
if cs, ok := tconn.(connectionStater); ok {
|
||||
if cs, ok := c.(connectionStater); ok {
|
||||
state := cs.ConnectionState()
|
||||
cc.tlsState = &state
|
||||
}
|
||||
@@ -414,13 +342,13 @@ func (t *Transport) newClientConn(key string, tconn net.Conn) (*clientConn, erro
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
func (cc *clientConn) setGoAway(f *GoAwayFrame) {
|
||||
func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
cc.goAway = f
|
||||
}
|
||||
|
||||
func (cc *clientConn) canTakeNewRequest() bool {
|
||||
func (cc *ClientConn) CanTakeNewRequest() bool {
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
return cc.goAway == nil &&
|
||||
@@ -428,7 +356,7 @@ func (cc *clientConn) canTakeNewRequest() bool {
|
||||
cc.nextStreamID < 2147483647
|
||||
}
|
||||
|
||||
func (cc *clientConn) closeIfIdle() {
|
||||
func (cc *ClientConn) closeIfIdle() {
|
||||
cc.mu.Lock()
|
||||
if len(cc.streams) > 0 {
|
||||
cc.mu.Unlock()
|
||||
@@ -447,7 +375,7 @@ const maxAllocFrameSize = 512 << 10
|
||||
// They're capped at the min of the peer's max frame size or 512KB
|
||||
// (kinda arbitrarily), but definitely capped so we don't allocate 4GB
|
||||
// bufers.
|
||||
func (cc *clientConn) frameScratchBuffer() []byte {
|
||||
func (cc *ClientConn) frameScratchBuffer() []byte {
|
||||
cc.mu.Lock()
|
||||
size := cc.maxFrameSize
|
||||
if size > maxAllocFrameSize {
|
||||
@@ -464,7 +392,7 @@ func (cc *clientConn) frameScratchBuffer() []byte {
|
||||
return make([]byte, size)
|
||||
}
|
||||
|
||||
func (cc *clientConn) putFrameScratchBuffer(buf []byte) {
|
||||
func (cc *ClientConn) putFrameScratchBuffer(buf []byte) {
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
const maxBufs = 4 // arbitrary; 4 concurrent requests per conn? investigate.
|
||||
@@ -481,7 +409,7 @@ func (cc *clientConn) putFrameScratchBuffer(buf []byte) {
|
||||
// forget about it.
|
||||
}
|
||||
|
||||
func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) {
|
||||
func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
cc.mu.Lock()
|
||||
|
||||
if cc.closed {
|
||||
@@ -649,7 +577,7 @@ func (cs *clientStream) awaitFlowControl(maxBytes int32) (taken int32, err error
|
||||
}
|
||||
|
||||
// requires cc.mu be held.
|
||||
func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
|
||||
func (cc *ClientConn) encodeHeaders(req *http.Request) []byte {
|
||||
cc.hbuf.Reset()
|
||||
|
||||
// TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
|
||||
@@ -680,7 +608,7 @@ func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
|
||||
return cc.hbuf.Bytes()
|
||||
}
|
||||
|
||||
func (cc *clientConn) writeHeader(name, value string) {
|
||||
func (cc *ClientConn) writeHeader(name, value string) {
|
||||
cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
|
||||
}
|
||||
|
||||
@@ -690,7 +618,7 @@ type resAndError struct {
|
||||
}
|
||||
|
||||
// requires cc.mu be held.
|
||||
func (cc *clientConn) newStream() *clientStream {
|
||||
func (cc *ClientConn) newStream() *clientStream {
|
||||
cs := &clientStream{
|
||||
cc: cc,
|
||||
ID: cc.nextStreamID,
|
||||
@@ -706,7 +634,7 @@ func (cc *clientConn) newStream() *clientStream {
|
||||
return cs
|
||||
}
|
||||
|
||||
func (cc *clientConn) streamByID(id uint32, andRemove bool) *clientStream {
|
||||
func (cc *ClientConn) streamByID(id uint32, andRemove bool) *clientStream {
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
cs := cc.streams[id]
|
||||
@@ -718,7 +646,7 @@ func (cc *clientConn) streamByID(id uint32, andRemove bool) *clientStream {
|
||||
|
||||
// clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop.
|
||||
type clientConnReadLoop struct {
|
||||
cc *clientConn
|
||||
cc *ClientConn
|
||||
activeRes map[uint32]*clientStream // keyed by streamID
|
||||
|
||||
// continueStreamID is the stream ID we're waiting for
|
||||
@@ -734,7 +662,7 @@ type clientConnReadLoop struct {
|
||||
}
|
||||
|
||||
// readLoop runs in its own goroutine and reads and dispatches frames.
|
||||
func (cc *clientConn) readLoop() {
|
||||
func (cc *ClientConn) readLoop() {
|
||||
rl := &clientConnReadLoop{
|
||||
cc: cc,
|
||||
activeRes: make(map[uint32]*clientStream),
|
||||
@@ -754,7 +682,7 @@ func (cc *clientConn) readLoop() {
|
||||
func (rl *clientConnReadLoop) cleanup() {
|
||||
cc := rl.cc
|
||||
defer cc.tconn.Close()
|
||||
defer cc.t.removeClientConn(cc)
|
||||
defer cc.t.connPool().MarkDead(cc)
|
||||
defer close(cc.readerDone)
|
||||
|
||||
// Close any response bodies if the server closes prematurely.
|
||||
@@ -978,7 +906,7 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error {
|
||||
|
||||
func (rl *clientConnReadLoop) processGoAway(f *GoAwayFrame) error {
|
||||
cc := rl.cc
|
||||
cc.t.removeClientConn(cc)
|
||||
cc.t.connPool().MarkDead(cc)
|
||||
if f.ErrCode != 0 {
|
||||
// TODO: deal with GOAWAY more. particularly the error code
|
||||
cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode)
|
||||
@@ -1066,7 +994,7 @@ func (rl *clientConnReadLoop) processPushPromise(f *PushPromiseFrame) error {
|
||||
return ConnectionError(ErrCodeProtocol)
|
||||
}
|
||||
|
||||
func (cc *clientConn) writeStreamReset(streamID uint32, code ErrCode, err error) {
|
||||
func (cc *ClientConn) writeStreamReset(streamID uint32, code ErrCode, err error) {
|
||||
// TODO: do something with err? send it as a debug frame to the peer?
|
||||
// But that's only in GOAWAY. Invent a new frame type? Is there one already?
|
||||
cc.wmu.Lock()
|
||||
@@ -1108,11 +1036,11 @@ func (rl *clientConnReadLoop) onNewHeaderField(f hpack.HeaderField) {
|
||||
}
|
||||
}
|
||||
|
||||
func (cc *clientConn) logf(format string, args ...interface{}) {
|
||||
func (cc *ClientConn) logf(format string, args ...interface{}) {
|
||||
cc.t.logf(format, args...)
|
||||
}
|
||||
|
||||
func (cc *clientConn) vlogf(format string, args ...interface{}) {
|
||||
func (cc *ClientConn) vlogf(format string, args ...interface{}) {
|
||||
cc.t.vlogf(format, args...)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user