http2: use synthetic time in server tests

Change newServerTester to return a server using fake time
and a fake net.Conn.

Change-Id: I9d5db0cbe75696aed6d99ff1cd2369c2dea426c3
Reviewed-on: https://go-review.googlesource.com/c/net/+/586247
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
This commit is contained in:
Damien Neil
2024-05-18 12:55:39 -07:00
parent 022530c415
commit 03c24c2d76
7 changed files with 318 additions and 156 deletions

View File

@@ -17,6 +17,7 @@ package http2 // import "golang.org/x/net/http2"
import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
@@ -26,6 +27,7 @@ import (
"strconv"
"strings"
"sync"
"time"
"golang.org/x/net/http/httpguts"
)
@@ -377,3 +379,14 @@ func validPseudoPath(v string) bool {
// makes that struct also non-comparable, and generally doesn't add
// any size (as long as it's first).
type incomparable [0]func()
// synctestGroupInterface is the methods of synctestGroup used by Server and Transport.
// It's defined as an interface here to let us keep synctestGroup entirely test-only
// and not a part of non-test builds.
type synctestGroupInterface interface {
Join()
Now() time.Time
NewTimer(d time.Duration) timer
AfterFunc(d time.Duration, f func()) timer
ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc)
}

View File

@@ -154,6 +154,39 @@ type Server struct {
// so that we don't embed a Mutex in this struct, which will make the
// struct non-copyable, which might break some callers.
state *serverInternalState
// Synchronization group used for testing.
// Outside of tests, this is nil.
group synctestGroupInterface
}
func (s *Server) markNewGoroutine() {
if s.group != nil {
s.group.Join()
}
}
func (s *Server) now() time.Time {
if s.group != nil {
return s.group.Now()
}
return time.Now()
}
// newTimer creates a new time.Timer, or a synthetic timer in tests.
func (s *Server) newTimer(d time.Duration) timer {
if s.group != nil {
return s.group.NewTimer(d)
}
return timeTimer{time.NewTimer(d)}
}
// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
func (s *Server) afterFunc(d time.Duration, f func()) timer {
if s.group != nil {
return s.group.AfterFunc(d, f)
}
return timeTimer{time.AfterFunc(d, f)}
}
func (s *Server) initialConnRecvWindowSize() int32 {
@@ -400,6 +433,10 @@ func (o *ServeConnOpts) handler() http.Handler {
//
// The opts parameter is optional. If nil, default values are used.
func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
s.serveConn(c, opts, nil)
}
func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverConn)) {
baseCtx, cancel := serverConnBaseContext(c, opts)
defer cancel()
@@ -426,6 +463,9 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
pushEnabled: true,
sawClientPreface: opts.SawClientPreface,
}
if newf != nil {
newf(sc)
}
s.state.registerConn(sc)
defer s.state.unregisterConn(sc)
@@ -599,8 +639,8 @@ type serverConn struct {
inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop
needToSendGoAway bool // we need to schedule a GOAWAY frame write
goAwayCode ErrCode
shutdownTimer *time.Timer // nil until used
idleTimer *time.Timer // nil if unused
shutdownTimer timer // nil until used
idleTimer timer // nil if unused
// Owned by the writeFrameAsync goroutine:
headerWriteBuf bytes.Buffer
@@ -649,12 +689,12 @@ type stream struct {
flow outflow // limits writing from Handler to client
inflow inflow // what the client is allowed to POST/etc to us
state streamState
resetQueued bool // RST_STREAM queued for write; set by sc.resetStream
gotTrailerHeader bool // HEADER frame for trailers was seen
wroteHeaders bool // whether we wrote headers (not status 100)
readDeadline *time.Timer // nil if unused
writeDeadline *time.Timer // nil if unused
closeErr error // set before cw is closed
resetQueued bool // RST_STREAM queued for write; set by sc.resetStream
gotTrailerHeader bool // HEADER frame for trailers was seen
wroteHeaders bool // whether we wrote headers (not status 100)
readDeadline timer // nil if unused
writeDeadline timer // nil if unused
closeErr error // set before cw is closed
trailer http.Header // accumulated trailers
reqTrailer http.Header // handler's Request.Trailer
@@ -811,6 +851,7 @@ type readFrameResult struct {
// consumer is done with the frame.
// It's run on its own goroutine.
func (sc *serverConn) readFrames() {
sc.srv.markNewGoroutine()
gate := make(chan struct{})
gateDone := func() { gate <- struct{}{} }
for {
@@ -843,6 +884,7 @@ type frameWriteResult struct {
// At most one goroutine can be running writeFrameAsync at a time per
// serverConn.
func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest, wd *writeData) {
sc.srv.markNewGoroutine()
var err error
if wd == nil {
err = wr.write.writeFrame(sc)
@@ -922,13 +964,13 @@ func (sc *serverConn) serve() {
sc.setConnState(http.StateIdle)
if sc.srv.IdleTimeout > 0 {
sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
sc.idleTimer = sc.srv.afterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
defer sc.idleTimer.Stop()
}
go sc.readFrames() // closed by defer sc.conn.Close above
settingsTimer := time.AfterFunc(firstSettingsTimeout, sc.onSettingsTimer)
settingsTimer := sc.srv.afterFunc(firstSettingsTimeout, sc.onSettingsTimer)
defer settingsTimer.Stop()
loopNum := 0
@@ -1057,10 +1099,10 @@ func (sc *serverConn) readPreface() error {
errc <- nil
}
}()
timer := time.NewTimer(prefaceTimeout) // TODO: configurable on *Server?
timer := sc.srv.newTimer(prefaceTimeout) // TODO: configurable on *Server?
defer timer.Stop()
select {
case <-timer.C:
case <-timer.C():
return errPrefaceTimeout
case err := <-errc:
if err == nil {
@@ -1425,7 +1467,7 @@ func (sc *serverConn) goAway(code ErrCode) {
func (sc *serverConn) shutDownIn(d time.Duration) {
sc.serveG.check()
sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer)
sc.shutdownTimer = sc.srv.afterFunc(d, sc.onShutdownTimer)
}
func (sc *serverConn) resetStream(se StreamError) {
@@ -2022,7 +2064,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
// (in Go 1.8), though. That's a more sane option anyway.
if sc.hs.ReadTimeout > 0 {
sc.conn.SetReadDeadline(time.Time{})
st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
st.readDeadline = sc.srv.afterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
}
return sc.scheduleHandler(id, rw, req, handler)
@@ -2120,7 +2162,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
st.flow.add(sc.initialStreamSendWindowSize)
st.inflow.init(sc.srv.initialStreamRecvWindowSize())
if sc.hs.WriteTimeout > 0 {
st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
st.writeDeadline = sc.srv.afterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
}
sc.streams[id] = st
@@ -2344,6 +2386,7 @@ func (sc *serverConn) handlerDone() {
// Run on its own goroutine.
func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) {
sc.srv.markNewGoroutine()
defer sc.sendServeMsg(handlerDoneMsg)
didPanic := true
defer func() {
@@ -2640,7 +2683,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
var date string
if _, ok := rws.snapHeader["Date"]; !ok {
// TODO(bradfitz): be faster here, like net/http? measure.
date = time.Now().UTC().Format(http.TimeFormat)
date = rws.conn.srv.now().UTC().Format(http.TimeFormat)
}
for _, v := range rws.snapHeader["Trailer"] {
@@ -2762,7 +2805,7 @@ func (rws *responseWriterState) promoteUndeclaredTrailers() {
func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
st := w.rws.stream
if !deadline.IsZero() && deadline.Before(time.Now()) {
if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) {
// If we're setting a deadline in the past, reset the stream immediately
// so writes after SetWriteDeadline returns will fail.
st.onReadTimeout()
@@ -2778,9 +2821,9 @@ func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
if deadline.IsZero() {
st.readDeadline = nil
} else if st.readDeadline == nil {
st.readDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onReadTimeout)
st.readDeadline = sc.srv.afterFunc(deadline.Sub(w.rws.conn.srv.now()), st.onReadTimeout)
} else {
st.readDeadline.Reset(deadline.Sub(time.Now()))
st.readDeadline.Reset(deadline.Sub(w.rws.conn.srv.now()))
}
})
return nil
@@ -2788,7 +2831,7 @@ func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
st := w.rws.stream
if !deadline.IsZero() && deadline.Before(time.Now()) {
if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) {
// If we're setting a deadline in the past, reset the stream immediately
// so writes after SetWriteDeadline returns will fail.
st.onWriteTimeout()
@@ -2804,9 +2847,9 @@ func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
if deadline.IsZero() {
st.writeDeadline = nil
} else if st.writeDeadline == nil {
st.writeDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onWriteTimeout)
st.writeDeadline = sc.srv.afterFunc(deadline.Sub(w.rws.conn.srv.now()), st.onWriteTimeout)
} else {
st.writeDeadline.Reset(deadline.Sub(time.Now()))
st.writeDeadline.Reset(deadline.Sub(w.rws.conn.srv.now()))
}
})
return nil

View File

@@ -105,7 +105,7 @@ func TestServer_Push_Success(t *testing.T) {
errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
}
})
stURL = st.ts.URL
stURL = "https://" + st.authority()
// Send one request, which should push two responses.
st.greet()
@@ -169,7 +169,7 @@ func TestServer_Push_Success(t *testing.T) {
return checkPushPromise(f, 2, [][2]string{
{":method", "GET"},
{":scheme", "https"},
{":authority", st.ts.Listener.Addr().String()},
{":authority", st.authority()},
{":path", "/pushed?get"},
{"user-agent", userAgent},
})
@@ -178,7 +178,7 @@ func TestServer_Push_Success(t *testing.T) {
return checkPushPromise(f, 4, [][2]string{
{":method", "HEAD"},
{":scheme", "https"},
{":authority", st.ts.Listener.Addr().String()},
{":authority", st.authority()},
{":path", "/pushed?head"},
{"cookie", cookie},
{"user-agent", userAgent},

View File

@@ -15,6 +15,7 @@ import (
"fmt"
"io"
"log"
"math"
"net"
"net/http"
"net/http/httptest"
@@ -66,7 +67,9 @@ func (sb *safeBuffer) Len() int {
type serverTester struct {
cc net.Conn // client conn
t testing.TB
ts *httptest.Server
group *synctestGroup
h1server *http.Server
h2server *Server
fr *Framer
serverLogBuf safeBuffer // logger for httptest.Server
logFilter []string // substrings to filter out
@@ -109,6 +112,8 @@ func newTestServer(t testing.TB, handler http.HandlerFunc, opts ...interface{})
switch v := opt.(type) {
case func(*httptest.Server):
v(ts)
case func(*http.Server):
v(ts.Config)
case func(*Server):
v(h2server)
default:
@@ -140,14 +145,95 @@ type serverTesterOpt string
var optFramerReuseFrames = serverTesterOpt("frame_reuse_frames")
var optQuiet = func(ts *httptest.Server) {
ts.Config.ErrorLog = log.New(io.Discard, "", 0)
var optQuiet = func(server *http.Server) {
server.ErrorLog = log.New(io.Discard, "", 0)
}
func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester {
t.Helper()
g := newSynctest(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC))
h1server := &http.Server{}
h2server := &Server{
group: g,
}
tlsState := tls.ConnectionState{
Version: tls.VersionTLS13,
ServerName: "go.dev",
CipherSuite: tls.TLS_AES_128_GCM_SHA256,
}
for _, opt := range opts {
switch v := opt.(type) {
case func(*Server):
v(h2server)
case func(*http.Server):
v(h1server)
case func(*tls.ConnectionState):
v(&tlsState)
default:
t.Fatalf("unknown newServerTester option type %T", v)
}
}
ConfigureServer(h1server, h2server)
cli, srv := synctestNetPipe(g)
cli.SetReadDeadline(g.Now())
cli.autoWait = true
st := &serverTester{
t: t,
cc: cli,
group: g,
h1server: h1server,
h2server: h2server,
}
st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField)
if h1server.ErrorLog == nil {
h1server.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags)
}
t.Cleanup(func() {
st.Close()
})
connc := make(chan *serverConn)
go func() {
g.Join()
h2server.serveConn(&netConnWithConnectionState{
Conn: srv,
state: tlsState,
}, &ServeConnOpts{
Handler: handler,
BaseConfig: h1server,
}, func(sc *serverConn) {
connc <- sc
})
}()
st.sc = <-connc
st.fr = NewFramer(st.cc, st.cc)
g.Wait()
return st
}
type netConnWithConnectionState struct {
net.Conn
state tls.ConnectionState
}
func (c *netConnWithConnectionState) ConnectionState() tls.ConnectionState {
return c.state
}
// newServerTesterWithRealConn creates a test server listening on a localhost port.
// Mostly superseded by newServerTester, which creates a test server using a fake
// net.Conn and synthetic time. This function is still around because some benchmarks
// rely on it; new tests should use newServerTester.
func newServerTesterWithRealConn(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester {
resetHooks()
ts := httptest.NewUnstartedServer(handler)
t.Cleanup(ts.Close)
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
@@ -162,6 +248,8 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}
v(tlsConfig)
case func(*httptest.Server):
v(ts)
case func(*http.Server):
v(ts.Config)
case func(*Server):
v(h2server)
case serverTesterOpt:
@@ -185,8 +273,7 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}
ts.Config.TLSConfig.MinVersion = tls.VersionTLS10
st := &serverTester{
t: t,
ts: ts,
t: t,
}
st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField)
@@ -234,6 +321,20 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}
return st
}
// sync waits for all goroutines to idle.
func (st *serverTester) sync() {
st.group.Wait()
}
// advance advances synthetic time by a duration.
func (st *serverTester) advance(d time.Duration) {
st.group.AdvanceTime(d)
}
func (st *serverTester) authority() string {
return "dummy.tld"
}
func (st *serverTester) closeConn() {
st.scMu.Lock()
defer st.scMu.Unlock()
@@ -309,7 +410,6 @@ func (st *serverTester) Close() {
st.cc.Close()
}
}
st.ts.Close()
if st.cc != nil {
st.cc.Close()
}
@@ -438,7 +538,7 @@ func (st *serverTester) encodeHeader(headers ...string) []byte {
}
st.headerBuf.Reset()
defaultAuthority := st.ts.Listener.Addr().String()
defaultAuthority := st.authority()
if len(headers) == 0 {
// Fast path, mostly for benchmarks, so test code doesn't pollute
@@ -1245,38 +1345,32 @@ func (l *filterListener) Accept() (net.Conn, error) {
}
func TestServer_MaxQueuedControlFrames(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
// Goroutine debugging makes this test very slow.
disableGoroutineTracking(t)
st := newServerTester(t, nil, func(ts *httptest.Server) {
// TCP buffer sizes on test systems aren't under our control and can be large.
// Create a conn that blocks after 10000 bytes written.
ts.Listener = &filterListener{
Listener: ts.Listener,
accept: func(conn net.Conn) (net.Conn, error) {
return newBlockingWriteConn(conn, 10000), nil
},
}
})
defer st.Close()
st := newServerTester(t, nil)
st.greet()
const extraPings = 500000 // enough to fill the TCP buffers
st.cc.(*synctestNetConn).SetReadBufferSize(0) // all writes block
st.cc.(*synctestNetConn).autoWait = false // don't sync after every write
// Send maxQueuedControlFrames pings, plus a few extra
// to account for ones that enter the server's write buffer.
const extraPings = 2
for i := 0; i < maxQueuedControlFrames+extraPings; i++ {
pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
if err := st.fr.WritePing(false, pingData); err != nil {
if i == 0 {
t.Fatal(err)
}
// We expect the connection to get closed by the server when the TCP
// buffer fills up and the write queue reaches MaxQueuedControlFrames.
t.Logf("sent %d PING frames", i)
return
}
st.fr.WritePing(false, pingData)
}
t.Errorf("unexpected success sending all PING frames")
st.group.Wait()
// Unblock the server.
// It should have closed the connection after exceeding the control frame limit.
st.cc.(*synctestNetConn).SetReadBufferSize(math.MaxInt)
fr, err := st.readFrame()
if err != nil {
return
}
t.Errorf("unexpected frame after exceeding maxQueuedControlFrames; want closed conn\n%v", fr)
}
func TestServer_RejectsLargeFrames(t *testing.T) {
@@ -1762,6 +1856,7 @@ func testServerRejectsConn(t *testing.T, writeReq func(*serverTester)) {
writeReq(st)
st.wantGoAway()
st.advance(goAwayTimeout)
fr, err := st.fr.ReadFrame()
if err == nil {
@@ -2611,13 +2706,12 @@ func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) {
func TestServer_Rejects_TLS10(t *testing.T) { testRejectTLS(t, tls.VersionTLS10) }
func TestServer_Rejects_TLS11(t *testing.T) { testRejectTLS(t, tls.VersionTLS11) }
func testRejectTLS(t *testing.T, max uint16) {
st := newServerTester(t, nil, func(c *tls.Config) {
func testRejectTLS(t *testing.T, version uint16) {
st := newServerTester(t, nil, func(state *tls.ConnectionState) {
// As of 1.18 the default minimum Go TLS version is
// 1.2. In order to test rejection of lower versions,
// manually set the minimum version to 1.0
c.MinVersion = tls.VersionTLS10
c.MaxVersion = max
// manually set the version to 1.0
state.Version = version
})
defer st.Close()
gf := st.wantGoAway()
@@ -2627,24 +2721,9 @@ func testRejectTLS(t *testing.T, max uint16) {
}
func TestServer_Rejects_TLSBadCipher(t *testing.T) {
st := newServerTester(t, nil, func(c *tls.Config) {
// All TLS 1.3 ciphers are good. Test with TLS 1.2.
c.MaxVersion = tls.VersionTLS12
// Only list bad ones:
c.CipherSuites = []uint16{
tls.TLS_RSA_WITH_RC4_128_SHA,
tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
tls.TLS_RSA_WITH_AES_128_CBC_SHA,
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
cipher_TLS_RSA_WITH_AES_128_CBC_SHA256,
}
st := newServerTester(t, nil, func(state *tls.ConnectionState) {
state.Version = tls.VersionTLS12
state.CipherSuite = tls.TLS_RSA_WITH_RC4_128_SHA
})
defer st.Close()
gf := st.wantGoAway()
@@ -2654,18 +2733,30 @@ func TestServer_Rejects_TLSBadCipher(t *testing.T) {
}
func TestServer_Advertises_Common_Cipher(t *testing.T) {
const requiredSuite = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
st := newServerTester(t, nil, func(c *tls.Config) {
// Have the client only support the one required by the spec.
c.CipherSuites = []uint16{requiredSuite}
}, func(ts *httptest.Server) {
var srv *http.Server = ts.Config
ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
}, func(srv *http.Server) {
// Have the server configured with no specific cipher suites.
// This tests that Go's defaults include the required one.
srv.TLSConfig = nil
})
defer st.Close()
st.greet()
// Have the client only support the one required by the spec.
const requiredSuite = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
tlsConfig := tlsConfigInsecure.Clone()
tlsConfig.MaxVersion = tls.VersionTLS12
tlsConfig.CipherSuites = []uint16{requiredSuite}
tr := &Transport{TLSClientConfig: tlsConfig}
defer tr.CloseIdleConnections()
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
}
func (st *serverTester) onHeaderField(f hpack.HeaderField) {
@@ -2867,8 +2958,8 @@ func TestCompressionErrorOnWrite(t *testing.T) {
var serverConfig *http.Server
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
// No response body.
}, func(ts *httptest.Server) {
serverConfig = ts.Config
}, func(s *http.Server) {
serverConfig = s
serverConfig.MaxHeaderBytes = maxStrLen
})
st.addLogFilter("connection error: COMPRESSION_ERROR")
@@ -3141,11 +3232,11 @@ func TestServerDoesntWriteInvalidHeaders(t *testing.T) {
}
func BenchmarkServerGets(b *testing.B) {
defer disableGoroutineTracking()()
disableGoroutineTracking(b)
b.ReportAllocs()
const msg = "Hello, world"
st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, msg)
})
defer st.Close()
@@ -3173,11 +3264,11 @@ func BenchmarkServerGets(b *testing.B) {
}
func BenchmarkServerPosts(b *testing.B) {
defer disableGoroutineTracking()()
disableGoroutineTracking(b)
b.ReportAllocs()
const msg = "Hello, world"
st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) {
// Consume the (empty) body from th peer before replying, otherwise
// the server will sometimes (depending on scheduling) send the peer a
// a RST_STREAM with the CANCEL error code.
@@ -3225,7 +3316,7 @@ func BenchmarkServerToClientStreamReuseFrames(b *testing.B) {
}
func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) {
defer disableGoroutineTracking()()
disableGoroutineTracking(b)
b.ReportAllocs()
const msgLen = 1
// default window size
@@ -3241,7 +3332,7 @@ func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) {
return msg
}
st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) {
// Consume the (empty) body from th peer before replying, otherwise
// the server will sometimes (depending on scheduling) send the peer a
// a RST_STREAM with the CANCEL error code.
@@ -3515,17 +3606,17 @@ func TestServerContentLengthCanBeDisabled(t *testing.T) {
}
}
func disableGoroutineTracking() (restore func()) {
func disableGoroutineTracking(t testing.TB) {
old := DebugGoroutines
DebugGoroutines = false
return func() { DebugGoroutines = old }
t.Cleanup(func() { DebugGoroutines = old })
}
func BenchmarkServer_GetRequest(b *testing.B) {
defer disableGoroutineTracking()()
disableGoroutineTracking(b)
b.ReportAllocs()
const msg = "Hello, world."
st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) {
n, err := io.Copy(io.Discard, r.Body)
if err != nil || n > 0 {
b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err)
@@ -3554,10 +3645,10 @@ func BenchmarkServer_GetRequest(b *testing.B) {
}
func BenchmarkServer_PostRequest(b *testing.B) {
defer disableGoroutineTracking()()
disableGoroutineTracking(b)
b.ReportAllocs()
const msg = "Hello, world."
st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) {
n, err := io.Copy(io.Discard, r.Body)
if err != nil || n > 0 {
b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err)
@@ -3901,6 +3992,7 @@ func TestServerIdleTimeout(t *testing.T) {
defer st.Close()
st.greet()
st.advance(500 * time.Millisecond)
ga := st.wantGoAway()
if ga.ErrCode != ErrCodeNo {
t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
@@ -3911,12 +4003,16 @@ func TestServerIdleTimeout_AfterRequest(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
const timeout = 250 * time.Millisecond
const (
requestTimeout = 2 * time.Second
idleTimeout = 1 * time.Second
)
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
time.Sleep(timeout * 2)
var st *serverTester
st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
st.group.Sleep(requestTimeout)
}, func(h2s *Server) {
h2s.IdleTimeout = timeout
h2s.IdleTimeout = idleTimeout
})
defer st.Close()
@@ -3925,10 +4021,12 @@ func TestServerIdleTimeout_AfterRequest(t *testing.T) {
// Send a request which takes twice the timeout. Verifies the
// idle timeout doesn't fire while we're in a request:
st.bodylessReq1()
st.advance(requestTimeout)
st.wantHeaders()
// But the idle timeout should be rearmed after the request
// is done:
st.advance(idleTimeout)
ga := st.wantGoAway()
if ga.ErrCode != ErrCodeNo {
t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
@@ -4092,6 +4190,8 @@ func TestServerHandlerConnectionClose(t *testing.T) {
}
sawWindowUpdate = true
unblockHandler <- true
st.sync()
st.advance(goAwayTimeout)
default:
t.Logf("unexpected frame: %v", summarizeFrame(f))
}
@@ -4157,20 +4257,9 @@ func TestServer_Headers_HalfCloseRemote(t *testing.T) {
}
func TestServerGracefulShutdown(t *testing.T) {
var st *serverTester
handlerDone := make(chan struct{})
st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
defer close(handlerDone)
go st.ts.Config.Shutdown(context.Background())
ga := st.wantGoAway()
if ga.ErrCode != ErrCodeNo {
t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
}
if ga.LastStreamID != 1 {
t.Errorf("GOAWAY LastStreamID = %v; want 1", ga.LastStreamID)
}
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
<-handlerDone
w.Header().Set("x-foo", "bar")
})
defer st.Close()
@@ -4178,7 +4267,20 @@ func TestServerGracefulShutdown(t *testing.T) {
st.greet()
st.bodylessReq1()
<-handlerDone
st.sync()
st.h1server.Shutdown(context.Background())
ga := st.wantGoAway()
if ga.ErrCode != ErrCodeNo {
t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
}
if ga.LastStreamID != 1 {
t.Errorf("GOAWAY LastStreamID = %v; want 1", ga.LastStreamID)
}
close(handlerDone)
st.sync()
hf := st.wantHeaders()
goth := st.decodeHeader(hf.HeaderBlockFragment())
wanth := [][2]string{
@@ -4396,7 +4498,6 @@ func TestNoErrorLoggedOnPostAfterGOAWAY(t *testing.T) {
}
st.writeData(1, true, []byte(content))
time.Sleep(200 * time.Millisecond)
st.Close()
if bytes.Contains(st.serverLogBuf.Bytes(), []byte("PROTOCOL_ERROR")) {
@@ -4523,6 +4624,7 @@ func TestProtocolErrorAfterGoAway(t *testing.T) {
t.Fatal(err)
}
st.advance(goAwayTimeout)
for {
if _, err := st.readFrame(); err != nil {
if err != io.EOF {
@@ -4805,8 +4907,8 @@ Frames:
func TestServerContinuationFlood(t *testing.T) {
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
fmt.Println(r.Header)
}, func(ts *httptest.Server) {
ts.Config.MaxHeaderBytes = 4096
}, func(s *http.Server) {
s.MaxHeaderBytes = 4096
})
defer st.Close()

View File

@@ -31,6 +31,9 @@ type goroutine struct {
// newSynctest creates a new group with the synthetic clock set the provided time.
func newSynctest(now time.Time) *synctestGroup {
return &synctestGroup{
gids: map[int]bool{
currentGoroutine(): true,
},
now: now,
}
}
@@ -39,9 +42,6 @@ func newSynctest(now time.Time) *synctestGroup {
func (g *synctestGroup) Join() {
g.mu.Lock()
defer g.mu.Unlock()
if g.gids == nil {
g.gids = map[int]bool{}
}
g.gids[currentGoroutine()] = true
}
@@ -154,6 +154,7 @@ func stacks(all bool) []goroutine {
// AdvanceTime advances the synthetic clock by d.
func (g *synctestGroup) AdvanceTime(d time.Duration) {
defer g.Wait()
g.mu.Lock()
defer g.mu.Unlock()
g.now = g.now.Add(d)
@@ -186,6 +187,12 @@ func (g *synctestGroup) TimeUntilEvent() (d time.Duration, scheduled bool) {
return d, scheduled
}
// Sleep is time.Sleep, but using synthetic time.
func (g *synctestGroup) Sleep(d time.Duration) {
tm := g.NewTimer(d)
<-tm.C()
}
// NewTimer is time.NewTimer, but using synthetic time.
func (g *synctestGroup) NewTimer(d time.Duration) Timer {
return g.addTimer(d, &fakeTimer{

View File

@@ -194,12 +194,7 @@ type Transport struct {
type transportTestHooks struct {
newclientconn func(*ClientConn)
group interface {
Join()
NewTimer(d time.Duration) timer
AfterFunc(d time.Duration, f func()) timer
ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc)
}
group synctestGroupInterface
}
func (t *Transport) markNewGoroutine() {

View File

@@ -3658,7 +3658,7 @@ func TestTransportNoBodyMeansNoDATA(t *testing.T) {
}
func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) {
defer disableGoroutineTracking()()
disableGoroutineTracking(b)
b.ReportAllocs()
ts := newTestServer(b,
func(w http.ResponseWriter, r *http.Request) {
@@ -3770,10 +3770,10 @@ func BenchmarkDownloadFrameSize(b *testing.B) {
b.Run("512k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 512*1024) })
}
func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) {
defer disableGoroutineTracking()()
disableGoroutineTracking(b)
const transferSize = 1024 * 1024 * 1024 // must be multiple of 1M
b.ReportAllocs()
st := newServerTester(b,
ts := newTestServer(b,
func(w http.ResponseWriter, r *http.Request) {
// test 1GB transfer
w.Header().Set("Content-Length", strconv.Itoa(transferSize))
@@ -3784,12 +3784,11 @@ func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) {
}
}, optQuiet,
)
defer st.Close()
tr := &Transport{TLSClientConfig: tlsConfigInsecure, MaxReadFrameSize: frameSize}
defer tr.CloseIdleConnections()
req, err := http.NewRequest("GET", st.ts.URL, nil)
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
b.Fatal(err)
}
@@ -4869,33 +4868,36 @@ func TestTransportRetriesOnStreamProtocolError(t *testing.T) {
}
func TestClientConnReservations(t *testing.T) {
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
}, func(s *Server) {
s.MaxConcurrentStreams = initialMaxConcurrentStreams
})
defer st.Close()
tc := newTestClientConn(t)
tc.greet(
Setting{ID: SettingMaxConcurrentStreams, Val: initialMaxConcurrentStreams},
)
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
cc, err := tr.newClientConn(st.cc, false)
if err != nil {
t.Fatal(err)
doRoundTrip := func() {
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
rt := tc.roundTrip(req)
tc.wantFrameType(FrameHeaders)
tc.writeHeaders(HeadersFrameParam{
StreamID: rt.streamID(),
EndHeaders: true,
EndStream: true,
BlockFragment: tc.makeHeaderBlockFragment(
":status", "200",
),
})
rt.wantStatus(200)
}
req, _ := http.NewRequest("GET", st.ts.URL, nil)
n := 0
for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
for n <= initialMaxConcurrentStreams && tc.cc.ReserveNewRequest() {
n++
}
if n != initialMaxConcurrentStreams {
t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams)
}
if _, err := cc.RoundTrip(req); err != nil {
t.Fatalf("RoundTrip error = %v", err)
}
doRoundTrip()
n2 := 0
for n2 <= 5 && cc.ReserveNewRequest() {
for n2 <= 5 && tc.cc.ReserveNewRequest() {
n2++
}
if n2 != 1 {
@@ -4904,11 +4906,11 @@ func TestClientConnReservations(t *testing.T) {
// Use up all the reservations
for i := 0; i < n; i++ {
cc.RoundTrip(req)
doRoundTrip()
}
n2 = 0
for n2 <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
for n2 <= initialMaxConcurrentStreams && tc.cc.ReserveNewRequest() {
n2++
}
if n2 != n {