websocket: add support for dialing with context

Right now there is no way to pass context.Context to websocket.Dial.
In addition, this method can block indefinitely in the NewClient call.

Fixes golang/go#57953.

Change-Id: Ic52d4b8306cd0850e78d683abb1bf11f0d4247ca
GitHub-Last-Rev: 5e8c3a7cba
GitHub-Pull-Request: golang/net#160
Reviewed-on: https://go-review.googlesource.com/c/net/+/463097
Auto-Submit: Damien Neil <dneil@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
This commit is contained in:
Aleksei Besogonov
2024-01-12 07:38:27 +00:00
committed by Gopher Robot
parent fa11427993
commit 3dfd003ad3
3 changed files with 91 additions and 17 deletions

View File

@@ -6,10 +6,12 @@ package websocket
import (
"bufio"
"context"
"io"
"net"
"net/http"
"net/url"
"time"
)
// DialError is an error that occurs while dialling a websocket server.
@@ -77,30 +79,60 @@ func parseAuthority(location *url.URL) string {
return location.Host
}
// DialConfig opens a new client connection to a WebSocket with a config.
func DialConfig(config *Config) (ws *Conn, err error) {
var client net.Conn
return config.DialContext(context.Background())
}
// DialContext opens a new client connection to a WebSocket, with context support for timeouts/cancellation.
func (config *Config) DialContext(ctx context.Context) (*Conn, error) {
if config.Location == nil {
return nil, &DialError{config, ErrBadWebSocketLocation}
}
if config.Origin == nil {
return nil, &DialError{config, ErrBadWebSocketOrigin}
}
dialer := config.Dialer
if dialer == nil {
dialer = &net.Dialer{}
}
client, err = dialWithDialer(dialer, config)
if err != nil {
goto Error
}
ws, err = NewClient(config, client)
if err != nil {
client.Close()
goto Error
}
return
Error:
return nil, &DialError{config, err}
client, err := dialWithDialer(ctx, dialer, config)
if err != nil {
return nil, &DialError{config, err}
}
// Cleanup the connection if we fail to create the websocket successfully
success := false
defer func() {
if !success {
_ = client.Close()
}
}()
var ws *Conn
var wsErr error
doneConnecting := make(chan struct{})
go func() {
defer close(doneConnecting)
ws, err = NewClient(config, client)
if err != nil {
wsErr = &DialError{config, err}
}
}()
// The websocket.NewClient() function can block indefinitely, make sure that we
// respect the deadlines specified by the context.
select {
case <-ctx.Done():
// Force the pending operations to fail, terminating the pending connection attempt
_ = client.SetDeadline(time.Now())
<-doneConnecting // Wait for the goroutine that tries to establish the connection to finish
return nil, &DialError{config, ctx.Err()}
case <-doneConnecting:
if wsErr == nil {
success = true // Disarm the deferred connection cleanup
}
return ws, wsErr
}
}

View File

@@ -5,18 +5,23 @@
package websocket
import (
"context"
"crypto/tls"
"net"
)
func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
func dialWithDialer(ctx context.Context, dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
switch config.Location.Scheme {
case "ws":
conn, err = dialer.Dial("tcp", parseAuthority(config.Location))
conn, err = dialer.DialContext(ctx, "tcp", parseAuthority(config.Location))
case "wss":
conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig)
tlsDialer := &tls.Dialer{
NetDialer: dialer,
Config: config.TlsConfig,
}
conn, err = tlsDialer.DialContext(ctx, "tcp", parseAuthority(config.Location))
default:
err = ErrBadScheme
}

View File

@@ -5,10 +5,13 @@
package websocket
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
@@ -41,3 +44,37 @@ func TestDialConfigTLSWithDialer(t *testing.T) {
t.Fatalf("expected timeout error, got %#v", neterr)
}
}
func TestDialConfigTLSWithTimeouts(t *testing.T) {
t.Parallel()
finishedRequest := make(chan bool)
// Context for cancellation
ctx, cancel := context.WithCancel(context.Background())
// This is a TLS server that blocks each request indefinitely (and cancels the context)
tlsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cancel()
<-finishedRequest
}))
tlsServerAddr := tlsServer.Listener.Addr().String()
log.Print("Test TLS WebSocket server listening on ", tlsServerAddr)
defer tlsServer.Close()
defer close(finishedRequest)
config, _ := NewConfig(fmt.Sprintf("wss://%s/echo", tlsServerAddr), "http://localhost")
config.TlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
_, err := config.DialContext(ctx)
dialerr, ok := err.(*DialError)
if !ok {
t.Fatalf("DialError expected, got %#v", err)
}
if !errors.Is(dialerr.Err, context.Canceled) {
t.Fatalf("context.Canceled error expected, got %#v", dialerr.Err)
}
}