mirror of
https://github.com/golang/net.git
synced 2026-04-01 02:47:08 +09:00
http2: clearer distinction between test server types
newServerTester is used to create an HTTP/2 server for testing. It returns a *serverTester, which includes a number of methods for sending frames to and reading frames from a server connection under test. Many tests also use newServerTester to simply start an *httptest.Server. These tests pass an "optOnlyServer" to indicate that they don't need a client connection to interact with. They interact with the *httptest.Server, and use no methods or fields of *serverTester. Make a clearer distinction between test types. Add a newTestServer function which starts and returns an *httptest.Server. This function replaces use of newServerTester with optOnlyServer. Change-Id: Ia590c9b4dcc23c17e530b0cc273b9120965da11a Reviewed-on: https://go-review.googlesource.com/c/net/+/586155 Reviewed-by: Jonathan Amsterdam <jba@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
@@ -101,12 +101,50 @@ func resetHooks() {
|
||||
testHookOnPanicMu.Unlock()
|
||||
}
|
||||
|
||||
func newTestServer(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *httptest.Server {
|
||||
ts := httptest.NewUnstartedServer(handler)
|
||||
ts.EnableHTTP2 = true
|
||||
ts.Config.ErrorLog = log.New(twriter{t: t}, "", log.LstdFlags)
|
||||
h2server := new(Server)
|
||||
for _, opt := range opts {
|
||||
switch v := opt.(type) {
|
||||
case func(*httptest.Server):
|
||||
v(ts)
|
||||
case func(*Server):
|
||||
v(h2server)
|
||||
default:
|
||||
t.Fatalf("unknown newTestServer option type %T", v)
|
||||
}
|
||||
}
|
||||
ConfigureServer(ts.Config, h2server)
|
||||
|
||||
// ConfigureServer populates ts.Config.TLSConfig.
|
||||
// Copy it to ts.TLS as well.
|
||||
ts.TLS = ts.Config.TLSConfig
|
||||
|
||||
// Go 1.22 changes the default minimum TLS version to TLS 1.2,
|
||||
// in order to properly test cases where we want to reject low
|
||||
// TLS versions, we need to explicitly configure the minimum
|
||||
// version here.
|
||||
ts.Config.TLSConfig.MinVersion = tls.VersionTLS10
|
||||
|
||||
ts.StartTLS()
|
||||
t.Cleanup(func() {
|
||||
ts.CloseClientConnections()
|
||||
ts.Close()
|
||||
})
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
type serverTesterOpt string
|
||||
|
||||
var optOnlyServer = serverTesterOpt("only_server")
|
||||
var optQuiet = serverTesterOpt("quiet_logging")
|
||||
var optFramerReuseFrames = serverTesterOpt("frame_reuse_frames")
|
||||
|
||||
var optQuiet = func(ts *httptest.Server) {
|
||||
ts.Config.ErrorLog = log.New(io.Discard, "", 0)
|
||||
}
|
||||
|
||||
func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester {
|
||||
resetHooks()
|
||||
|
||||
@@ -117,7 +155,7 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}
|
||||
NextProtos: []string{NextProtoTLS},
|
||||
}
|
||||
|
||||
var onlyServer, quiet, framerReuseFrames bool
|
||||
var framerReuseFrames bool
|
||||
h2server := new(Server)
|
||||
for _, opt := range opts {
|
||||
switch v := opt.(type) {
|
||||
@@ -129,10 +167,6 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}
|
||||
v(h2server)
|
||||
case serverTesterOpt:
|
||||
switch v {
|
||||
case optOnlyServer:
|
||||
onlyServer = true
|
||||
case optQuiet:
|
||||
quiet = true
|
||||
case optFramerReuseFrames:
|
||||
framerReuseFrames = true
|
||||
}
|
||||
@@ -159,9 +193,7 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}
|
||||
st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField)
|
||||
|
||||
ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config
|
||||
if quiet {
|
||||
ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
|
||||
} else {
|
||||
if ts.Config.ErrorLog == nil {
|
||||
ts.Config.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags)
|
||||
}
|
||||
ts.StartTLS()
|
||||
@@ -175,32 +207,30 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}
|
||||
st.sc = v
|
||||
}
|
||||
log.SetOutput(io.MultiWriter(stderrv(), twriter{t: t, st: st}))
|
||||
if !onlyServer {
|
||||
cc, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
cc, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
st.cc = cc
|
||||
st.fr = NewFramer(cc, cc)
|
||||
if framerReuseFrames {
|
||||
st.fr.SetReuseFrames()
|
||||
}
|
||||
if !logFrameReads && !logFrameWrites {
|
||||
st.fr.debugReadLoggerf = func(m string, v ...interface{}) {
|
||||
m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
|
||||
st.frameReadLogMu.Lock()
|
||||
fmt.Fprintf(&st.frameReadLogBuf, m, v...)
|
||||
st.frameReadLogMu.Unlock()
|
||||
}
|
||||
st.cc = cc
|
||||
st.fr = NewFramer(cc, cc)
|
||||
if framerReuseFrames {
|
||||
st.fr.SetReuseFrames()
|
||||
}
|
||||
if !logFrameReads && !logFrameWrites {
|
||||
st.fr.debugReadLoggerf = func(m string, v ...interface{}) {
|
||||
m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
|
||||
st.frameReadLogMu.Lock()
|
||||
fmt.Fprintf(&st.frameReadLogBuf, m, v...)
|
||||
st.frameReadLogMu.Unlock()
|
||||
}
|
||||
st.fr.debugWriteLoggerf = func(m string, v ...interface{}) {
|
||||
m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
|
||||
st.frameWriteLogMu.Lock()
|
||||
fmt.Fprintf(&st.frameWriteLogBuf, m, v...)
|
||||
st.frameWriteLogMu.Unlock()
|
||||
}
|
||||
st.fr.logReads = true
|
||||
st.fr.logWrites = true
|
||||
st.fr.debugWriteLoggerf = func(m string, v ...interface{}) {
|
||||
m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
|
||||
st.frameWriteLogMu.Lock()
|
||||
fmt.Fprintf(&st.frameWriteLogBuf, m, v...)
|
||||
st.frameWriteLogMu.Unlock()
|
||||
}
|
||||
st.fr.logReads = true
|
||||
st.fr.logWrites = true
|
||||
}
|
||||
return st
|
||||
}
|
||||
@@ -3067,16 +3097,15 @@ func testServerWritesTrailers(t *testing.T, withFlush bool) {
|
||||
func TestServerWritesUndeclaredTrailers(t *testing.T) {
|
||||
const trailer = "Trailer-Header"
|
||||
const value = "hi1"
|
||||
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set(http.TrailerPrefix+trailer, value)
|
||||
}, optOnlyServer)
|
||||
defer st.Close()
|
||||
})
|
||||
|
||||
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
|
||||
defer tr.CloseIdleConnections()
|
||||
|
||||
cl := &http.Client{Transport: tr}
|
||||
resp, err := cl.Get(st.ts.URL)
|
||||
resp, err := cl.Get(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -3731,7 +3760,7 @@ func TestExpect100ContinueAfterHandlerWrites(t *testing.T) {
|
||||
doRead := make(chan bool, 1)
|
||||
defer close(doRead) // fallback cleanup
|
||||
|
||||
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, msg)
|
||||
w.(http.Flusher).Flush()
|
||||
|
||||
@@ -3740,14 +3769,12 @@ func TestExpect100ContinueAfterHandlerWrites(t *testing.T) {
|
||||
r.Body.Read(make([]byte, 10))
|
||||
|
||||
io.WriteString(w, msg2)
|
||||
|
||||
}, optOnlyServer)
|
||||
defer st.Close()
|
||||
})
|
||||
|
||||
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
|
||||
defer tr.CloseIdleConnections()
|
||||
|
||||
req, _ := http.NewRequest("POST", st.ts.URL, io.LimitReader(neverEnding('A'), 2<<20))
|
||||
req, _ := http.NewRequest("POST", ts.URL, io.LimitReader(neverEnding('A'), 2<<20))
|
||||
req.Header.Set("Expect", "100-continue")
|
||||
|
||||
res, err := tr.RoundTrip(req)
|
||||
@@ -3808,14 +3835,13 @@ func TestUnreadFlowControlReturned_Server(t *testing.T) {
|
||||
unblock := make(chan bool, 1)
|
||||
defer close(unblock)
|
||||
|
||||
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
// Don't read the 16KB request body. Wait until the client's
|
||||
// done sending it and then return. This should cause the Server
|
||||
// to then return those 16KB of flow control to the client.
|
||||
tt.reqFn(r)
|
||||
<-unblock
|
||||
}, optOnlyServer)
|
||||
defer st.Close()
|
||||
})
|
||||
|
||||
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
|
||||
defer tr.CloseIdleConnections()
|
||||
@@ -3833,7 +3859,7 @@ func TestUnreadFlowControlReturned_Server(t *testing.T) {
|
||||
return 0, io.EOF
|
||||
}),
|
||||
)
|
||||
req, _ := http.NewRequest("POST", st.ts.URL, body)
|
||||
req, _ := http.NewRequest("POST", ts.URL, body)
|
||||
res, err := tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal(tt.name, err)
|
||||
@@ -3949,22 +3975,21 @@ func TestIssue20704Race(t *testing.T) {
|
||||
itemCount = 100
|
||||
)
|
||||
|
||||
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
for i := 0; i < itemCount; i++ {
|
||||
_, err := w.Write(make([]byte, itemSize))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}, optOnlyServer)
|
||||
defer st.Close()
|
||||
})
|
||||
|
||||
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
|
||||
defer tr.CloseIdleConnections()
|
||||
cl := &http.Client{Transport: tr}
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
resp, err := cl.Get(st.ts.URL)
|
||||
resp, err := cl.Get(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -4241,26 +4266,25 @@ func TestContentEncodingNoSniffing(t *testing.T) {
|
||||
|
||||
for _, tt := range resps {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if tt.contentEncoding != nil {
|
||||
w.Header().Set("Content-Encoding", tt.contentEncoding.(string))
|
||||
}
|
||||
w.Write(tt.body)
|
||||
}, optOnlyServer)
|
||||
defer st.Close()
|
||||
})
|
||||
|
||||
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
|
||||
defer tr.CloseIdleConnections()
|
||||
|
||||
req, _ := http.NewRequest("GET", st.ts.URL, nil)
|
||||
req, _ := http.NewRequest("GET", ts.URL, nil)
|
||||
res, err := tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("GET %s: %v", st.ts.URL, err)
|
||||
t.Fatalf("GET %s: %v", ts.URL, err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
g := res.Header.Get("Content-Encoding")
|
||||
t.Logf("%s: Content-Encoding: %s", st.ts.URL, g)
|
||||
t.Logf("%s: Content-Encoding: %s", ts.URL, g)
|
||||
|
||||
if w := tt.contentEncoding; g != w {
|
||||
if w != nil { // The case where contentEncoding was set explicitly.
|
||||
@@ -4274,7 +4298,7 @@ func TestContentEncodingNoSniffing(t *testing.T) {
|
||||
if w := tt.wantContentType; g != w {
|
||||
t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
|
||||
}
|
||||
t.Logf("%s: Content-Type: %s", st.ts.URL, g)
|
||||
t.Logf("%s: Content-Type: %s", ts.URL, g)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4606,7 +4630,7 @@ func TestCanonicalHeaderCacheGrowth(t *testing.T) {
|
||||
// We should not access the slice after this point.
|
||||
func TestServerWriteDoesNotRetainBufferAfterReturn(t *testing.T) {
|
||||
donec := make(chan struct{})
|
||||
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
defer close(donec)
|
||||
buf := make([]byte, 1<<20)
|
||||
var i byte
|
||||
@@ -4620,13 +4644,12 @@ func TestServerWriteDoesNotRetainBufferAfterReturn(t *testing.T) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}, optOnlyServer)
|
||||
defer st.Close()
|
||||
})
|
||||
|
||||
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
|
||||
defer tr.CloseIdleConnections()
|
||||
|
||||
req, _ := http.NewRequest("GET", st.ts.URL, nil)
|
||||
req, _ := http.NewRequest("GET", ts.URL, nil)
|
||||
res, err := tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -4642,7 +4665,7 @@ func TestServerWriteDoesNotRetainBufferAfterReturn(t *testing.T) {
|
||||
// We should not access the slice after this point.
|
||||
func TestServerWriteDoesNotRetainBufferAfterServerClose(t *testing.T) {
|
||||
donec := make(chan struct{}, 1)
|
||||
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
donec <- struct{}{}
|
||||
defer close(donec)
|
||||
buf := make([]byte, 1<<20)
|
||||
@@ -4657,20 +4680,19 @@ func TestServerWriteDoesNotRetainBufferAfterServerClose(t *testing.T) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}, optOnlyServer)
|
||||
defer st.Close()
|
||||
})
|
||||
|
||||
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
|
||||
defer tr.CloseIdleConnections()
|
||||
|
||||
req, _ := http.NewRequest("GET", st.ts.URL, nil)
|
||||
req, _ := http.NewRequest("GET", ts.URL, nil)
|
||||
res, err := tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
<-donec
|
||||
st.ts.Config.Close()
|
||||
ts.Config.Close()
|
||||
<-donec
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user