netutil: release semaphore on error

Also rewrite it a bit for clarity (IMO).

LGTM=pzm, r
R=pzm, adg, r
CC=golang-codereviews
https://golang.org/cl/96560043
This commit is contained in:
Brad Fitzpatrick
2014-05-22 11:59:35 -07:00
parent 1e5c0004cd
commit a479876f52
2 changed files with 43 additions and 13 deletions

View File

@@ -14,37 +14,35 @@ import (
// LimitListener returns a Listener that accepts at most n simultaneous
// connections from the provided Listener.
func LimitListener(l net.Listener, n int) net.Listener {
ch := make(chan struct{}, n)
for i := 0; i < n; i++ {
ch <- struct{}{}
}
return &limitListener{l, ch}
return &limitListener{l, make(chan struct{}, n)}
}
type limitListener struct {
net.Listener
ch chan struct{}
sem chan struct{}
}
func (l *limitListener) acquire() { l.sem <- struct{}{} }
func (l *limitListener) release() { <-l.sem }
func (l *limitListener) Accept() (net.Conn, error) {
<-l.ch
l.acquire()
c, err := l.Listener.Accept()
if err != nil {
l.release()
return nil, err
}
return &limitListenerConn{Conn: c, ch: l.ch}, nil
return &limitListenerConn{Conn: c, release: l.release}, nil
}
type limitListenerConn struct {
net.Conn
ch chan<- struct{}
close sync.Once
releaseOnce sync.Once
release func()
}
func (l *limitListenerConn) Close() error {
err := l.Conn.Close()
l.close.Do(func() {
l.ch <- struct{}{}
})
l.releaseOnce.Do(l.release)
return err
}

View File

@@ -10,6 +10,7 @@
package netutil
import (
"errors"
"fmt"
"io"
"io/ioutil"
@@ -69,3 +70,34 @@ func TestLimitListener(t *testing.T) {
t.Errorf("too many Gets failed: %v", failed)
}
}
type errorListener struct {
net.Listener
}
func (errorListener) Accept() (net.Conn, error) {
return nil, errFake
}
var errFake = errors.New("fake error from errorListener")
// This used to hang.
func TestLimitListenerError(t *testing.T) {
donec := make(chan bool, 1)
go func() {
const n = 2
ll := LimitListener(errorListener{}, n)
for i := 0; i < n+1; i++ {
_, err := ll.Accept()
if err != errFake {
t.Fatalf("Accept error = %v; want errFake", err)
}
}
donec <- true
}()
select {
case <-donec:
case <-time.After(5 * time.Second):
t.Fatal("timeout. deadlock?")
}
}