diff --git a/proxy/dial_test.go b/proxy/dial_test.go index 3edab49d..608835b5 100644 --- a/proxy/dial_test.go +++ b/proxy/dial_test.go @@ -7,30 +7,26 @@ package proxy import ( "context" "fmt" - "net" "os" "testing" "time" "golang.org/x/net/internal/sockstest" + "golang.org/x/net/nettest" ) func TestDial(t *testing.T) { ResetProxyEnv() t.Run("DirectWithCancel", func(t *testing.T) { defer ResetProxyEnv() - l, err := net.Listen("tcp", "127.0.0.1:0") + l, err := nettest.NewLocalListener("tcp") if err != nil { t.Fatal(err) } defer l.Close() - _, port, err := net.SplitHostPort(l.Addr().String()) - if err != nil { - t.Fatal(err) - } ctx, cancel := context.WithCancel(context.Background()) defer cancel() - c, err := Dial(ctx, l.Addr().Network(), net.JoinHostPort("", port)) + c, err := Dial(ctx, l.Addr().Network(), l.Addr().String()) if err != nil { t.Fatal(err) } @@ -38,18 +34,14 @@ func TestDial(t *testing.T) { }) t.Run("DirectWithTimeout", func(t *testing.T) { defer ResetProxyEnv() - l, err := net.Listen("tcp", "127.0.0.1:0") + l, err := nettest.NewLocalListener("tcp") if err != nil { t.Fatal(err) } defer l.Close() - _, port, err := net.SplitHostPort(l.Addr().String()) - if err != nil { - t.Fatal(err) - } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - c, err := Dial(ctx, l.Addr().Network(), net.JoinHostPort("", port)) + c, err := Dial(ctx, l.Addr().Network(), l.Addr().String()) if err != nil { t.Fatal(err) } @@ -57,19 +49,15 @@ func TestDial(t *testing.T) { }) t.Run("DirectWithTimeoutExceeded", func(t *testing.T) { defer ResetProxyEnv() - l, err := net.Listen("tcp", "127.0.0.1:0") + l, err := nettest.NewLocalListener("tcp") if err != nil { t.Fatal(err) } defer l.Close() - _, port, err := net.SplitHostPort(l.Addr().String()) - if err != nil { - t.Fatal(err) - } ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) time.Sleep(time.Millisecond) defer cancel() - c, err := Dial(ctx, l.Addr().Network(), net.JoinHostPort("", port)) + c, err := Dial(ctx, l.Addr().Network(), l.Addr().String()) if err == nil { defer c.Close() t.Fatal("failed to timeout")