mirror of
https://github.com/golang/net.git
synced 2026-03-31 18:37:08 +09:00
netutil: release semaphore on error
Also rewrite it a bit for clarity (IMO). LGTM=pzm, r R=pzm, adg, r CC=golang-codereviews https://golang.org/cl/96560043
This commit is contained in:
@@ -14,37 +14,35 @@ import (
|
||||
// LimitListener returns a Listener that accepts at most n simultaneous
|
||||
// connections from the provided Listener.
|
||||
func LimitListener(l net.Listener, n int) net.Listener {
|
||||
ch := make(chan struct{}, n)
|
||||
for i := 0; i < n; i++ {
|
||||
ch <- struct{}{}
|
||||
}
|
||||
return &limitListener{l, ch}
|
||||
return &limitListener{l, make(chan struct{}, n)}
|
||||
}
|
||||
|
||||
type limitListener struct {
|
||||
net.Listener
|
||||
ch chan struct{}
|
||||
sem chan struct{}
|
||||
}
|
||||
|
||||
func (l *limitListener) acquire() { l.sem <- struct{}{} }
|
||||
func (l *limitListener) release() { <-l.sem }
|
||||
|
||||
func (l *limitListener) Accept() (net.Conn, error) {
|
||||
<-l.ch
|
||||
l.acquire()
|
||||
c, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
l.release()
|
||||
return nil, err
|
||||
}
|
||||
return &limitListenerConn{Conn: c, ch: l.ch}, nil
|
||||
return &limitListenerConn{Conn: c, release: l.release}, nil
|
||||
}
|
||||
|
||||
type limitListenerConn struct {
|
||||
net.Conn
|
||||
ch chan<- struct{}
|
||||
close sync.Once
|
||||
releaseOnce sync.Once
|
||||
release func()
|
||||
}
|
||||
|
||||
func (l *limitListenerConn) Close() error {
|
||||
err := l.Conn.Close()
|
||||
l.close.Do(func() {
|
||||
l.ch <- struct{}{}
|
||||
})
|
||||
l.releaseOnce.Do(l.release)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@@ -69,3 +70,34 @@ func TestLimitListener(t *testing.T) {
|
||||
t.Errorf("too many Gets failed: %v", failed)
|
||||
}
|
||||
}
|
||||
|
||||
type errorListener struct {
|
||||
net.Listener
|
||||
}
|
||||
|
||||
func (errorListener) Accept() (net.Conn, error) {
|
||||
return nil, errFake
|
||||
}
|
||||
|
||||
var errFake = errors.New("fake error from errorListener")
|
||||
|
||||
// This used to hang.
|
||||
func TestLimitListenerError(t *testing.T) {
|
||||
donec := make(chan bool, 1)
|
||||
go func() {
|
||||
const n = 2
|
||||
ll := LimitListener(errorListener{}, n)
|
||||
for i := 0; i < n+1; i++ {
|
||||
_, err := ll.Accept()
|
||||
if err != errFake {
|
||||
t.Fatalf("Accept error = %v; want errFake", err)
|
||||
}
|
||||
}
|
||||
donec <- true
|
||||
}()
|
||||
select {
|
||||
case <-donec:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout. deadlock?")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user