diff --git a/netutil/listen.go b/netutil/listen.go index b23c6e99..a2591f83 100644 --- a/netutil/listen.go +++ b/netutil/listen.go @@ -14,37 +14,35 @@ import ( // LimitListener returns a Listener that accepts at most n simultaneous // connections from the provided Listener. func LimitListener(l net.Listener, n int) net.Listener { - ch := make(chan struct{}, n) - for i := 0; i < n; i++ { - ch <- struct{}{} - } - return &limitListener{l, ch} + return &limitListener{l, make(chan struct{}, n)} } type limitListener struct { net.Listener - ch chan struct{} + sem chan struct{} } +func (l *limitListener) acquire() { l.sem <- struct{}{} } +func (l *limitListener) release() { <-l.sem } + func (l *limitListener) Accept() (net.Conn, error) { - <-l.ch + l.acquire() c, err := l.Listener.Accept() if err != nil { + l.release() return nil, err } - return &limitListenerConn{Conn: c, ch: l.ch}, nil + return &limitListenerConn{Conn: c, release: l.release}, nil } type limitListenerConn struct { net.Conn - ch chan<- struct{} - close sync.Once + releaseOnce sync.Once + release func() } func (l *limitListenerConn) Close() error { err := l.Conn.Close() - l.close.Do(func() { - l.ch <- struct{}{} - }) + l.releaseOnce.Do(l.release) return err } diff --git a/netutil/listen_test.go b/netutil/listen_test.go index 1146d458..ac87e0ee 100644 --- a/netutil/listen_test.go +++ b/netutil/listen_test.go @@ -10,6 +10,7 @@ package netutil import ( + "errors" "fmt" "io" "io/ioutil" @@ -69,3 +70,34 @@ func TestLimitListener(t *testing.T) { t.Errorf("too many Gets failed: %v", failed) } } + +type errorListener struct { + net.Listener +} + +func (errorListener) Accept() (net.Conn, error) { + return nil, errFake +} + +var errFake = errors.New("fake error from errorListener") + +// This used to hang. +func TestLimitListenerError(t *testing.T) { + donec := make(chan bool, 1) + go func() { + const n = 2 + ll := LimitListener(errorListener{}, n) + for i := 0; i < n+1; i++ { + _, err := ll.Accept() + if err != errFake { + t.Fatalf("Accept error = %v; want errFake", err) + } + } + donec <- true + }() + select { + case <-donec: + case <-time.After(5 * time.Second): + t.Fatal("timeout. deadlock?") + } +}