From a8e8f92cd65f5ff48599999fa4001d9b95ecea85 Mon Sep 17 00:00:00 2001 From: Tom Bergan Date: Mon, 15 May 2017 12:29:23 -0700 Subject: [PATCH] http2: remove extra goroutine stack from awaitGracefulShutdown This is a better fix that https://golang.org/cl/43455. Instead of creating a separate goroutine to wait for the global shutdown channel, we reuse the new serverMsgCh, which was added in a prior CL. We also use the new net/http.Server.RegisterOnShutdown method to register a shutdown callback for each http2.Server. Updates golang/go#20302 Updates golang/go#18471 Change-Id: Icf29d5e4f65b3779d1fb4ea92924e4fb6bdadb2a Reviewed-on: https://go-review.googlesource.com/43230 Run-TryBot: Tom Bergan Reviewed-by: Brad Fitzpatrick --- http2/go19.go | 16 ++++++ http2/go19_test.go | 60 +++++++++++++++++++++ http2/not_go19.go | 16 ++++++ http2/server.go | 122 ++++++++++++++++++++++++++----------------- http2/server_test.go | 45 ---------------- 5 files changed, 165 insertions(+), 94 deletions(-) create mode 100644 http2/go19.go create mode 100644 http2/go19_test.go create mode 100644 http2/not_go19.go diff --git a/http2/go19.go b/http2/go19.go new file mode 100644 index 00000000..38124ba5 --- /dev/null +++ b/http2/go19.go @@ -0,0 +1,16 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.9 + +package http2 + +import ( + "net/http" +) + +func configureServer19(s *http.Server, conf *Server) error { + s.RegisterOnShutdown(conf.state.startGracefulShutdown) + return nil +} diff --git a/http2/go19_test.go b/http2/go19_test.go new file mode 100644 index 00000000..1675d248 --- /dev/null +++ b/http2/go19_test.go @@ -0,0 +1,60 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.9 + +package http2 + +import ( + "context" + "net/http" + "reflect" + "testing" + "time" +) + +func TestServerGracefulShutdown(t *testing.T) { + var st *serverTester + handlerDone := make(chan struct{}) + st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + defer close(handlerDone) + go st.ts.Config.Shutdown(context.Background()) + + ga := st.wantGoAway() + if ga.ErrCode != ErrCodeNo { + t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode) + } + if ga.LastStreamID != 1 { + t.Errorf("GOAWAY LastStreamID = %v; want 1", ga.LastStreamID) + } + + w.Header().Set("x-foo", "bar") + }) + defer st.Close() + + st.greet() + st.bodylessReq1() + + select { + case <-handlerDone: + case <-time.After(5 * time.Second): + t.Fatalf("server did not shutdown?") + } + hf := st.wantHeaders() + goth := st.decodeHeader(hf.HeaderBlockFragment()) + wanth := [][2]string{ + {":status", "200"}, + {"x-foo", "bar"}, + {"content-type", "text/plain; charset=utf-8"}, + {"content-length", "0"}, + } + if !reflect.DeepEqual(goth, wanth) { + t.Errorf("Got headers %v; want %v", goth, wanth) + } + + n, err := st.cc.Read([]byte{0}) + if n != 0 || err == nil { + t.Errorf("Read = %v, %v; want 0, non-nil", n, err) + } +} diff --git a/http2/not_go19.go b/http2/not_go19.go new file mode 100644 index 00000000..5ae07726 --- /dev/null +++ b/http2/not_go19.go @@ -0,0 +1,16 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.9 + +package http2 + +import ( + "net/http" +) + +func configureServer19(s *http.Server, conf *Server) error { + // not supported prior to go1.9 + return nil +} diff --git a/http2/server.go b/http2/server.go index 3175a08c..7367b31c 100644 --- a/http2/server.go +++ b/http2/server.go @@ -126,6 +126,11 @@ type Server struct { // NewWriteScheduler constructs a write scheduler for a connection. // If nil, a default scheduler is chosen. NewWriteScheduler func() WriteScheduler + + // Internal state. This is a pointer (rather than embedded directly) + // so that we don't embed a Mutex in this struct, which will make the + // struct non-copyable, which might break some callers. + state *serverInternalState } func (s *Server) initialConnRecvWindowSize() int32 { @@ -156,6 +161,40 @@ func (s *Server) maxConcurrentStreams() uint32 { return defaultMaxStreams } +type serverInternalState struct { + mu sync.Mutex + activeConns map[*serverConn]struct{} +} + +func (s *serverInternalState) registerConn(sc *serverConn) { + if s == nil { + return // if the Server was used without calling ConfigureServer + } + s.mu.Lock() + s.activeConns[sc] = struct{}{} + s.mu.Unlock() +} + +func (s *serverInternalState) unregisterConn(sc *serverConn) { + if s == nil { + return // if the Server was used without calling ConfigureServer + } + s.mu.Lock() + delete(s.activeConns, sc) + s.mu.Unlock() +} + +func (s *serverInternalState) startGracefulShutdown() { + if s == nil { + return // if the Server was used without calling ConfigureServer + } + s.mu.Lock() + for sc := range s.activeConns { + sc.startGracefulShutdown() + } + s.mu.Unlock() +} + // ConfigureServer adds HTTP/2 support to a net/http Server. // // The configuration conf may be nil. @@ -168,9 +207,13 @@ func ConfigureServer(s *http.Server, conf *Server) error { if conf == nil { conf = new(Server) } + conf.state = &serverInternalState{activeConns: make(map[*serverConn]struct{})} if err := configureServer18(s, conf); err != nil { return err } + if err := configureServer19(s, conf); err != nil { + return err + } if s.TLSConfig == nil { s.TLSConfig = new(tls.Config) @@ -305,6 +348,9 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { pushEnabled: true, } + s.state.registerConn(sc) + defer s.state.unregisterConn(sc) + // The net/http package sets the write deadline from the // http.Server.WriteTimeout during the TLS handshake, but then // passes the connection off to us with the deadline already set. @@ -445,6 +491,9 @@ type serverConn struct { // Owned by the writeFrameAsync goroutine: headerWriteBuf bytes.Buffer hpackEncoder *hpack.Encoder + + // Used by startGracefulShutdown. + shutdownOnce sync.Once } func (sc *serverConn) maxHeaderListSize() uint32 { @@ -749,15 +798,6 @@ func (sc *serverConn) serve() { defer sc.idleTimer.Stop() } - var gracefulShutdownCh chan struct{} - if sc.hs != nil { - ch := h1ServerShutdownChan(sc.hs) - if ch != nil { - gracefulShutdownCh = make(chan struct{}) - go sc.awaitGracefulShutdown(ch, gracefulShutdownCh) - } - } - go sc.readFrames() // closed by defer sc.conn.Close above settingsTimer := time.AfterFunc(firstSettingsTimeout, sc.onSettingsTimer) @@ -786,14 +826,11 @@ func (sc *serverConn) serve() { } case m := <-sc.bodyReadCh: sc.noteBodyRead(m.st, m.n) - case <-gracefulShutdownCh: - gracefulShutdownCh = nil - sc.startGracefulShutdown() case msg := <-sc.serveMsgCh: switch v := msg.(type) { case func(int): v(loopNum) // for testing - case *timerMessage: + case *serverMessage: switch v { case settingsTimerMsg: sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr()) @@ -804,6 +841,8 @@ func (sc *serverConn) serve() { case shutdownTimerMsg: sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) return + case gracefulShutdownMsg: + sc.startGracefulShutdownInternal() default: panic("unknown timer") } @@ -828,13 +867,14 @@ func (sc *serverConn) awaitGracefulShutdown(sharedCh <-chan struct{}, privateCh } } -type timerMessage int +type serverMessage int -// Timeout message values sent to serveMsgCh. +// Message values sent to serveMsgCh. var ( - settingsTimerMsg = new(timerMessage) - idleTimerMsg = new(timerMessage) - shutdownTimerMsg = new(timerMessage) + settingsTimerMsg = new(serverMessage) + idleTimerMsg = new(serverMessage) + shutdownTimerMsg = new(serverMessage) + gracefulShutdownMsg = new(serverMessage) ) func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) } @@ -1166,10 +1206,19 @@ func (sc *serverConn) scheduleFrameWrite() { sc.inFrameScheduleLoop = false } -// startGracefulShutdown sends a GOAWAY with ErrCodeNo to tell the -// client we're gracefully shutting down. The connection isn't closed -// until all current streams are done. +// startGracefulShutdown gracefully shuts down a connection. This +// sends GOAWAY with ErrCodeNo to tell the client we're gracefully +// shutting down. The connection isn't closed until all current +// streams are done. +// +// startGracefulShutdown returns immediately; it does not wait until +// the connection has shut down. func (sc *serverConn) startGracefulShutdown() { + sc.serveG.checkNotOn() // NOT + sc.shutdownOnce.Do(func() { sc.sendServeMsg(gracefulShutdownMsg) }) +} + +func (sc *serverConn) startGracefulShutdownInternal() { sc.goAwayIn(ErrCodeNo, 0) } @@ -1399,7 +1448,7 @@ func (sc *serverConn) closeStream(st *stream, err error) { sc.idleTimer.Reset(sc.srv.IdleTimeout) } if h1ServerKeepAlivesDisabled(sc.hs) { - sc.startGracefulShutdown() + sc.startGracefulShutdownInternal() } } if p := st.body; p != nil { @@ -1586,7 +1635,7 @@ func (sc *serverConn) processGoAway(f *GoAwayFrame) error { } else { sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f) } - sc.startGracefulShutdown() + sc.startGracefulShutdownInternal() // http://tools.ietf.org/html/rfc7540#section-6.8 // We should not create any new streams, which means we should disable push. sc.pushEnabled = false @@ -2653,7 +2702,7 @@ func (sc *serverConn) startPush(msg *startPushRequest) { // A server that is unable to establish a new stream identifier can send a GOAWAY // frame so that the client is forced to open a new connection for new streams. if sc.maxPushPromiseID+2 >= 1<<31 { - sc.startGracefulShutdown() + sc.startGracefulShutdownInternal() return 0, ErrPushLimitReached } sc.maxPushPromiseID += 2 @@ -2778,31 +2827,6 @@ var badTrailer = map[string]bool{ "Www-Authenticate": true, } -// h1ServerShutdownChan returns a channel that will be closed when the -// provided *http.Server wants to shut down. -// -// This is a somewhat hacky way to get at http1 innards. It works -// when the http2 code is bundled into the net/http package in the -// standard library. The alternatives ended up making the cmd/go tool -// depend on http Servers. This is the lightest option for now. -// This is tested via the TestServeShutdown* tests in net/http. -func h1ServerShutdownChan(hs *http.Server) <-chan struct{} { - if fn := testh1ServerShutdownChan; fn != nil { - return fn(hs) - } - var x interface{} = hs - type I interface { - getDoneChan() <-chan struct{} - } - if hs, ok := x.(I); ok { - return hs.getDoneChan() - } - return nil -} - -// optional test hook for h1ServerShutdownChan. -var testh1ServerShutdownChan func(hs *http.Server) <-chan struct{} - // h1ServerKeepAlivesDisabled reports whether hs has its keep-alives // disabled. See comments on h1ServerShutdownChan above for why // the code is written this way. diff --git a/http2/server_test.go b/http2/server_test.go index 5cb24905..638d2a4a 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -3685,48 +3685,3 @@ func TestRequestBodyReadCloseRace(t *testing.T) { <-done } } - -func TestServerGracefulShutdown(t *testing.T) { - shutdownCh := make(chan struct{}) - defer func() { testh1ServerShutdownChan = nil }() - testh1ServerShutdownChan = func(*http.Server) <-chan struct{} { return shutdownCh } - - var st *serverTester - handlerDone := make(chan struct{}) - st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - defer close(handlerDone) - close(shutdownCh) - - ga := st.wantGoAway() - if ga.ErrCode != ErrCodeNo { - t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode) - } - if ga.LastStreamID != 1 { - t.Errorf("GOAWAY LastStreamID = %v; want 1", ga.LastStreamID) - } - - w.Header().Set("x-foo", "bar") - }) - defer st.Close() - - st.greet() - st.bodylessReq1() - - <-handlerDone - hf := st.wantHeaders() - goth := st.decodeHeader(hf.HeaderBlockFragment()) - wanth := [][2]string{ - {":status", "200"}, - {"x-foo", "bar"}, - {"content-type", "text/plain; charset=utf-8"}, - {"content-length", "0"}, - } - if !reflect.DeepEqual(goth, wanth) { - t.Errorf("Got headers %v; want %v", goth, wanth) - } - - n, err := st.cc.Read([]byte{0}) - if n != 0 || err == nil { - t.Errorf("Read = %v, %v; want 0, non-nil", n, err) - } -}