mirror of
https://github.com/golang/net.git
synced 2026-03-31 10:27:08 +09:00
go.net/websocket: allow server configurable
Add websocket.Server to configure WebSocket server handler. - Config.Header is additional headers to send, so you can use it to send cookies or so. To read cookies, you can use Conn.Request().Header. - factor out Handshake. You can set func to check origin, subprotocol etc. Handler checks origin by default. Fixes golang/go#4198. Fixes golang/go#5178. R=golang-dev, mikioh.mikioh, crobin CC=golang-dev https://golang.org/cl/8731044
This commit is contained in:
committed by
Mikio Hara
parent
94458b3b47
commit
0005f0a0c0
@@ -9,6 +9,7 @@ import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
@@ -34,6 +35,7 @@ func NewConfig(server, origin string) (config *Config, err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
config.Header = http.Header(make(map[string][]string))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -46,6 +46,17 @@ var (
|
||||
ErrBadClosingStatus = &ProtocolError{"bad closing status"}
|
||||
ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
|
||||
ErrNotImplemented = &ProtocolError{"not implemented"}
|
||||
|
||||
handshakeHeader = map[string]bool{
|
||||
"Host": true,
|
||||
"Upgrade": true,
|
||||
"Connection": true,
|
||||
"Sec-Websocket-Key": true,
|
||||
"Sec-Websocket-Origin": true,
|
||||
"Sec-Websocket-Version": true,
|
||||
"Sec-Websocket-Protocol": true,
|
||||
"Sec-Websocket-Accept": true,
|
||||
}
|
||||
)
|
||||
|
||||
// A hybiFrameHeader is a frame header as defined in hybi draft.
|
||||
@@ -408,8 +419,11 @@ func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (er
|
||||
if len(config.Protocol) > 0 {
|
||||
bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
|
||||
}
|
||||
// TODO(ukai): send extensions.
|
||||
// TODO(ukai): send cookie if any.
|
||||
// TODO(ukai): send Sec-WebSocket-Extensions.
|
||||
err = config.Header.WriteSubset(bw, handshakeHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bw.WriteString("\r\n")
|
||||
if err = bw.Flush(); err != nil {
|
||||
@@ -483,21 +497,14 @@ func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Reques
|
||||
return http.StatusBadRequest, ErrChallengeResponse
|
||||
}
|
||||
version := req.Header.Get("Sec-Websocket-Version")
|
||||
var origin string
|
||||
switch version {
|
||||
case "13":
|
||||
c.Version = ProtocolVersionHybi13
|
||||
origin = req.Header.Get("Origin")
|
||||
case "8":
|
||||
c.Version = ProtocolVersionHybi08
|
||||
origin = req.Header.Get("Sec-Websocket-Origin")
|
||||
default:
|
||||
return http.StatusBadRequest, ErrBadWebSocketVersion
|
||||
}
|
||||
c.Origin, err = url.ParseRequestURI(origin)
|
||||
if err != nil {
|
||||
return http.StatusForbidden, err
|
||||
}
|
||||
var scheme string
|
||||
if req.TLS != nil {
|
||||
scheme = "wss"
|
||||
@@ -520,6 +527,22 @@ func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Reques
|
||||
return http.StatusSwitchingProtocols, nil
|
||||
}
|
||||
|
||||
// Origin parses Origin header in "req".
|
||||
// If origin is "null", returns (nil, nil).
|
||||
func Origin(config *Config, req *http.Request) (*url.URL, error) {
|
||||
var origin string
|
||||
switch config.Version {
|
||||
case ProtocolVersionHybi13:
|
||||
origin = req.Header.Get("Origin")
|
||||
case ProtocolVersionHybi08:
|
||||
origin = req.Header.Get("Sec-Websocket-Origin")
|
||||
}
|
||||
if origin == "null" {
|
||||
return nil, nil
|
||||
}
|
||||
return url.ParseRequestURI(origin)
|
||||
}
|
||||
|
||||
func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
|
||||
if len(c.Protocol) > 0 {
|
||||
if len(c.Protocol) != 1 {
|
||||
@@ -533,7 +556,13 @@ func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
|
||||
if len(c.Protocol) > 0 {
|
||||
buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
|
||||
}
|
||||
// TODO(ukai): support extensions
|
||||
// TODO(ukai): send Sec-WebSocket-Extensions.
|
||||
if c.Header != nil {
|
||||
err := c.Header.WriteSubset(buf, handshakeHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
buf.WriteString("\r\n")
|
||||
return buf.Flush()
|
||||
}
|
||||
|
||||
@@ -92,6 +92,71 @@ Sec-WebSocket-Protocol: chat
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybiClientHandshakeWithHeader(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
bw := bufio.NewWriter(b)
|
||||
br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 Switching Protocols
|
||||
Upgrade: websocket
|
||||
Connection: Upgrade
|
||||
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
|
||||
Sec-WebSocket-Protocol: chat
|
||||
|
||||
`))
|
||||
var err error
|
||||
config := new(Config)
|
||||
config.Location, err = url.ParseRequestURI("ws://server.example.com/chat")
|
||||
if err != nil {
|
||||
t.Fatal("location url", err)
|
||||
}
|
||||
config.Origin, err = url.ParseRequestURI("http://example.com")
|
||||
if err != nil {
|
||||
t.Fatal("origin url", err)
|
||||
}
|
||||
config.Protocol = append(config.Protocol, "chat")
|
||||
config.Protocol = append(config.Protocol, "superchat")
|
||||
config.Version = ProtocolVersionHybi13
|
||||
config.Header = http.Header(make(map[string][]string))
|
||||
config.Header.Add("User-Agent", "test")
|
||||
|
||||
config.handshakeData = map[string]string{
|
||||
"key": "dGhlIHNhbXBsZSBub25jZQ==",
|
||||
}
|
||||
err = hybiClientHandshake(config, br, bw)
|
||||
if err != nil {
|
||||
t.Errorf("handshake failed: %v", err)
|
||||
}
|
||||
req, err := http.ReadRequest(bufio.NewReader(b))
|
||||
if err != nil {
|
||||
t.Fatalf("read request: %v", err)
|
||||
}
|
||||
if req.Method != "GET" {
|
||||
t.Errorf("request method expected GET, but got %q", req.Method)
|
||||
}
|
||||
if req.URL.Path != "/chat" {
|
||||
t.Errorf("request path expected /chat, but got %q", req.URL.Path)
|
||||
}
|
||||
if req.Proto != "HTTP/1.1" {
|
||||
t.Errorf("request proto expected HTTP/1.1, but got %q", req.Proto)
|
||||
}
|
||||
if req.Host != "server.example.com" {
|
||||
t.Errorf("request Host expected server.example.com, but got %v", req.Host)
|
||||
}
|
||||
var expectedHeader = map[string]string{
|
||||
"Connection": "Upgrade",
|
||||
"Upgrade": "websocket",
|
||||
"Sec-Websocket-Key": config.handshakeData["key"],
|
||||
"Origin": config.Origin.String(),
|
||||
"Sec-Websocket-Protocol": "chat, superchat",
|
||||
"Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi13),
|
||||
"User-Agent": "test",
|
||||
}
|
||||
for k, v := range expectedHeader {
|
||||
if req.Header.Get(k) != v {
|
||||
t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybiClientHandshakeHybi08(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
bw := bufio.NewWriter(b)
|
||||
|
||||
@@ -11,8 +11,7 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request) (conn *Conn, err error) {
|
||||
config := new(Config)
|
||||
func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) {
|
||||
var hs serverHandshaker = &hybiServerHandshaker{Config: config}
|
||||
code, err := hs.ReadHandshake(buf.Reader, req)
|
||||
if err == ErrBadWebSocketVersion {
|
||||
@@ -38,8 +37,16 @@ func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Requ
|
||||
buf.Flush()
|
||||
return
|
||||
}
|
||||
config.Protocol = nil
|
||||
|
||||
if handshake != nil {
|
||||
err = handshake(config, req)
|
||||
if err != nil {
|
||||
code = http.StatusForbidden
|
||||
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
|
||||
buf.WriteString("\r\n")
|
||||
buf.Flush()
|
||||
return
|
||||
}
|
||||
}
|
||||
err = hs.AcceptHandshake(buf.Writer)
|
||||
if err != nil {
|
||||
code = http.StatusBadRequest
|
||||
@@ -52,11 +59,26 @@ func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
// Handler is an interface to a WebSocket.
|
||||
type Handler func(*Conn)
|
||||
// Server represents a server of a WebSocket.
|
||||
type Server struct {
|
||||
// Config is a WebSocket configuration for new WebSocket connection.
|
||||
Config
|
||||
|
||||
// ServeHTTP implements the http.Handler interface for a Web Socket
|
||||
func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
// Handshake is an optional function in WebSocket handshake.
|
||||
// For example, you can check, or don't check Origin header.
|
||||
// Another example, you can select config.Protocol.
|
||||
Handshake func(*Config, *http.Request) error
|
||||
|
||||
// Handler handles a WebSocket connection.
|
||||
Handler
|
||||
}
|
||||
|
||||
// ServeHTTP implements the http.Handler interface for a WebSocket
|
||||
func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
s.serveWebSocket(w, req)
|
||||
}
|
||||
|
||||
func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) {
|
||||
rwc, buf, err := w.(http.Hijacker).Hijack()
|
||||
if err != nil {
|
||||
panic("Hijack failed: " + err.Error())
|
||||
@@ -66,12 +88,35 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
// the client did not send a handshake that matches with protocol
|
||||
// specification.
|
||||
defer rwc.Close()
|
||||
conn, err := newServerConn(rwc, buf, req)
|
||||
conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if conn == nil {
|
||||
panic("unexpected nil conn")
|
||||
}
|
||||
h(conn)
|
||||
s.Handler(conn)
|
||||
}
|
||||
|
||||
// Handler is a simple interface to a WebSocket browser client.
|
||||
// It checks if Origin header is valid URL by default.
|
||||
// You might want to verify websocket.Conn.Config().Origin in the func.
|
||||
// If you use Server instead of Handler, you could call websocket.Origin and
|
||||
// check the origin in your Handshake func. So, if you want to accept
|
||||
// non-browser client, which doesn't send Origin header, you could use Server
|
||||
//. that doesn't check origin in its Handshake.
|
||||
type Handler func(*Conn)
|
||||
|
||||
func checkOrigin(config *Config, req *http.Request) (err error) {
|
||||
config.Origin, err = Origin(config, req)
|
||||
if err == nil && config.Origin == nil {
|
||||
return fmt.Errorf("null origin")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// ServeHTTP implements the http.Handler interface for a WebSocket
|
||||
func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
s := Server{Handler: h, Handshake: checkOrigin}
|
||||
s.serveWebSocket(w, req)
|
||||
}
|
||||
|
||||
@@ -87,6 +87,9 @@ type Config struct {
|
||||
// TLS config for secure WebSocket (wss).
|
||||
TlsConfig *tls.Config
|
||||
|
||||
// Additional header fields to be sent in WebSocket opening handshake.
|
||||
Header http.Header
|
||||
|
||||
handshakeData map[string]string
|
||||
}
|
||||
|
||||
|
||||
@@ -44,9 +44,30 @@ func countServer(ws *Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
func subProtocolHandshake(config *Config, req *http.Request) error {
|
||||
for _, proto := range config.Protocol {
|
||||
if proto == "chat" {
|
||||
config.Protocol = []string{proto}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ErrBadWebSocketProtocol
|
||||
}
|
||||
|
||||
func subProtoServer(ws *Conn) {
|
||||
for _, proto := range ws.Config().Protocol {
|
||||
io.WriteString(ws, proto)
|
||||
}
|
||||
}
|
||||
|
||||
func startServer() {
|
||||
http.Handle("/echo", Handler(echoServer))
|
||||
http.Handle("/count", Handler(countServer))
|
||||
subproto := Server{
|
||||
Handshake: subProtocolHandshake,
|
||||
Handler: Handler(subProtoServer),
|
||||
}
|
||||
http.Handle("/subproto", subproto)
|
||||
server := httptest.NewServer(nil)
|
||||
serverAddr = server.Listener.Addr().String()
|
||||
log.Print("Test WebSocket server listening on ", serverAddr)
|
||||
@@ -177,7 +198,7 @@ func TestWithQuery(t *testing.T) {
|
||||
ws.Close()
|
||||
}
|
||||
|
||||
func TestWithProtocol(t *testing.T) {
|
||||
func testWithProtocol(t *testing.T, subproto []string) (string, error) {
|
||||
once.Do(startServer)
|
||||
|
||||
client, err := net.Dial("tcp", serverAddr)
|
||||
@@ -185,15 +206,47 @@ func TestWithProtocol(t *testing.T) {
|
||||
t.Fatal("dialing", err)
|
||||
}
|
||||
|
||||
config := newConfig(t, "/echo")
|
||||
config.Protocol = append(config.Protocol, "test")
|
||||
config := newConfig(t, "/subproto")
|
||||
config.Protocol = subproto
|
||||
|
||||
ws, err := NewClient(config, client)
|
||||
if err != nil {
|
||||
t.Errorf("WebSocket handshake: %v", err)
|
||||
return
|
||||
return "", err
|
||||
}
|
||||
msg := make([]byte, 16)
|
||||
n, err := ws.Read(msg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
ws.Close()
|
||||
return string(msg[:n]), nil
|
||||
}
|
||||
|
||||
func TestWithProtocol(t *testing.T) {
|
||||
proto, err := testWithProtocol(t, []string{"chat"})
|
||||
if err != nil {
|
||||
t.Errorf("SubProto: unexpected error: %v", err)
|
||||
}
|
||||
if proto != "chat" {
|
||||
t.Errorf("SubProto: expected %q, got %q", "chat", proto)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTwoProtocol(t *testing.T) {
|
||||
proto, err := testWithProtocol(t, []string{"test", "chat"})
|
||||
if err != nil {
|
||||
t.Errorf("SubProto: unexpected error: %v", err)
|
||||
}
|
||||
if proto != "chat" {
|
||||
t.Errorf("SubProto: expected %q, got %q", "chat", proto)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithBadProtocol(t *testing.T) {
|
||||
_, err := testWithProtocol(t, []string{"test"})
|
||||
if err != ErrBadStatus {
|
||||
t.Errorf("SubProto: expected %q, got %q", ErrBadStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user