mirror of
https://github.com/golang/net.git
synced 2026-03-31 02:17:08 +09:00
websocket: add support for dialing with context
Right now there is no way to pass context.Context to websocket.Dial.
In addition, this method can block indefinitely in the NewClient call.
Fixes golang/go#57953.
Change-Id: Ic52d4b8306cd0850e78d683abb1bf11f0d4247ca
GitHub-Last-Rev: 5e8c3a7cba
GitHub-Pull-Request: golang/net#160
Reviewed-on: https://go-review.googlesource.com/c/net/+/463097
Auto-Submit: Damien Neil <dneil@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
This commit is contained in:
committed by
Gopher Robot
parent
fa11427993
commit
3dfd003ad3
@@ -6,10 +6,12 @@ package websocket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DialError is an error that occurs while dialling a websocket server.
|
||||
@@ -77,30 +79,60 @@ func parseAuthority(location *url.URL) string {
|
||||
return location.Host
|
||||
}
|
||||
|
||||
// DialConfig opens a new client connection to a WebSocket with a config.
|
||||
func DialConfig(config *Config) (ws *Conn, err error) {
|
||||
var client net.Conn
|
||||
return config.DialContext(context.Background())
|
||||
}
|
||||
|
||||
// DialContext opens a new client connection to a WebSocket, with context support for timeouts/cancellation.
|
||||
func (config *Config) DialContext(ctx context.Context) (*Conn, error) {
|
||||
if config.Location == nil {
|
||||
return nil, &DialError{config, ErrBadWebSocketLocation}
|
||||
}
|
||||
if config.Origin == nil {
|
||||
return nil, &DialError{config, ErrBadWebSocketOrigin}
|
||||
}
|
||||
|
||||
dialer := config.Dialer
|
||||
if dialer == nil {
|
||||
dialer = &net.Dialer{}
|
||||
}
|
||||
client, err = dialWithDialer(dialer, config)
|
||||
if err != nil {
|
||||
goto Error
|
||||
}
|
||||
ws, err = NewClient(config, client)
|
||||
if err != nil {
|
||||
client.Close()
|
||||
goto Error
|
||||
}
|
||||
return
|
||||
|
||||
Error:
|
||||
return nil, &DialError{config, err}
|
||||
client, err := dialWithDialer(ctx, dialer, config)
|
||||
if err != nil {
|
||||
return nil, &DialError{config, err}
|
||||
}
|
||||
|
||||
// Cleanup the connection if we fail to create the websocket successfully
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
_ = client.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
var ws *Conn
|
||||
var wsErr error
|
||||
doneConnecting := make(chan struct{})
|
||||
go func() {
|
||||
defer close(doneConnecting)
|
||||
ws, err = NewClient(config, client)
|
||||
if err != nil {
|
||||
wsErr = &DialError{config, err}
|
||||
}
|
||||
}()
|
||||
|
||||
// The websocket.NewClient() function can block indefinitely, make sure that we
|
||||
// respect the deadlines specified by the context.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Force the pending operations to fail, terminating the pending connection attempt
|
||||
_ = client.SetDeadline(time.Now())
|
||||
<-doneConnecting // Wait for the goroutine that tries to establish the connection to finish
|
||||
return nil, &DialError{config, ctx.Err()}
|
||||
case <-doneConnecting:
|
||||
if wsErr == nil {
|
||||
success = true // Disarm the deferred connection cleanup
|
||||
}
|
||||
return ws, wsErr
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,18 +5,23 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
)
|
||||
|
||||
func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
|
||||
func dialWithDialer(ctx context.Context, dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
|
||||
switch config.Location.Scheme {
|
||||
case "ws":
|
||||
conn, err = dialer.Dial("tcp", parseAuthority(config.Location))
|
||||
conn, err = dialer.DialContext(ctx, "tcp", parseAuthority(config.Location))
|
||||
|
||||
case "wss":
|
||||
conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig)
|
||||
tlsDialer := &tls.Dialer{
|
||||
NetDialer: dialer,
|
||||
Config: config.TlsConfig,
|
||||
}
|
||||
|
||||
conn, err = tlsDialer.DialContext(ctx, "tcp", parseAuthority(config.Location))
|
||||
default:
|
||||
err = ErrBadScheme
|
||||
}
|
||||
|
||||
@@ -5,10 +5,13 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -41,3 +44,37 @@ func TestDialConfigTLSWithDialer(t *testing.T) {
|
||||
t.Fatalf("expected timeout error, got %#v", neterr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialConfigTLSWithTimeouts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
finishedRequest := make(chan bool)
|
||||
|
||||
// Context for cancellation
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// This is a TLS server that blocks each request indefinitely (and cancels the context)
|
||||
tlsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cancel()
|
||||
<-finishedRequest
|
||||
}))
|
||||
|
||||
tlsServerAddr := tlsServer.Listener.Addr().String()
|
||||
log.Print("Test TLS WebSocket server listening on ", tlsServerAddr)
|
||||
defer tlsServer.Close()
|
||||
defer close(finishedRequest)
|
||||
|
||||
config, _ := NewConfig(fmt.Sprintf("wss://%s/echo", tlsServerAddr), "http://localhost")
|
||||
config.TlsConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
|
||||
_, err := config.DialContext(ctx)
|
||||
dialerr, ok := err.(*DialError)
|
||||
if !ok {
|
||||
t.Fatalf("DialError expected, got %#v", err)
|
||||
}
|
||||
if !errors.Is(dialerr.Err, context.Canceled) {
|
||||
t.Fatalf("context.Canceled error expected, got %#v", dialerr.Err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user