From 8d297f1cac94a449ef60461f8242b1982bdeb0bc Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Wed, 25 Feb 2026 09:41:57 -0800 Subject: [PATCH] http2: Move most tests from the http2 package to the http2_test package. This change makes it easier to move x/net/http2 into std. Moving the http2 package into std and importing it from net/http (rather than bundling it as net/http/h2_bundle.go) requires removing the http2->net/http dependency. Moving tests into the http2_test package allows them to continue importing net/http without creating a cycle. Change-Id: If0799a94a6d2c90f02d7f391e352e14e6a6a6964 Reviewed-on: https://go-review.googlesource.com/c/net/+/749280 Auto-Submit: Damien Neil Reviewed-by: Nicholas Husin LUCI-TryBot-Result: Go LUCI Reviewed-by: Nicholas Husin --- http2/clientconn_test.go | 63 +++-- http2/config_test.go | 4 +- http2/connframes_test.go | 7 +- http2/export_test.go | 237 ++++++++++++++++++ http2/frame_test.go | 33 +-- http2/http2_test.go | 57 +---- http2/netconn_test.go | 2 +- http2/server_internal_test.go | 84 +++++++ http2/server_push_test.go | 12 +- http2/server_test.go | 320 ++++++++--------------- http2/synctest_test.go | 10 +- http2/transport.go | 4 - http2/transport_internal_test.go | 293 ++++++++++++++++++++++ http2/transport_test.go | 418 ++++++------------------------- 14 files changed, 863 insertions(+), 681 deletions(-) create mode 100644 http2/export_test.go create mode 100644 http2/server_internal_test.go create mode 100644 http2/transport_internal_test.go diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go index 247a8d57..10b1fc12 100644 --- a/http2/clientconn_test.go +++ b/http2/clientconn_test.go @@ -5,7 +5,7 @@ // Infrastructure for testing ClientConn.RoundTrip. // Put actual tests in transport_test.go. -package http2 +package http2_test import ( "bytes" @@ -20,6 +20,7 @@ import ( "testing/synctest" "time" + . "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" "golang.org/x/net/internal/gate" ) @@ -108,29 +109,32 @@ type testClientConn struct { netconn *synctestNetConn } -func newTestClientConnFromClientConn(t testing.TB, cc *ClientConn) *testClientConn { +func newTestClientConnFromClientConn(t testing.TB, tr *Transport, cc *ClientConn) *testClientConn { tc := &testClientConn{ t: t, - tr: cc.t, + tr: tr, cc: cc, } // srv is the side controlled by the test. var srv *synctestNetConn - if cc.tconn == nil { + if tconn := cc.TestNetConn(); tconn == nil { // If cc.tconn is nil, we're being called with a new conn created by the // Transport's client pool. This path skips dialing the server, and we // create a test connection pair here. - cc.tconn, srv = synctestNetPipe() + var cli *synctestNetConn + cli, srv = synctestNetPipe() + cc.TestSetNetConn(cli) } else { // If cc.tconn is non-nil, we're in a test which provides a conn to the // Transport via a TLSNextProto hook. Extract the test connection pair. - if tc, ok := cc.tconn.(*tls.Conn); ok { + if tc, ok := tconn.(*tls.Conn); ok { // Unwrap any *tls.Conn to the underlying net.Conn, // to avoid dealing with encryption in tests. - cc.tconn = tc.NetConn() + tconn = tc.NetConn() + cc.TestSetNetConn(tconn) } - srv = cc.tconn.(*synctestNetConn).peer + srv = tconn.(*synctestNetConn).peer } srv.SetReadDeadline(time.Now()) @@ -141,7 +145,7 @@ func newTestClientConnFromClientConn(t testing.TB, cc *ClientConn) *testClientCo tc.testConnFramer = testConnFramer{ t: t, fr: tc.fr, - dec: hpack.NewDecoder(initialHeaderTableSize, nil), + dec: hpack.NewDecoder(InitialHeaderTableSize, nil), } tc.fr.SetMaxReadFrameSize(10 << 20) t.Cleanup(func() { @@ -154,12 +158,12 @@ func newTestClientConnFromClientConn(t testing.TB, cc *ClientConn) *testClientCo func (tc *testClientConn) readClientPreface() { tc.t.Helper() // Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames. - buf := make([]byte, len(clientPreface)) + buf := make([]byte, len(ClientPreface)) if _, err := io.ReadFull(tc.netconn, buf); err != nil { tc.t.Fatalf("reading preface: %v", err) } - if !bytes.Equal(buf, clientPreface) { - tc.t.Fatalf("client preface: %q, want %q", buf, clientPreface) + if !bytes.Equal(buf, []byte(ClientPreface)) { + tc.t.Fatalf("client preface: %q, want %q", buf, ClientPreface) } } @@ -168,7 +172,7 @@ func newTestClientConn(t testing.TB, opts ...any) *testClientConn { tt := newTestTransport(t, opts...) const singleUse = false - _, err := tt.tr.newClientConn(nil, singleUse, nil) + _, err := tt.tr.TestNewClientConn(nil, singleUse, nil) if err != nil { t.Fatalf("newClientConn: %v", err) } @@ -303,8 +307,8 @@ func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { tc.roundtrips = append(tc.roundtrips, rt) go func() { defer close(rt.donec) - rt.resp, rt.respErr = tc.cc.roundTrip(req, func(cs *clientStream) { - rt.id.Store(cs.ID) + rt.resp, rt.respErr = tc.cc.TestRoundTrip(req, func(streamID uint32) { + rt.id.Store(streamID) }) }() synctest.Wait() @@ -348,17 +352,11 @@ func (tc *testClientConn) makeHeaderBlockFragment(s ...string) []byte { // inflowWindow returns the amount of inbound flow control available for a stream, // or for the connection if streamID is 0. func (tc *testClientConn) inflowWindow(streamID uint32) int32 { - tc.cc.mu.Lock() - defer tc.cc.mu.Unlock() - if streamID == 0 { - return tc.cc.inflow.avail + tc.cc.inflow.unsent + w, err := tc.cc.TestInflowWindow(streamID) + if err != nil { + tc.t.Error(err) } - cs := tc.cc.streams[streamID] - if cs == nil { - tc.t.Errorf("no stream with id %v", streamID) - return -1 - } - return cs.inflow.avail + cs.inflow.unsent + return w } // testRoundTrip manages a RoundTrip in progress. @@ -508,10 +506,7 @@ func newTestTransport(t testing.TB, opts ...any) *testTransport { for _, o := range opts { switch o := o.(type) { case func(*http.Transport): - if tr.t1 == nil { - tr.t1 = &http.Transport{} - } - o(tr.t1) + o(tr.TestTransport()) case func(*Transport): o(tr) case *Transport: @@ -520,12 +515,10 @@ func newTestTransport(t testing.TB, opts ...any) *testTransport { } tt.tr = tr - tr.transportTestHooks = &transportTestHooks{ - newclientconn: func(cc *ClientConn) { - tc := newTestClientConnFromClientConn(t, cc) - tt.ccs = append(tt.ccs, tc) - }, - } + tr.TestSetNewClientConnHook(func(cc *ClientConn) { + tc := newTestClientConnFromClientConn(t, tr, cc) + tt.ccs = append(tt.ccs, tc) + }) t.Cleanup(func() { synctest.Wait() diff --git a/http2/config_test.go b/http2/config_test.go index 01395725..5da266aa 100644 --- a/http2/config_test.go +++ b/http2/config_test.go @@ -2,12 +2,14 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package http2 +package http2_test import ( "net/http" "testing" "time" + + . "golang.org/x/net/http2" ) func TestConfigServerSettings(t *testing.T) { synctestTest(t, testConfigServerSettings) } diff --git a/http2/connframes_test.go b/http2/connframes_test.go index 4508a580..d4f09303 100644 --- a/http2/connframes_test.go +++ b/http2/connframes_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package http2 +package http2_test import ( "bytes" @@ -13,6 +13,7 @@ import ( "slices" "testing" + . "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" ) @@ -257,7 +258,7 @@ func (tf *testConnFramer) wantRSTStream(streamID uint32, code ErrCode) { tf.t.Helper() fr := readFrame[*RSTStreamFrame](tf.t, tf) if fr.StreamID != streamID || fr.ErrCode != code { - tf.t.Fatalf("got %v, want RST_STREAM StreamID=%v, code=%v", summarizeFrame(fr), streamID, code) + tf.t.Fatalf("got %v, want RST_STREAM StreamID=%v, code=%v", SummarizeFrame(fr), streamID, code) } } @@ -291,7 +292,7 @@ func (tf *testConnFramer) wantGoAway(maxStreamID uint32, code ErrCode) { tf.t.Helper() fr := readFrame[*GoAwayFrame](tf.t, tf) if fr.LastStreamID != maxStreamID || fr.ErrCode != code { - tf.t.Fatalf("got %v, want GOAWAY LastStreamID=%v, code=%v", summarizeFrame(fr), maxStreamID, code) + tf.t.Fatalf("got %v, want GOAWAY LastStreamID=%v, code=%v", SummarizeFrame(fr), maxStreamID, code) } } diff --git a/http2/export_test.go b/http2/export_test.go new file mode 100644 index 00000000..4a35e765 --- /dev/null +++ b/http2/export_test.go @@ -0,0 +1,237 @@ +package http2 + +import ( + "context" + "fmt" + "net" + "net/http" + "net/textproto" + "sync" + "testing" + "time" + + "golang.org/x/net/http2/hpack" + "golang.org/x/net/internal/httpcommon" +) + +const ( + DefaultMaxReadFrameSize = defaultMaxReadFrameSize + DefaultMaxStreams = defaultMaxStreams + InflowMinRefresh = inflowMinRefresh + InitialHeaderTableSize = initialHeaderTableSize + InitialMaxConcurrentStreams = initialMaxConcurrentStreams + InitialWindowSize = initialWindowSize + MaxFrameSize = maxFrameSize + MaxQueuedControlFrames = maxQueuedControlFrames + MinMaxFrameSize = minMaxFrameSize +) + +type ( + ServerConn = serverConn + Stream = stream + StreamState = streamState + + PseudoHeaderError = pseudoHeaderError + HeaderFieldNameError = headerFieldNameError + HeaderFieldValueError = headerFieldValueError +) + +const ( + StateIdle = stateIdle + StateOpen = stateOpen + StateHalfClosedLocal = stateHalfClosedLocal + StateHalfClosedRemote = stateHalfClosedRemote + StateClosed = stateClosed +) + +var ( + ErrClientConnForceClosed = errClientConnForceClosed + ErrClientConnNotEstablished = errClientConnNotEstablished + ErrClientConnUnusable = errClientConnUnusable + ErrExtendedConnectNotSupported = errExtendedConnectNotSupported + ErrReqBodyTooLong = errReqBodyTooLong + ErrRequestHeaderListSize = errRequestHeaderListSize + ErrResponseHeaderListSize = errResponseHeaderListSize +) + +func (s *Server) TestServeConn(c net.Conn, opts *ServeConnOpts, newf func(*serverConn)) { + s.serveConn(c, opts, newf) +} + +func (sc *serverConn) TestFlowControlConsumed() (consumed int32) { + conf := configFromServer(sc.hs, sc.srv) + donec := make(chan struct{}) + sc.sendServeMsg(func(sc *serverConn) { + defer close(donec) + initial := conf.MaxUploadBufferPerConnection + avail := sc.inflow.avail + sc.inflow.unsent + consumed = initial - avail + }) + <-donec + return consumed +} + +func (sc *serverConn) TestStreamExists(id uint32) bool { + ch := make(chan bool, 1) + sc.serveMsgCh <- func(int) { + ch <- (sc.streams[id] != nil) + } + return <-ch +} + +func (sc *serverConn) TestStreamState(id uint32) streamState { + ch := make(chan streamState, 1) + sc.serveMsgCh <- func(int) { + state, _ := sc.state(id) + ch <- state + } + return <-ch +} + +func (sc *serverConn) StartGracefulShutdown() { sc.startGracefulShutdown() } + +func (sc *serverConn) TestHPACKEncoder() *hpack.Encoder { + return sc.hpackEncoder +} + +func (sc *serverConn) TestFramerMaxHeaderStringLen() int { + return sc.framer.maxHeaderStringLen() +} + +func (t *Transport) DialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) { + return t.dialClientConn(ctx, addr, singleUse) +} + +func (t *Transport) TestNewClientConn(c net.Conn, singleUse bool, internalStateHook func()) (*ClientConn, error) { + return t.newClientConn(c, singleUse, internalStateHook) +} + +func (t *Transport) TestSetNewClientConnHook(f func(*ClientConn)) { + t.transportTestHooks = &transportTestHooks{ + newclientconn: f, + } +} + +func (t *Transport) TestTransport() *http.Transport { + if t.t1 == nil { + t.t1 = &http.Transport{} + } + return t.t1 +} + +func (cc *ClientConn) TestNetConn() net.Conn { return cc.tconn } +func (cc *ClientConn) TestSetNetConn(c net.Conn) { cc.tconn = c } +func (cc *ClientConn) TestRoundTrip(req *http.Request, f func(stremaID uint32)) (*http.Response, error) { + return cc.roundTrip(req, func(cs *clientStream) { + f(cs.ID) + }) +} + +func (cc *ClientConn) TestHPACKEncoder() *hpack.Encoder { + return cc.henc +} + +func (cc *ClientConn) TestPeerMaxHeaderTableSize() uint32 { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.peerMaxHeaderTableSize +} + +func (cc *ClientConn) TestInflowWindow(streamID uint32) (int32, error) { + cc.mu.Lock() + defer cc.mu.Unlock() + if streamID == 0 { + return cc.inflow.avail + cc.inflow.unsent, nil + } + cs := cc.streams[streamID] + if cs == nil { + return -1, fmt.Errorf("no stream with id %v", streamID) + } + return cs.inflow.avail + cs.inflow.unsent, nil +} + +func (fr *Framer) TestSetDebugReadLoggerf(f func(string, ...any)) { + fr.logReads = true + fr.debugReadLoggerf = f +} + +func (fr *Framer) TestSetDebugWriteLoggerf(f func(string, ...any)) { + fr.logWrites = true + fr.debugWriteLoggerf = f +} + +func SummarizeFrame(f Frame) string { + return summarizeFrame(f) +} + +func SetTestHookGetServerConn(t testing.TB, f func(*serverConn)) { + SetForTest(t, &testHookGetServerConn, f) +} + +func init() { + testHookOnPanicMu = new(sync.Mutex) +} + +func SetTestHookOnPanic(t testing.TB, f func(sc *serverConn, panicVal interface{}) (rePanic bool)) { + testHookOnPanicMu.Lock() + defer testHookOnPanicMu.Unlock() + old := testHookOnPanic + testHookOnPanic = f + t.Cleanup(func() { + testHookOnPanicMu.Lock() + defer testHookOnPanicMu.Unlock() + testHookOnPanic = old + }) +} + +func SetTestHookGot1xx(t testing.TB, f func(int, textproto.MIMEHeader) error) { + SetForTest(t, &got1xxFuncForTests, f) +} + +func SetDisableExtendedConnectProtocol(t testing.TB, v bool) { + SetForTest(t, &disableExtendedConnectProtocol, v) +} + +func LogFrameReads() bool { return logFrameReads } +func LogFrameWrites() bool { return logFrameWrites } + +const GoAwayTimeout = 25 * time.Millisecond + +func init() { + goAwayTimeout = GoAwayTimeout +} + +func EncodeHeaderRaw(t testing.TB, headers ...string) []byte { + return encodeHeaderRaw(t, headers...) +} + +func NewPriorityWriteSchedulerRFC7540(cfg *PriorityWriteSchedulerConfig) WriteScheduler { + return newPriorityWriteSchedulerRFC7540(cfg) +} + +func NewPriorityWriteSchedulerRFC9218() WriteScheduler { + return newPriorityWriteSchedulerRFC9218() +} + +func NewRoundRobinWriteScheduler() WriteScheduler { + return newRoundRobinWriteScheduler() +} + +func DisableGoroutineTracking(t testing.TB) { + disableDebugGoroutines.Store(true) + t.Cleanup(func() { + disableDebugGoroutines.Store(false) + }) +} + +func InvalidHTTP1LookingFrameHeader() FrameHeader { + return invalidHTTP1LookingFrameHeader() +} + +func NewNoDialClientConnPool() ClientConnPool { + return noDialClientConnPool{new(clientConnPool)} +} + +func EncodeRequestHeaders(req *http.Request, addGzipHeader bool, peerMaxHeaderListSize uint64, headerf func(name, value string)) (httpcommon.EncodeHeadersResult, error) { + return encodeRequestHeaders(req, addGzipHeader, peerMaxHeaderListSize, headerf) +} diff --git a/http2/frame_test.go b/http2/frame_test.go index 287a65e3..488b2eb3 100644 --- a/http2/frame_test.go +++ b/http2/frame_test.go @@ -1139,8 +1139,7 @@ func TestMetaFrameHeader(t *testing.T) { 0: { name: "single_headers", w: func(f *Framer) { - var he hpackEncoder - all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/") + all := encodeHeaderRaw(t, ":method", "GET", ":path", "/") write(f, all) }, want: want(FlagHeadersEndHeaders, 2, ":method", "GET", ":path", "/"), @@ -1148,8 +1147,7 @@ func TestMetaFrameHeader(t *testing.T) { 1: { name: "with_continuation", w: func(f *Framer) { - var he hpackEncoder - all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", "bar") + all := encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", "bar") write(f, all[:1], all[1:]) }, want: want(noFlags, 1, ":method", "GET", ":path", "/", "foo", "bar"), @@ -1157,8 +1155,7 @@ func TestMetaFrameHeader(t *testing.T) { 2: { name: "with_two_continuation", w: func(f *Framer) { - var he hpackEncoder - all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", "bar") + all := encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", "bar") write(f, all[:2], all[2:4], all[4:]) }, want: want(noFlags, 2, ":method", "GET", ":path", "/", "foo", "bar"), @@ -1166,8 +1163,7 @@ func TestMetaFrameHeader(t *testing.T) { 3: { name: "big_string_okay", w: func(f *Framer) { - var he hpackEncoder - all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", oneKBString) + all := encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", oneKBString) write(f, all[:2], all[2:]) }, want: want(noFlags, 2, ":method", "GET", ":path", "/", "foo", oneKBString), @@ -1175,8 +1171,7 @@ func TestMetaFrameHeader(t *testing.T) { 4: { name: "big_string_error", w: func(f *Framer) { - var he hpackEncoder - all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", oneKBString) + all := encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", oneKBString) write(f, all[:2], all[2:]) }, maxHeaderListSize: (1 << 10) / 2, @@ -1185,12 +1180,11 @@ func TestMetaFrameHeader(t *testing.T) { 5: { name: "max_header_list_truncated", w: func(f *Framer) { - var he hpackEncoder var pairs = []string{":method", "GET", ":path", "/"} for i := 0; i < 100; i++ { pairs = append(pairs, "foo", "bar") } - all := he.encodeHeaderRaw(t, pairs...) + all := encodeHeaderRaw(t, pairs...) write(f, all[:2], all[2:]) }, maxHeaderListSize: (1 << 10) / 2, @@ -1412,9 +1406,18 @@ func readAndVerifyDataFrame(data string, length byte, fr *Framer, buf *bytes.Buf return df } -func encodeHeaderRaw(t *testing.T, pairs ...string) []byte { - var he hpackEncoder - return he.encodeHeaderRaw(t, pairs...) +func encodeHeaderRaw(t testing.TB, headers ...string) []byte { + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + for len(headers) > 0 { + k, v := headers[0], headers[1] + err := enc.WriteField(hpack.HeaderField{Name: k, Value: v}) + if err != nil { + t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err) + } + headers = headers[2:] + } + return buf.Bytes() } func TestSettingsDuplicates(t *testing.T) { diff --git a/http2/http2_test.go b/http2/http2_test.go index 76948b7b..89003fd6 100644 --- a/http2/http2_test.go +++ b/http2/http2_test.go @@ -5,7 +5,6 @@ package http2 import ( - "bytes" "flag" "fmt" "net/http" @@ -15,8 +14,6 @@ import ( "strings" "testing" "time" - - "golang.org/x/net/http2/hpack" ) var knownFailing = flag.Bool("known_failing", false, "Run known-failing tests.") @@ -48,44 +45,6 @@ func TestSettingString(t *testing.T) { } } -type twriter struct { - t testing.TB - st *serverTester // optional -} - -func (w twriter) Write(p []byte) (n int, err error) { - if w.st != nil { - ps := string(p) - for _, phrase := range w.st.logFilter { - if strings.Contains(ps, phrase) { - return len(p), nil // no logging - } - } - } - w.t.Logf("%s", p) - return len(p), nil -} - -// like encodeHeader, but don't add implicit pseudo headers. -func encodeHeaderNoImplicit(t testing.TB, headers ...string) []byte { - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - for len(headers) > 0 { - k, v := headers[0], headers[1] - headers = headers[2:] - if err := enc.WriteField(hpack.HeaderField{Name: k, Value: v}); err != nil { - t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err) - } - } - return buf.Bytes() -} - -func cleanDate(res *http.Response) { - if d := res.Header["Date"]; len(d) == 1 { - d[0] = "XXX" - } -} - func TestSorterPoolAllocs(t *testing.T) { ss := []string{"a", "b", "c"} h := http.Header{ @@ -254,8 +213,8 @@ func TestNoUnicodeStrings(t *testing.T) { } } -// setForTest sets *p = v, and restores its original value in t.Cleanup. -func setForTest[T any](t testing.TB, p *T, v T) { +// SetForTest sets *p = v, and restores its original value in t.Cleanup. +func SetForTest[T any](t testing.TB, p *T, v T) { orig := *p t.Cleanup(func() { *p = orig @@ -263,18 +222,10 @@ func setForTest[T any](t testing.TB, p *T, v T) { *p = v } -// must returns v if err is nil, or panics otherwise. -func must[T any](v T, err error) T { +// Must returns v if err is nil, or panics otherwise. +func Must[T any](v T, err error) T { if err != nil { panic(err) } return v } - -// synctestSubtest starts a subtest and runs f in a synctest bubble within it. -func synctestSubtest(t *testing.T, name string, f func(testing.TB)) { - t.Helper() - t.Run(name, func(t *testing.T) { - synctestTest(t, f) - }) -} diff --git a/http2/netconn_test.go b/http2/netconn_test.go index 8665e454..6eba9e2d 100644 --- a/http2/netconn_test.go +++ b/http2/netconn_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package http2 +package http2_test import ( "bytes" diff --git a/http2/server_internal_test.go b/http2/server_internal_test.go new file mode 100644 index 00000000..763976e1 --- /dev/null +++ b/http2/server_internal_test.go @@ -0,0 +1,84 @@ +// Copyright 2026 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http2 + +import ( + "errors" + "fmt" + "net/http" + "strings" + "testing" +) + +func TestCheckValidHTTP2Request(t *testing.T) { + tests := []struct { + h http.Header + want error + }{ + { + h: http.Header{"Te": {"trailers"}}, + want: nil, + }, + { + h: http.Header{"Te": {"trailers", "bogus"}}, + want: errors.New(`request header "TE" may only be "trailers" in HTTP/2`), + }, + { + h: http.Header{"Foo": {""}}, + want: nil, + }, + { + h: http.Header{"Connection": {""}}, + want: errors.New(`request header "Connection" is not valid in HTTP/2`), + }, + { + h: http.Header{"Proxy-Connection": {""}}, + want: errors.New(`request header "Proxy-Connection" is not valid in HTTP/2`), + }, + { + h: http.Header{"Keep-Alive": {""}}, + want: errors.New(`request header "Keep-Alive" is not valid in HTTP/2`), + }, + { + h: http.Header{"Upgrade": {""}}, + want: errors.New(`request header "Upgrade" is not valid in HTTP/2`), + }, + } + for i, tt := range tests { + got := checkValidHTTP2RequestHeaders(tt.h) + if !equalError(got, tt.want) { + t.Errorf("%d. checkValidHTTP2Request = %v; want %v", i, got, tt.want) + } + } +} + +// TestCanonicalHeaderCacheGrowth verifies that the canonical header cache +// size is capped to a reasonable level. +func TestCanonicalHeaderCacheGrowth(t *testing.T) { + for _, size := range []int{1, (1 << 20) - 10} { + base := strings.Repeat("X", size) + sc := &serverConn{ + serveG: newGoroutineLock(), + } + count := 0 + added := 0 + for added < 10*maxCachedCanonicalHeadersKeysSize { + h := fmt.Sprintf("%v-%v", base, count) + c := sc.canonicalHeader(h) + if len(h) != len(c) { + t.Errorf("sc.canonicalHeader(%q) = %q, want same length", h, c) + } + count++ + added += len(h) + } + total := 0 + for k, v := range sc.canonHeader { + total += len(k) + len(v) + 100 + } + if total > maxCachedCanonicalHeadersKeysSize { + t.Errorf("after adding %v ~%v-byte headers, canonHeader cache is ~%v bytes, want <%v", count, size, total, maxCachedCanonicalHeadersKeysSize) + } + } +} diff --git a/http2/server_push_test.go b/http2/server_push_test.go index d27d2d7d..105589f6 100644 --- a/http2/server_push_test.go +++ b/http2/server_push_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package http2 +package http2_test import ( "errors" @@ -14,6 +14,8 @@ import ( "testing" "testing/synctest" "time" + + . "golang.org/x/net/http2" ) func TestServer_Push_Success(t *testing.T) { synctestTest(t, testServer_Push_Success) } @@ -463,16 +465,16 @@ func testServer_Push_StateTransitions(t testing.TB) { defer st.Close() st.greet() - if st.stream(2) != nil { + if st.streamExists(2) { t.Fatal("stream 2 should be empty") } - if got, want := st.streamState(2), stateIdle; got != want { + if got, want := st.streamState(2), StateIdle; got != want { t.Fatalf("streamState(2)=%v, want %v", got, want) } getSlash(st) // After the PUSH_PROMISE is sent, the stream should be stateHalfClosedRemote. _ = readFrame[*PushPromiseFrame](t, st) - if got, want := st.streamState(2), stateHalfClosedRemote; got != want { + if got, want := st.streamState(2), StateHalfClosedRemote; got != want { t.Fatalf("streamState(2)=%v, want %v", got, want) } // We stall the HTTP handler for "/pushed" until the above check. If we don't @@ -484,7 +486,7 @@ func testServer_Push_StateTransitions(t testing.TB) { streamID: 2, endStream: false, }) - if got, want := st.streamState(2), stateClosed; got != want { + if got, want := st.streamState(2), StateClosed; got != want { t.Fatalf("streamState(2)=%v, want %v", got, want) } close(finishedPush) diff --git a/http2/server_test.go b/http2/server_test.go index 7da28af5..73530858 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package http2 +package http2_test import ( "bytes" @@ -30,6 +30,7 @@ import ( "testing/synctest" "time" + . "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" ) @@ -74,7 +75,7 @@ type serverTester struct { serverLogBuf safeBuffer // logger for httptest.Server logFilter []string // substrings to filter out scMu sync.Mutex // guards sc - sc *serverConn + sc *ServerConn testConnFramer callsMu sync.Mutex @@ -94,15 +95,22 @@ type serverTester struct { hpackEnc *hpack.Encoder } -func init() { - testHookOnPanicMu = new(sync.Mutex) - goAwayTimeout = 25 * time.Millisecond +type twriter struct { + t testing.TB + st *serverTester // optional } -func resetHooks() { - testHookOnPanicMu.Lock() - testHookOnPanic = nil - testHookOnPanicMu.Unlock() +func (w twriter) Write(p []byte) (n int, err error) { + if w.st != nil { + ps := string(p) + for _, phrase := range w.st.logFilter { + if strings.Contains(ps, phrase) { + return len(p), nil // no logging + } + } + } + w.t.Logf("%s", p) + return len(p), nil } func newTestServer(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *httptest.Server { @@ -196,18 +204,18 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} t.Cleanup(func() { st.Close() - time.Sleep(goAwayTimeout) // give server time to shut down + time.Sleep(GoAwayTimeout) // give server time to shut down }) - connc := make(chan *serverConn) + connc := make(chan *ServerConn) go func() { - h2server.serveConn(&netConnWithConnectionState{ + h2server.TestServeConn(&netConnWithConnectionState{ Conn: srv, state: tlsState, }, &ServeConnOpts{ Handler: handler, BaseConfig: h1server, - }, func(sc *serverConn) { + }, func(sc *ServerConn) { connc <- sc }) }() @@ -217,7 +225,7 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} st.testConnFramer = testConnFramer{ t: t, fr: NewFramer(st.cc, st.cc), - dec: hpack.NewDecoder(initialHeaderTableSize, nil), + dec: hpack.NewDecoder(InitialHeaderTableSize, nil), } synctest.Wait() return st @@ -281,8 +289,6 @@ func (call *serverHandlerCall) exit() { // net.Conn and synthetic time. This function is still around because some benchmarks // rely on it; new tests should use newServerTester. func newServerTesterWithRealConn(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester { - resetHooks() - ts := httptest.NewUnstartedServer(handler) t.Cleanup(ts.Close) @@ -337,11 +343,11 @@ func newServerTesterWithRealConn(t testing.TB, handler http.HandlerFunc, opts .. if VerboseLogs { t.Logf("Running test server at: %s", ts.URL) } - testHookGetServerConn = func(v *serverConn) { + SetTestHookGetServerConn(t, func(v *ServerConn) { st.scMu.Lock() defer st.scMu.Unlock() st.sc = v - } + }) log.SetOutput(io.MultiWriter(stderrv(), twriter{t: t, st: st})) cc, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig) if err != nil { @@ -351,26 +357,24 @@ func newServerTesterWithRealConn(t testing.TB, handler http.HandlerFunc, opts .. st.testConnFramer = testConnFramer{ t: t, fr: NewFramer(st.cc, st.cc), - dec: hpack.NewDecoder(initialHeaderTableSize, nil), + dec: hpack.NewDecoder(InitialHeaderTableSize, nil), } if framerReuseFrames { st.fr.SetReuseFrames() } - if !logFrameReads && !logFrameWrites { - st.fr.debugReadLoggerf = func(m string, v ...interface{}) { + if !LogFrameReads() && !LogFrameWrites() { + st.fr.TestSetDebugReadLoggerf(func(m string, v ...interface{}) { m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n" st.frameReadLogMu.Lock() fmt.Fprintf(&st.frameReadLogBuf, m, v...) st.frameReadLogMu.Unlock() - } - st.fr.debugWriteLoggerf = func(m string, v ...interface{}) { + }) + st.fr.TestSetDebugWriteLoggerf(func(m string, v ...interface{}) { m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n" st.frameWriteLogMu.Lock() fmt.Fprintf(&st.frameWriteLogBuf, m, v...) st.frameWriteLogMu.Unlock() - } - st.fr.logReads = true - st.fr.logWrites = true + }) } return st } @@ -390,12 +394,6 @@ func (st *serverTester) authority() string { return "dummy.tld" } -func (st *serverTester) closeConn() { - st.scMu.Lock() - defer st.scMu.Unlock() - st.sc.conn.Close() -} - func (st *serverTester) addLogFilter(phrase string) { st.logFilter = append(st.logFilter, phrase) } @@ -413,30 +411,12 @@ func (st *serverTester) nextHandlerCall() *serverHandlerCall { return call } -func (st *serverTester) stream(id uint32) *stream { - ch := make(chan *stream, 1) - st.sc.serveMsgCh <- func(int) { - ch <- st.sc.streams[id] - } - return <-ch +func (st *serverTester) streamExists(id uint32) bool { + return st.sc.TestStreamExists(id) } -func (st *serverTester) streamState(id uint32) streamState { - ch := make(chan streamState, 1) - st.sc.serveMsgCh <- func(int) { - state, _ := st.sc.state(id) - ch <- state - } - return <-ch -} - -// loopNum reports how many times this conn's select loop has gone around. -func (st *serverTester) loopNum() int { - lastc := make(chan int, 1) - st.sc.serveMsgCh <- func(loopNum int) { - lastc <- loopNum - } - return <-lastc +func (st *serverTester) streamState(id uint32) StreamState { + return st.sc.TestStreamState(id) } func (st *serverTester) Close() { @@ -502,11 +482,6 @@ func (st *serverTester) greetAndCheckSettings(checkSetting func(s Setting) error if f.FrameHeader.StreamID != 0 { st.t.Fatalf("WindowUpdate StreamID = %d; want 0", f.FrameHeader.StreamID) } - conf := configFromServer(st.sc.hs, st.sc.srv) - incr := uint32(conf.MaxUploadBufferPerConnection - initialWindowSize) - if f.Increment != incr { - st.t.Fatalf("WindowUpdate increment = %d; want %d", f.Increment, incr) - } gotWindowUpdate = true default: @@ -523,12 +498,12 @@ func (st *serverTester) greetAndCheckSettings(checkSetting func(s Setting) error } func (st *serverTester) writePreface() { - n, err := st.cc.Write(clientPreface) + n, err := st.cc.Write([]byte(ClientPreface)) if err != nil { st.t.Fatalf("Error writing client preface: %v", err) } - if n != len(clientPreface) { - st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(clientPreface)) + if n != len(ClientPreface) { + st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(ClientPreface)) } } @@ -631,18 +606,9 @@ func (st *serverTester) bodylessReq1(headers ...string) { } func (st *serverTester) wantConnFlowControlConsumed(consumed int32) { - conf := configFromServer(st.sc.hs, st.sc.srv) - donec := make(chan struct{}) - st.sc.sendServeMsg(func(sc *serverConn) { - defer close(donec) - var avail int32 - initial := conf.MaxUploadBufferPerConnection - avail = sc.inflow.avail + sc.inflow.unsent - if got, want := initial-avail, consumed; got != want { - st.t.Errorf("connection flow control consumed: %v, want %v", got, want) - } - }) - <-donec + if got, want := st.sc.TestFlowControlConsumed(), consumed; got != want { + st.t.Errorf("connection flow control consumed: %v, want %v", got, want) + } } func TestServer(t *testing.T) { synctestTest(t, testServer) } @@ -1269,7 +1235,7 @@ func TestServer_MaxQueuedControlFrames(t *testing.T) { } func testServer_MaxQueuedControlFrames(t testing.TB) { // Goroutine debugging makes this test very slow. - disableGoroutineTracking(t) + DisableGoroutineTracking(t) st := newServerTester(t, nil) st.greet() @@ -1280,7 +1246,7 @@ func testServer_MaxQueuedControlFrames(t testing.TB) { // Send maxQueuedControlFrames pings, plus a few extra // to account for ones that enter the server's write buffer. const extraPings = 2 - for i := 0; i < maxQueuedControlFrames+extraPings; i++ { + for i := 0; i < MaxQueuedControlFrames+extraPings; i++ { pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} st.fr.WritePing(false, pingData) } @@ -1290,7 +1256,7 @@ func testServer_MaxQueuedControlFrames(t testing.TB) { // It should have closed the connection after exceeding the control frame limit. st.cc.(*synctestNetConn).SetReadBufferSize(math.MaxInt) - st.advance(goAwayTimeout) + st.advance(GoAwayTimeout) // Some frames may have persisted in the server's buffers. for i := 0; i < 10; i++ { if st.readFrame() == nil { @@ -1312,10 +1278,10 @@ func testServer_RejectsLargeFrames(t testing.TB) { // Write too large of a frame (too large by one byte) // We ignore the return value because it's expected that the server // will only read the first 9 bytes (the headre) and then disconnect. - st.fr.WriteRawFrame(0xff, 0, 0, make([]byte, defaultMaxReadFrameSize+1)) + st.fr.WriteRawFrame(0xff, 0, 0, make([]byte, DefaultMaxReadFrameSize+1)) st.wantGoAway(0, ErrCodeFrameSize) - st.advance(goAwayTimeout) + st.advance(GoAwayTimeout) st.wantClosed() } @@ -1600,27 +1566,27 @@ func testServer_StateTransitions(t testing.TB) { leaveHandler := make(chan bool) st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) { inHandler <- true - if st.stream(1) == nil { - t.Errorf("nil stream 1 in handler") + if !st.streamExists(1) { + t.Errorf("stream 1 does not exist in handler") } - if got, want := st.streamState(1), stateOpen; got != want { + if got, want := st.streamState(1), StateOpen; got != want { t.Errorf("in handler, state is %v; want %v", got, want) } writeData <- true if n, err := r.Body.Read(make([]byte, 1)); n != 0 || err != io.EOF { t.Errorf("body read = %d, %v; want 0, EOF", n, err) } - if got, want := st.streamState(1), stateHalfClosedRemote; got != want { + if got, want := st.streamState(1), StateHalfClosedRemote; got != want { t.Errorf("in handler, state is %v; want %v", got, want) } <-leaveHandler }) st.greet() - if st.stream(1) != nil { + if st.streamExists(1) { t.Fatal("stream 1 should be empty") } - if got := st.streamState(1); got != stateIdle { + if got := st.streamState(1); got != StateIdle { t.Fatalf("stream 1 should be idle; got %v", got) } @@ -1640,10 +1606,10 @@ func testServer_StateTransitions(t testing.TB) { endStream: true, }) - if got, want := st.streamState(1), stateClosed; got != want { + if got, want := st.streamState(1), StateClosed; got != want { t.Errorf("at end, state is %v; want %v", got, want) } - if st.stream(1) != nil { + if st.streamExists(1) { t.Fatal("at end, stream 1 should be gone") } } @@ -1704,7 +1670,7 @@ func testServer_Rejects_HeadersEnd_Then_Continuation(t testing.TB) { streamID: 1, endStream: true, }) - if err := st.fr.WriteContinuation(1, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil { + if err := st.fr.WriteContinuation(1, true, EncodeHeaderRaw(t, "foo", "bar")); err != nil { t.Fatal(err) } st.wantGoAway(1, ErrCodeProtocol) @@ -1722,7 +1688,7 @@ func testServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t testing.TB) EndStream: true, EndHeaders: false, }) - if err := st.fr.WriteContinuation(3, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil { + if err := st.fr.WriteContinuation(3, true, EncodeHeaderRaw(t, "foo", "bar")); err != nil { t.Fatal(err) } st.wantGoAway(0, ErrCodeProtocol) @@ -1782,7 +1748,7 @@ func TestServer_Rejects_PriorityUpdateUnparsable(t *testing.T) { } func testServer_Rejects_PriorityUnparsable(t testing.TB) { st := newServerTester(t, nil, func(s *Server) { - s.NewWriteScheduler = newPriorityWriteSchedulerRFC9218 + s.NewWriteScheduler = NewPriorityWriteSchedulerRFC9218 }) defer st.Close() st.greet() @@ -2482,12 +2448,12 @@ func testServer_Rejects_Too_Many_Streams(t testing.TB) { EndHeaders: true, }) } - for i := 0; i < defaultMaxStreams; i++ { + for i := 0; i < DefaultMaxStreams; i++ { sendReq(streamID()) <-inHandler } defer func() { - for i := 0; i < defaultMaxStreams; i++ { + for i := 0; i < DefaultMaxStreams; i++ { leaveHandler <- true } }() @@ -2611,14 +2577,12 @@ func testServer_NoCrash_HandlerClose_Then_ClientClose(t testing.TB) { panicVal interface{} ) - testHookOnPanicMu.Lock() - testHookOnPanic = func(sc *serverConn, pv interface{}) bool { + SetTestHookOnPanic(t, func(sc *ServerConn, pv interface{}) bool { panMu.Lock() panicVal = pv panMu.Unlock() return true - } - testHookOnPanicMu.Unlock() + }) // Now force the serve loop to end, via closing the connection. st.cc.Close() @@ -2738,7 +2702,7 @@ func TestServer_MaxDecoderHeaderTableSize(t *testing.T) { synctestTest(t, testServer_MaxDecoderHeaderTableSize) } func testServer_MaxDecoderHeaderTableSize(t testing.TB) { - wantHeaderTableSize := uint32(initialHeaderTableSize * 2) + wantHeaderTableSize := uint32(InitialHeaderTableSize * 2) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, func(s *Server) { s.MaxDecoderHeaderTableSize = wantHeaderTableSize }) @@ -2764,7 +2728,7 @@ func TestServer_MaxEncoderHeaderTableSize(t *testing.T) { synctestTest(t, testServer_MaxEncoderHeaderTableSize) } func testServer_MaxEncoderHeaderTableSize(t testing.TB) { - wantHeaderTableSize := uint32(initialHeaderTableSize / 2) + wantHeaderTableSize := uint32(InitialHeaderTableSize / 2) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, func(s *Server) { s.MaxEncoderHeaderTableSize = wantHeaderTableSize }) @@ -2772,7 +2736,7 @@ func testServer_MaxEncoderHeaderTableSize(t testing.TB) { st.greet() - if got, want := st.sc.hpackEncoder.MaxDynamicTableSize(), wantHeaderTableSize; got != want { + if got, want := st.sc.TestHPACKEncoder().MaxDynamicTableSize(), wantHeaderTableSize; got != want { t.Errorf("server encoder is using a header table size of %d, want %d", got, want) } } @@ -2784,15 +2748,15 @@ func testServerDoS_MaxHeaderListSize(t testing.TB) { defer st.Close() // shake hands - frameSize := defaultMaxReadFrameSize + frameSize := DefaultMaxReadFrameSize var advHeaderListSize *uint32 st.greetAndCheckSettings(func(s Setting) error { switch s.ID { case SettingMaxFrameSize: - if s.Val < minMaxFrameSize { - frameSize = minMaxFrameSize - } else if s.Val > maxFrameSize { - frameSize = maxFrameSize + if s.Val < MinMaxFrameSize { + frameSize = MinMaxFrameSize + } else if s.Val > MaxFrameSize { + frameSize = MaxFrameSize } else { frameSize = int(s.Val) } @@ -2883,7 +2847,7 @@ func testCompressionErrorOnWrite(t testing.TB) { defer st.Close() st.greet() - maxAllowed := st.sc.framer.maxHeaderStringLen() + maxAllowed := st.sc.TestFramerMaxHeaderStringLen() // Crank this up, now that we have a conn connected with the // hpack.Decoder's max string length set has been initialized @@ -3136,7 +3100,7 @@ func testServerDoesntWriteInvalidHeaders(t testing.TB) { } func BenchmarkServerGets(b *testing.B) { - disableGoroutineTracking(b) + DisableGoroutineTracking(b) b.ReportAllocs() const msg = "Hello, world" @@ -3167,7 +3131,7 @@ func BenchmarkServerGets(b *testing.B) { } func BenchmarkServerPosts(b *testing.B) { - disableGoroutineTracking(b) + DisableGoroutineTracking(b) b.ReportAllocs() const msg = "Hello, world" @@ -3218,7 +3182,7 @@ func BenchmarkServerToClientStreamReuseFrames(b *testing.B) { } func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) { - disableGoroutineTracking(b) + DisableGoroutineTracking(b) b.ReportAllocs() const msgLen = 1 // default window size @@ -3360,7 +3324,7 @@ func testServeConnNilOpts(t testing.TB) { mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { gotRequest = r.URL.Path }) - setForTest(t, &http.DefaultServeMux, &mux) + SetForTest(t, &http.DefaultServeMux, &mux) srvConn, cliConn := net.Pipe() defer srvConn.Close() @@ -3544,15 +3508,8 @@ func testServerContentLengthCanBeDisabled(t testing.TB) { }) } -func disableGoroutineTracking(t testing.TB) { - disableDebugGoroutines.Store(true) - t.Cleanup(func() { - disableDebugGoroutines.Store(false) - }) -} - func BenchmarkServer_GetRequest(b *testing.B) { - disableGoroutineTracking(b) + DisableGoroutineTracking(b) b.ReportAllocs() const msg = "Hello, world." st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) { @@ -3584,7 +3541,7 @@ func BenchmarkServer_GetRequest(b *testing.B) { } func BenchmarkServer_PostRequest(b *testing.B) { - disableGoroutineTracking(b) + DisableGoroutineTracking(b) b.ReportAllocs() const msg = "Hello, world." st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) { @@ -3644,7 +3601,7 @@ func testServerHandleCustomConn(t testing.TB) { return } if sf, ok := f.(*SettingsFrame); !ok || sf.IsAck() { - t.Errorf("Got %v; want non-ACK SettingsFrame", summarizeFrame(f)) + t.Errorf("Got %v; want non-ACK SettingsFrame", SummarizeFrame(f)) return } f, err = fr.ReadFrame() @@ -3653,7 +3610,7 @@ func testServerHandleCustomConn(t testing.TB) { return } if sf, ok := f.(*SettingsFrame); !ok || !sf.IsAck() { - t.Errorf("Got %v; want ACK SettingsFrame", summarizeFrame(f)) + t.Errorf("Got %v; want ACK SettingsFrame", SummarizeFrame(f)) return } var henc hpackEncoder @@ -3667,6 +3624,7 @@ func testServerHandleCustomConn(t testing.TB) { <-handlerDone }() const testString = "my custom ConnectionState" + const cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F // defined in ciphers.go fakeConnState := tls.ConnectionState{ ServerName: testString, Version: tls.VersionTLS12, @@ -3734,49 +3692,6 @@ func (he *hpackEncoder) encodeHeaderRaw(t testing.TB, headers ...string) []byte return he.buf.Bytes() } -func TestCheckValidHTTP2Request(t *testing.T) { synctestTest(t, testCheckValidHTTP2Request) } -func testCheckValidHTTP2Request(t testing.TB) { - tests := []struct { - h http.Header - want error - }{ - { - h: http.Header{"Te": {"trailers"}}, - want: nil, - }, - { - h: http.Header{"Te": {"trailers", "bogus"}}, - want: errors.New(`request header "TE" may only be "trailers" in HTTP/2`), - }, - { - h: http.Header{"Foo": {""}}, - want: nil, - }, - { - h: http.Header{"Connection": {""}}, - want: errors.New(`request header "Connection" is not valid in HTTP/2`), - }, - { - h: http.Header{"Proxy-Connection": {""}}, - want: errors.New(`request header "Proxy-Connection" is not valid in HTTP/2`), - }, - { - h: http.Header{"Keep-Alive": {""}}, - want: errors.New(`request header "Keep-Alive" is not valid in HTTP/2`), - }, - { - h: http.Header{"Upgrade": {""}}, - want: errors.New(`request header "Upgrade" is not valid in HTTP/2`), - }, - } - for i, tt := range tests { - got := checkValidHTTP2RequestHeaders(tt.h) - if !equalError(got, tt.want) { - t.Errorf("%d. checkValidHTTP2Request = %v; want %v", i, got, tt.want) - } - } -} - // golang.org/issue/14030 func TestExpect100ContinueAfterHandlerWrites(t *testing.T) { synctestTest(t, testExpect100ContinueAfterHandlerWrites) @@ -3923,7 +3838,7 @@ func testServerReturnsStreamAndConnFlowControlOnBodyClose(t testing.TB) { streamID: 1, endStream: false, }) - const size = inflowMinRefresh // enough to trigger flow control return + const size = InflowMinRefresh // enough to trigger flow control return st.writeData(1, false, make([]byte, size)) st.wantWindowUpdate(0, size) // conn-level flow control is returned unblockHandler <- struct{}{} @@ -4116,7 +4031,7 @@ func testServerHandlerConnectionClose(t testing.TB) { case *GoAwayFrame: sawGoAway = true if f.LastStreamID != 1 || f.ErrCode != ErrCodeNo { - t.Errorf("unexpected GOAWAY frame: %v", summarizeFrame(f)) + t.Errorf("unexpected GOAWAY frame: %v", SummarizeFrame(f)) } // Create a stream and reset it. // The server should ignore the stream. @@ -4150,11 +4065,11 @@ func testServerHandlerConnectionClose(t testing.TB) { sawRes = true case *DataFrame: if f.StreamID != 1 || !f.StreamEnded() || len(f.Data()) != 0 { - t.Errorf("unexpected DATA frame: %v", summarizeFrame(f)) + t.Errorf("unexpected DATA frame: %v", SummarizeFrame(f)) } case *WindowUpdateFrame: if !sawGoAway { - t.Errorf("unexpected WINDOW_UPDATE frame: %v", summarizeFrame(f)) + t.Errorf("unexpected WINDOW_UPDATE frame: %v", SummarizeFrame(f)) return } if f.StreamID != 0 { @@ -4164,9 +4079,9 @@ func testServerHandlerConnectionClose(t testing.TB) { sawWindowUpdate = true unblockHandler <- true st.sync() - st.advance(goAwayTimeout) + st.advance(GoAwayTimeout) default: - t.Logf("unexpected frame: %v", summarizeFrame(f)) + t.Logf("unexpected frame: %v", SummarizeFrame(f)) } } if !sawGoAway { @@ -4190,17 +4105,17 @@ func testServer_Headers_HalfCloseRemote(t testing.TB) { writeHeaders := make(chan bool) leaveHandler := make(chan bool) st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - if st.stream(1) == nil { - t.Errorf("nil stream 1 in handler") + if !st.streamExists(1) { + t.Errorf("stream 1 does not exist in handler") } - if got, want := st.streamState(1), stateOpen; got != want { + if got, want := st.streamState(1), StateOpen; got != want { t.Errorf("in handler, state is %v; want %v", got, want) } writeData <- true if n, err := r.Body.Read(make([]byte, 1)); n != 0 || err != io.EOF { t.Errorf("body read = %d, %v; want 0, EOF", n, err) } - if got, want := st.streamState(1), stateHalfClosedRemote; got != want { + if got, want := st.streamState(1), StateHalfClosedRemote; got != want { t.Errorf("in handler, state is %v; want %v", got, want) } writeHeaders <- true @@ -4453,7 +4368,7 @@ func testNoErrorLoggedOnPostAfterGOAWAY(t testing.TB) { endStream: true, }) - st.sc.startGracefulShutdown() + st.sc.StartGracefulShutdown() st.wantRSTStream(1, ErrCodeNo) st.wantGoAway(1, ErrCodeNo) @@ -4579,7 +4494,7 @@ func testProtocolErrorAfterGoAway(t testing.TB) { t.Fatal(err) } - st.advance(goAwayTimeout) + st.advance(GoAwayTimeout) st.wantGoAway(1, ErrCodeNo) st.wantClosed() } @@ -4637,36 +4552,6 @@ func TestServerInitialFlowControlWindow(t *testing.T) { } } -// TestCanonicalHeaderCacheGrowth verifies that the canonical header cache -// size is capped to a reasonable level. -func TestCanonicalHeaderCacheGrowth(t *testing.T) { synctestTest(t, testCanonicalHeaderCacheGrowth) } -func testCanonicalHeaderCacheGrowth(t testing.TB) { - for _, size := range []int{1, (1 << 20) - 10} { - base := strings.Repeat("X", size) - sc := &serverConn{ - serveG: newGoroutineLock(), - } - count := 0 - added := 0 - for added < 10*maxCachedCanonicalHeadersKeysSize { - h := fmt.Sprintf("%v-%v", base, count) - c := sc.canonicalHeader(h) - if len(h) != len(c) { - t.Errorf("sc.canonicalHeader(%q) = %q, want same length", h, c) - } - count++ - added += len(h) - } - total := 0 - for k, v := range sc.canonHeader { - total += len(k) + len(v) + 100 - } - if total > maxCachedCanonicalHeadersKeysSize { - t.Errorf("after adding %v ~%v-byte headers, canonHeader cache is ~%v bytes, want <%v", count, size, total, maxCachedCanonicalHeadersKeysSize) - } - } -} - // TestServerWriteDoesNotRetainBufferAfterReturn checks for access to // the slice passed to ResponseWriter.Write after Write returns. // @@ -5160,12 +5045,12 @@ func testServerSettingNoRFC7540Priorities(t testing.TB) { }{ { ws: func() WriteScheduler { - return newPriorityWriteSchedulerRFC7540(nil) + return NewPriorityWriteSchedulerRFC7540(nil) }, wantNoRFC7540Setting: false, }, { - ws: newPriorityWriteSchedulerRFC9218, + ws: NewPriorityWriteSchedulerRFC9218, wantNoRFC7540Setting: true, }, { @@ -5173,7 +5058,7 @@ func testServerSettingNoRFC7540Priorities(t testing.TB) { wantNoRFC7540Setting: true, }, { - ws: newRoundRobinWriteScheduler, + ws: NewRoundRobinWriteScheduler, wantNoRFC7540Setting: true, }, } @@ -5228,7 +5113,7 @@ func testServerRFC7540PrioritySmallPayload(t testing.TB) { } }, func(s *Server) { s.NewWriteScheduler = func() WriteScheduler { - return newPriorityWriteSchedulerRFC7540(nil) + return NewPriorityWriteSchedulerRFC7540(nil) } }) if syncConn, ok := st.cc.(*synctestNetConn); ok { @@ -5295,7 +5180,7 @@ func testServerRFC9218PrioritySmallPayload(t testing.TB) { } } }, func(s *Server) { - s.NewWriteScheduler = newPriorityWriteSchedulerRFC9218 + s.NewWriteScheduler = NewPriorityWriteSchedulerRFC9218 }) if syncConn, ok := st.cc.(*synctestNetConn); ok { syncConn.SetReadBufferSize(1) @@ -5355,7 +5240,7 @@ func testServerRFC9218Priority(t testing.TB) { f.Flush() } }, func(s *Server) { - s.NewWriteScheduler = newPriorityWriteSchedulerRFC9218 + s.NewWriteScheduler = NewPriorityWriteSchedulerRFC9218 }) defer st.Close() if syncConn, ok := st.cc.(*synctestNetConn); ok { @@ -5363,8 +5248,9 @@ func testServerRFC9218Priority(t testing.TB) { } else { t.Fatal("Server connection is not synctestNetConn") } - st.sc.flow.add(1 << 30) st.greet() + st.writeWindowUpdate(0, 1<<30) + synctest.Wait() // Create 8 streams, where streams with larger ID has lower urgency value // (i.e. more urgent). @@ -5409,7 +5295,7 @@ func testServerRFC9218PriorityIgnoredWhenProxied(t testing.TB) { f.Flush() } }, func(s *Server) { - s.NewWriteScheduler = newPriorityWriteSchedulerRFC9218 + s.NewWriteScheduler = NewPriorityWriteSchedulerRFC9218 }) defer st.Close() if syncConn, ok := st.cc.(*synctestNetConn); ok { @@ -5417,8 +5303,9 @@ func testServerRFC9218PriorityIgnoredWhenProxied(t testing.TB) { } else { t.Fatal("Server connection is not synctestNetConn") } - st.sc.flow.add(1 << 30) st.greet() + st.writeWindowUpdate(0, 1<<30) + synctest.Wait() // Create 8 streams, where streams with larger ID has lower urgency value // (i.e. more urgent). These should be ignored since the requests are @@ -5457,7 +5344,7 @@ func testServerRFC9218PriorityAware(t testing.TB) { f.Flush() } }, func(s *Server) { - s.NewWriteScheduler = newPriorityWriteSchedulerRFC9218 + s.NewWriteScheduler = NewPriorityWriteSchedulerRFC9218 }) defer st.Close() if syncConn, ok := st.cc.(*synctestNetConn); ok { @@ -5465,8 +5352,9 @@ func testServerRFC9218PriorityAware(t testing.TB) { } else { t.Fatal("Server connection is not synctestNetConn") } - st.sc.flow.add(1 << 30) st.greet() + st.writeWindowUpdate(0, 1<<30) + synctest.Wait() // When there is no indication that the client is aware of RFC 9218 // priority, it should process streams in a round-robin manner. diff --git a/http2/synctest_test.go b/http2/synctest_test.go index 1f4402b8..fec7ad65 100644 --- a/http2/synctest_test.go +++ b/http2/synctest_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package http2 +package http2_test import ( "testing" @@ -16,3 +16,11 @@ func synctestTest(t *testing.T, f func(t testing.TB)) { f(t) }) } + +// synctestSubtest starts a subtest and runs f in a synctest bubble within it. +func synctestSubtest(t *testing.T, name string, f func(testing.TB)) { + t.Helper() + t.Run(name, func(t *testing.T) { + synctestTest(t, f) + }) +} diff --git a/http2/transport.go b/http2/transport.go index 603387b7..2e9c2f6a 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -3229,10 +3229,6 @@ func (gz *gzipReader) Close() error { return gz.body.Close() } -type errorReader struct{ err error } - -func (r errorReader) Read(p []byte) (int, error) { return 0, r.err } - // isConnectionCloseRequest reports whether req should use its own // connection for a single request and then close the connection. func isConnectionCloseRequest(req *http.Request) bool { diff --git a/http2/transport_internal_test.go b/http2/transport_internal_test.go new file mode 100644 index 00000000..2f8532fd --- /dev/null +++ b/http2/transport_internal_test.go @@ -0,0 +1,293 @@ +// Copyright 2026 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http2 + +import ( + "bytes" + "compress/gzip" + "crypto/tls" + "fmt" + "io" + "io/fs" + "net/http" + "reflect" + "strings" + "testing" + "time" +) + +type panicReader struct{} + +func (panicReader) Read([]byte) (int, error) { panic("unexpected Read") } +func (panicReader) Close() error { panic("unexpected Close") } + +func TestActualContentLength(t *testing.T) { + tests := []struct { + req *http.Request + want int64 + }{ + // Verify we don't read from Body: + 0: { + req: &http.Request{Body: panicReader{}}, + want: -1, + }, + // nil Body means 0, regardless of ContentLength: + 1: { + req: &http.Request{Body: nil, ContentLength: 5}, + want: 0, + }, + // ContentLength is used if set. + 2: { + req: &http.Request{Body: panicReader{}, ContentLength: 5}, + want: 5, + }, + // http.NoBody means 0, not -1. + 3: { + req: &http.Request{Body: http.NoBody}, + want: 0, + }, + } + for i, tt := range tests { + got := actualContentLength(tt.req) + if got != tt.want { + t.Errorf("test[%d]: got %d; want %d", i, got, tt.want) + } + } +} + +// Tests that gzipReader doesn't crash on a second Read call following +// the first Read call's gzip.NewReader returning an error. +func TestGzipReader_DoubleReadCrash(t *testing.T) { + gz := &gzipReader{ + body: io.NopCloser(strings.NewReader("0123456789")), + } + var buf [1]byte + n, err1 := gz.Read(buf[:]) + if n != 0 || !strings.Contains(fmt.Sprint(err1), "invalid header") { + t.Fatalf("Read = %v, %v; want 0, invalid header", n, err1) + } + n, err2 := gz.Read(buf[:]) + if n != 0 || err2 != err1 { + t.Fatalf("second Read = %v, %v; want 0, %v", n, err2, err1) + } +} + +func TestGzipReader_ReadAfterClose(t *testing.T) { + body := bytes.Buffer{} + w := gzip.NewWriter(&body) + w.Write([]byte("012345679")) + w.Close() + gz := &gzipReader{ + body: io.NopCloser(&body), + } + var buf [1]byte + n, err := gz.Read(buf[:]) + if n != 1 || err != nil { + t.Fatalf("first Read = %v, %v; want 1, nil", n, err) + } + if err := gz.Close(); err != nil { + t.Fatalf("gz Close error: %v", err) + } + n, err = gz.Read(buf[:]) + if n != 0 || err != fs.ErrClosed { + t.Fatalf("Read after close = %v, %v; want 0, fs.ErrClosed", n, err) + } +} + +func TestTransportNewTLSConfig(t *testing.T) { + tests := [...]struct { + conf *tls.Config + host string + want *tls.Config + }{ + // Normal case. + 0: { + conf: nil, + host: "foo.com", + want: &tls.Config{ + ServerName: "foo.com", + NextProtos: []string{NextProtoTLS}, + }, + }, + + // User-provided name (bar.com) takes precedence: + 1: { + conf: &tls.Config{ + ServerName: "bar.com", + }, + host: "foo.com", + want: &tls.Config{ + ServerName: "bar.com", + NextProtos: []string{NextProtoTLS}, + }, + }, + + // NextProto is prepended: + 2: { + conf: &tls.Config{ + NextProtos: []string{"foo", "bar"}, + }, + host: "example.com", + want: &tls.Config{ + ServerName: "example.com", + NextProtos: []string{NextProtoTLS, "foo", "bar"}, + }, + }, + + // NextProto is not duplicated: + 3: { + conf: &tls.Config{ + NextProtos: []string{"foo", "bar", NextProtoTLS}, + }, + host: "example.com", + want: &tls.Config{ + ServerName: "example.com", + NextProtos: []string{"foo", "bar", NextProtoTLS}, + }, + }, + } + for i, tt := range tests { + // Ignore the session ticket keys part, which ends up populating + // unexported fields in the Config: + if tt.conf != nil { + tt.conf.SessionTicketsDisabled = true + } + + tr := &Transport{TLSClientConfig: tt.conf} + got := tr.newTLSConfig(tt.host) + + got.SessionTicketsDisabled = false + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("%d. got %#v; want %#v", i, got, tt.want) + } + } +} + +func TestAuthorityAddr(t *testing.T) { + tests := []struct { + scheme, authority string + want string + }{ + {"http", "foo.com", "foo.com:80"}, + {"https", "foo.com", "foo.com:443"}, + {"https", "foo.com:", "foo.com:443"}, + {"https", "foo.com:1234", "foo.com:1234"}, + {"https", "1.2.3.4:1234", "1.2.3.4:1234"}, + {"https", "1.2.3.4", "1.2.3.4:443"}, + {"https", "1.2.3.4:", "1.2.3.4:443"}, + {"https", "[::1]:1234", "[::1]:1234"}, + {"https", "[::1]", "[::1]:443"}, + {"https", "[::1]:", "[::1]:443"}, + } + for _, tt := range tests { + got := authorityAddr(tt.scheme, tt.authority) + if got != tt.want { + t.Errorf("authorityAddr(%q, %q) = %q; want %q", tt.scheme, tt.authority, got, tt.want) + } + } +} + +// Issue 25009: use Request.GetBody if present, even if it seems like +// we might not need it. Apparently something else can still read from +// the original request body. Data race? In any case, rewinding +// unconditionally on retry is a nicer model anyway and should +// simplify code in the future (after the Go 1.11 freeze) +func TestTransportUsesGetBodyWhenPresent(t *testing.T) { + calls := 0 + someBody := func() io.ReadCloser { + return struct{ io.ReadCloser }{io.NopCloser(bytes.NewReader(nil))} + } + req := &http.Request{ + Body: someBody(), + GetBody: func() (io.ReadCloser, error) { + calls++ + return someBody(), nil + }, + } + + req2, err := shouldRetryRequest(req, errClientConnUnusable) + if err != nil { + t.Fatal(err) + } + if calls != 1 { + t.Errorf("Calls = %d; want 1", calls) + } + if req2 == req { + t.Error("req2 changed") + } + if req2 == nil { + t.Fatal("req2 is nil") + } + if req2.Body == nil { + t.Fatal("req2.Body is nil") + } + if req2.GetBody == nil { + t.Fatal("req2.GetBody is nil") + } + if req2.Body == req.Body { + t.Error("req2.Body unchanged") + } +} + +// Issue 22891: verify that the "https" altproto we register with net/http +// is a certain type: a struct with one field with our *http2.Transport in it. +func TestNoDialH2RoundTripperType(t *testing.T) { + t1 := new(http.Transport) + t2 := new(Transport) + rt := noDialH2RoundTripper{t2} + if err := registerHTTPSProtocol(t1, rt); err != nil { + t.Fatal(err) + } + rv := reflect.ValueOf(rt) + if rv.Type().Kind() != reflect.Struct { + t.Fatalf("kind = %v; net/http expects struct", rv.Type().Kind()) + } + if n := rv.Type().NumField(); n != 1 { + t.Fatalf("fields = %d; net/http expects 1", n) + } + v := rv.Field(0) + if _, ok := v.Interface().(*Transport); !ok { + t.Fatalf("wrong kind %T; want *Transport", v.Interface()) + } +} + +func TestClientConnTooIdle(t *testing.T) { + tests := []struct { + cc func() *ClientConn + want bool + }{ + { + func() *ClientConn { + return &ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)} + }, + true, + }, + { + func() *ClientConn { + return &ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Time{}} + }, + false, + }, + { + func() *ClientConn { + return &ClientConn{idleTimeout: 60 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)} + }, + false, + }, + { + func() *ClientConn { + return &ClientConn{idleTimeout: 0, lastIdle: time.Now().Add(-10 * time.Second)} + }, + false, + }, + } + for i, tt := range tests { + got := tt.cc().tooIdleLocked() + if got != tt.want { + t.Errorf("%d. got %v; want %v", i, got, tt.want) + } + } +} diff --git a/http2/transport_test.go b/http2/transport_test.go index b267483f..d948b881 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package http2 +package http2_test import ( "bufio" @@ -16,7 +16,6 @@ import ( "flag" "fmt" "io" - "io/fs" "log" "math/rand" "net" @@ -36,6 +35,7 @@ import ( "testing/synctest" "time" + . "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" ) @@ -72,6 +72,7 @@ type fakeTLSConn struct { } func (c *fakeTLSConn) ConnectionState() tls.ConnectionState { + const cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F // defined in ciphers.go return tls.ConnectionState{ Version: tls.VersionTLS12, CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, @@ -250,9 +251,12 @@ func TestTransport(t *testing.T) { wantHeader := http.Header{ "Content-Length": []string{"3"}, "Content-Type": []string{"text/plain; charset=utf-8"}, - "Date": []string{"XXX"}, // see cleanDate + "Date": []string{"XXX"}, // see below + } + // replace date with XXX + if d := res.Header["Date"]; len(d) == 1 { + d[0] = "XXX" } - cleanDate(res) if !reflect.DeepEqual(res.Header, wantHeader) { t.Errorf("%d: res Header = %v; want %v", i, res.Header, wantHeader) } @@ -289,7 +293,7 @@ func TestTransportFailureErrorForHTTP1Response(t *testing.T) { }, { name: "with enough frame size to start reading", - maxFrameSize: invalidHTTP1LookingFrameHeader().Length + 1, + maxFrameSize: InvalidHTTP1LookingFrameHeader().Length + 1, }, } { t.Run(tc.name, func(t *testing.T) { @@ -324,7 +328,7 @@ func testTransportReusesConns(t *testing.T, useClient, wantSame bool, modReq fun }) tr := &Transport{TLSClientConfig: tlsConfigInsecure} if useClient { - tr.ConnPool = noDialClientConnPool{new(clientConnPool)} + tr.ConnPool = NewNoDialClientConnPool() } defer tr.CloseIdleConnections() get := func() string { @@ -612,45 +616,6 @@ func randString(n int) string { return string(b) } -type panicReader struct{} - -func (panicReader) Read([]byte) (int, error) { panic("unexpected Read") } -func (panicReader) Close() error { panic("unexpected Close") } - -func TestActualContentLength(t *testing.T) { - tests := []struct { - req *http.Request - want int64 - }{ - // Verify we don't read from Body: - 0: { - req: &http.Request{Body: panicReader{}}, - want: -1, - }, - // nil Body means 0, regardless of ContentLength: - 1: { - req: &http.Request{Body: nil, ContentLength: 5}, - want: 0, - }, - // ContentLength is used if set. - 2: { - req: &http.Request{Body: panicReader{}, ContentLength: 5}, - want: 5, - }, - // http.NoBody means 0, not -1. - 3: { - req: &http.Request{Body: http.NoBody}, - want: 0, - }, - } - for i, tt := range tests { - got := actualContentLength(tt.req) - if got != tt.want { - t.Errorf("test[%d]: got %d; want %d", i, got, tt.want) - } - } -} - func TestTransportBody(t *testing.T) { bodyTests := []struct { body string @@ -1180,11 +1145,10 @@ func testTransportResPatternBubble(t testing.TB, expect100Continue, resHeader he func TestTransportUnknown1xx(t *testing.T) { synctestTest(t, testTransportUnknown1xx) } func testTransportUnknown1xx(t testing.TB) { var buf bytes.Buffer - defer func() { got1xxFuncForTests = nil }() - got1xxFuncForTests = func(code int, header textproto.MIMEHeader) error { + SetTestHookGot1xx(t, func(code int, header textproto.MIMEHeader) error { fmt.Fprintf(&buf, "code=%d header=%v\n", code, header) return nil - } + }) tc := newTestClientConn(t) tc.greet() @@ -1268,7 +1232,7 @@ func TestTransportInvalidTrailer_Pseudo2(t *testing.T) { testTransportInvalidTrailer_Pseudo(t, splitHeader) } func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) { - testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), + testInvalidTrailer(t, trailers, PseudoHeaderError(":colon"), ":colon", "foo", "foo", "bar", ) @@ -1281,18 +1245,18 @@ func TestTransportInvalidTrailer_Capital2(t *testing.T) { testTransportInvalidTrailer_Capital(t, splitHeader) } func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) { - testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), + testInvalidTrailer(t, trailers, HeaderFieldNameError("Capital"), "foo", "bar", "Capital", "bad", ) } func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) { - testInvalidTrailer(t, oneHeader, headerFieldNameError(""), + testInvalidTrailer(t, oneHeader, HeaderFieldNameError(""), "", "bad", ) } func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) { - testInvalidTrailer(t, oneHeader, headerFieldValueError("x"), + testInvalidTrailer(t, oneHeader, HeaderFieldValueError("x"), "x", "has\nnewline", ) } @@ -1483,13 +1447,14 @@ func testTransportChecksRequestHeaderListSize(t testing.TB) { headerListSizeForRequest := func(req *http.Request) (size uint64) { const addGzipHeader = true const peerMaxHeaderListSize = 0xffffffffffffffff - _, err := encodeRequestHeaders(req, addGzipHeader, peerMaxHeaderListSize, func(name, value string) { + _, err := EncodeRequestHeaders(req, addGzipHeader, peerMaxHeaderListSize, func(name, value string) { hf := hpack.HeaderField{Name: name, Value: value} size += uint64(hf.Size()) }) if err != nil { t.Fatal(err) } + fmt.Println(size) return size } // Create a new Request for each test, rather than reusing the @@ -1521,26 +1486,26 @@ func testTransportChecksRequestHeaderListSize(t testing.TB) { req = newRequest() req.Header = make(http.Header) padHeaders(t, req.Header, peerSize, filler) - checkRoundTrip(req, errRequestHeaderListSize, "Headers over limit") + checkRoundTrip(req, ErrRequestHeaderListSize, "Headers over limit") // Push trailers over the limit. req = newRequest() req.Trailer = make(http.Header) padHeaders(t, req.Trailer, peerSize+1, filler) - checkRoundTrip(req, errRequestHeaderListSize, "Trailers over limit") + checkRoundTrip(req, ErrRequestHeaderListSize, "Trailers over limit") // Send headers with a single large value. req = newRequest() filler = strings.Repeat("*", int(peerSize)) req.Header = make(http.Header) req.Header.Set("Big", filler) - checkRoundTrip(req, errRequestHeaderListSize, "Single large header") + checkRoundTrip(req, ErrRequestHeaderListSize, "Single large header") // Send trailers with a single large value. req = newRequest() req.Trailer = make(http.Header) req.Trailer.Set("Big", filler) - checkRoundTrip(req, errRequestHeaderListSize, "Single large trailer") + checkRoundTrip(req, ErrRequestHeaderListSize, "Single large trailer") } func TestTransportChecksResponseHeaderListSize(t *testing.T) { @@ -1578,7 +1543,7 @@ func testTransportChecksResponseHeaderListSize(t testing.TB) { if e, ok := err.(StreamError); ok { err = e.Cause } - if err != errResponseHeaderListSize { + if err != ErrResponseHeaderListSize { size := int64(0) if res != nil { res.Body.Close() @@ -1706,9 +1671,6 @@ func TestTransportDisableKeepAlives(t *testing.T) { connClosed := make(chan struct{}) // closed on tls.Conn.Close tr := &Transport{ - t1: &http.Transport{ - DisableKeepAlives: true, - }, TLSClientConfig: tlsConfigInsecure, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { tc, err := tls.Dial(network, addr, cfg) @@ -1718,6 +1680,7 @@ func TestTransportDisableKeepAlives(t *testing.T) { return ¬eCloseConn{Conn: tc, closefn: func() { close(connClosed) }}, nil }, } + tr.TestTransport().DisableKeepAlives = true c := &http.Client{Transport: tr} res, err := c.Get(ts.URL) if err != nil { @@ -1750,9 +1713,6 @@ func TestTransportDisableKeepAlives_Concurrency(t *testing.T) { var dials int32 var conns sync.WaitGroup tr := &Transport{ - t1: &http.Transport{ - DisableKeepAlives: true, - }, TLSClientConfig: tlsConfigInsecure, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { tc, err := tls.Dial(network, addr, cfg) @@ -1764,6 +1724,7 @@ func TestTransportDisableKeepAlives_Concurrency(t *testing.T) { return ¬eCloseConn{Conn: tc, closefn: func() { conns.Done() }}, nil }, } + tr.TestTransport().DisableKeepAlives = true c := &http.Client{Transport: tr} var reqs sync.WaitGroup const N = 20 @@ -1834,10 +1795,8 @@ func TestTransportResponseHeaderTimeout_Body(t *testing.T) { func testTransportResponseHeaderTimeout(t testing.TB, body bool) { const bodySize = 4 << 20 - tc := newTestClientConn(t, func(tr *Transport) { - tr.t1 = &http.Transport{ - ResponseHeaderTimeout: 5 * time.Millisecond, - } + tc := newTestClientConn(t, func(t1 *http.Transport) { + t1.ResponseHeaderTimeout = 5 * time.Millisecond }) tc.greet() @@ -1915,10 +1874,8 @@ func TestTransportDisableCompression(t *testing.T) { tr := &Transport{ TLSClientConfig: tlsConfigInsecure, - t1: &http.Transport{ - DisableCompression: true, - }, } + tr.TestTransport().DisableCompression = true defer tr.CloseIdleConnections() req, err := http.NewRequest("GET", ts.URL, nil) @@ -2179,115 +2136,6 @@ func TestTransportFailsOnInvalidHeadersAndTrailers(t *testing.T) { } } -// Tests that gzipReader doesn't crash on a second Read call following -// the first Read call's gzip.NewReader returning an error. -func TestGzipReader_DoubleReadCrash(t *testing.T) { - gz := &gzipReader{ - body: io.NopCloser(strings.NewReader("0123456789")), - } - var buf [1]byte - n, err1 := gz.Read(buf[:]) - if n != 0 || !strings.Contains(fmt.Sprint(err1), "invalid header") { - t.Fatalf("Read = %v, %v; want 0, invalid header", n, err1) - } - n, err2 := gz.Read(buf[:]) - if n != 0 || err2 != err1 { - t.Fatalf("second Read = %v, %v; want 0, %v", n, err2, err1) - } -} - -func TestGzipReader_ReadAfterClose(t *testing.T) { - body := bytes.Buffer{} - w := gzip.NewWriter(&body) - w.Write([]byte("012345679")) - w.Close() - gz := &gzipReader{ - body: io.NopCloser(&body), - } - var buf [1]byte - n, err := gz.Read(buf[:]) - if n != 1 || err != nil { - t.Fatalf("first Read = %v, %v; want 1, nil", n, err) - } - if err := gz.Close(); err != nil { - t.Fatalf("gz Close error: %v", err) - } - n, err = gz.Read(buf[:]) - if n != 0 || err != fs.ErrClosed { - t.Fatalf("Read after close = %v, %v; want 0, fs.ErrClosed", n, err) - } -} - -func TestTransportNewTLSConfig(t *testing.T) { - tests := [...]struct { - conf *tls.Config - host string - want *tls.Config - }{ - // Normal case. - 0: { - conf: nil, - host: "foo.com", - want: &tls.Config{ - ServerName: "foo.com", - NextProtos: []string{NextProtoTLS}, - }, - }, - - // User-provided name (bar.com) takes precedence: - 1: { - conf: &tls.Config{ - ServerName: "bar.com", - }, - host: "foo.com", - want: &tls.Config{ - ServerName: "bar.com", - NextProtos: []string{NextProtoTLS}, - }, - }, - - // NextProto is prepended: - 2: { - conf: &tls.Config{ - NextProtos: []string{"foo", "bar"}, - }, - host: "example.com", - want: &tls.Config{ - ServerName: "example.com", - NextProtos: []string{NextProtoTLS, "foo", "bar"}, - }, - }, - - // NextProto is not duplicated: - 3: { - conf: &tls.Config{ - NextProtos: []string{"foo", "bar", NextProtoTLS}, - }, - host: "example.com", - want: &tls.Config{ - ServerName: "example.com", - NextProtos: []string{"foo", "bar", NextProtoTLS}, - }, - }, - } - for i, tt := range tests { - // Ignore the session ticket keys part, which ends up populating - // unexported fields in the Config: - if tt.conf != nil { - tt.conf.SessionTicketsDisabled = true - } - - tr := &Transport{TLSClientConfig: tt.conf} - got := tr.newTLSConfig(tt.host) - - got.SessionTicketsDisabled = false - - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("%d. got %#v; want %#v", i, got, tt.want) - } - } -} - // The Google GFE responds to HEAD requests with a HEADERS frame // without END_STREAM, followed by a 0-length DATA frame with // END_STREAM. Make sure we don't get confused by that. (We did.) @@ -2575,7 +2423,7 @@ func testTransportReturnsUnusedFlowControl(t testing.TB, oneDataFrame bool) { tc.wantUnorderedFrames( func(f *RSTStreamFrame) bool { if f.ErrCode != ErrCodeCancel { - t.Fatalf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f)) + t.Fatalf("Expected a RSTStreamFrame with code cancel; got %v", SummarizeFrame(f)) } if !oneDataFrame { // Send the remaining data now. @@ -2589,7 +2437,7 @@ func testTransportReturnsUnusedFlowControl(t testing.TB, oneDataFrame bool) { t.Fatalf("Got WindowUpdateFrame, don't expect one yet") } if f.Increment != 5000 { - t.Fatalf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f)) + t.Fatalf("Expected WindowUpdateFrames for 5000 bytes; got %v", SummarizeFrame(f)) } return true }, @@ -2640,7 +2488,7 @@ func testTransportAdjustsFlowControl(t testing.TB) { gotBytes += int64(len(f.Data())) // After we've got half the client's initial flow control window's worth // of request body data, give it just enough flow control to finish. - if gotBytes >= initialWindowSize/2 { + if gotBytes >= InitialWindowSize/2 { break } } @@ -2733,14 +2581,14 @@ func testTransportReturnsErrorOnBadResponseHeaders(t testing.TB) { }) err := rt.err() - want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")} + want := StreamError{1, ErrCodeProtocol, HeaderFieldNameError(" content-type")} if !reflect.DeepEqual(err, want) { t.Fatalf("RoundTrip error = %#v; want %#v", err, want) } fr := readFrame[*RSTStreamFrame](t, tc) if fr.StreamID != 1 || fr.ErrCode != ErrCodeProtocol { - t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr)) + t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", SummarizeFrame(fr)) } } @@ -2894,12 +2742,12 @@ func TestTransportRequestPathPseudo(t *testing.T) { const addGzipHeader = false const peerMaxHeaderListSize = 0xffffffffffffffff - _, err := encodeRequestHeaders(tt.req, addGzipHeader, peerMaxHeaderListSize, func(name, value string) { + _, err := EncodeRequestHeaders(tt.req, addGzipHeader, peerMaxHeaderListSize, func(name, value string) { henc.WriteField(hpack.HeaderField{Name: name, Value: value}) }) hdrs := hbuf.Bytes() var got result - hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) { + hpackDec := hpack.NewDecoder(InitialHeaderTableSize, func(f hpack.HeaderField) { if f.Name == ":path" { got.path = f.Value } @@ -2933,7 +2781,7 @@ func testRoundTripDoesntConsumeRequestBodyEarly(t testing.TB) { const body = "foo" req, _ := http.NewRequest("POST", "http://foo.com/", io.NopCloser(strings.NewReader(body))) rt := tc.roundTrip(req) - if err := rt.err(); err != errClientConnNotEstablished { + if err := rt.err(); err != ErrClientConnNotEstablished { t.Fatalf("RoundTrip = %v; want errClientConnNotEstablished", err) } @@ -2951,7 +2799,7 @@ func TestClientConnPing(t *testing.T) { tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() ctx := context.Background() - cc, err := tr.dialClientConn(ctx, ts.Listener.Addr().String(), false) + cc, err := tr.DialClientConn(ctx, ts.Listener.Addr().String(), false) if err != nil { t.Fatal(err) } @@ -3430,7 +3278,7 @@ func TestTransportMaxFrameReadSize(t *testing.T) { want: 64000, }, { maxReadFrameSize: 1024, - want: minMaxFrameSize, + want: MinMaxFrameSize, }} { synctestSubtest(t, fmt.Sprint(test.maxReadFrameSize), func(t testing.TB) { tc := newTestClientConn(t, func(tr *Transport) { @@ -3608,9 +3456,7 @@ func testTransportMaxDecoderHeaderTableSize(t testing.TB) { } tc.writeSettings(Setting{SettingHeaderTableSize, resSize}) - tc.cc.mu.Lock() - defer tc.cc.mu.Unlock() - if got, want := tc.cc.peerMaxHeaderTableSize, resSize; got != want { + if got, want := tc.cc.TestPeerMaxHeaderTableSize(), resSize; got != want { t.Fatalf("peerHeaderTableSize = %d, want %d", got, want) } } @@ -3625,35 +3471,11 @@ func testTransportMaxEncoderHeaderTableSize(t testing.TB) { }) tc.greet(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize}) - if got, want := tc.cc.henc.MaxDynamicTableSize(), tc.tr.MaxEncoderHeaderTableSize; got != want { + if got, want := tc.cc.TestHPACKEncoder().MaxDynamicTableSize(), tc.tr.MaxEncoderHeaderTableSize; got != want { t.Fatalf("henc.MaxDynamicTableSize() = %d, want %d", got, want) } } -func TestAuthorityAddr(t *testing.T) { - tests := []struct { - scheme, authority string - want string - }{ - {"http", "foo.com", "foo.com:80"}, - {"https", "foo.com", "foo.com:443"}, - {"https", "foo.com:", "foo.com:443"}, - {"https", "foo.com:1234", "foo.com:1234"}, - {"https", "1.2.3.4:1234", "1.2.3.4:1234"}, - {"https", "1.2.3.4", "1.2.3.4:443"}, - {"https", "1.2.3.4:", "1.2.3.4:443"}, - {"https", "[::1]:1234", "[::1]:1234"}, - {"https", "[::1]", "[::1]:443"}, - {"https", "[::1]:", "[::1]:443"}, - } - for _, tt := range tests { - got := authorityAddr(tt.scheme, tt.authority) - if got != tt.want { - t.Errorf("authorityAddr(%q, %q) = %q; want %q", tt.scheme, tt.authority, got, tt.want) - } - } -} - // Issue 20448: stop allocating for DATA frames' payload after // Response.Body.Close is called. func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) { @@ -3724,7 +3546,7 @@ func testTransportNoBodyMeansNoDATA(t testing.TB) { } func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) { - disableGoroutineTracking(b) + DisableGoroutineTracking(b) b.ReportAllocs() ts := newTestServer(b, func(w http.ResponseWriter, r *http.Request) { @@ -3839,7 +3661,7 @@ func BenchmarkDownloadFrameSize(b *testing.B) { b.Run("512k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 512*1024) }) } func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) { - disableGoroutineTracking(b) + DisableGoroutineTracking(b) const transferSize = 1024 * 1024 * 1024 // must be multiple of 1M b.ReportAllocs() ts := newTestServer(b, @@ -3886,7 +3708,7 @@ func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) { } func BenchmarkClientGzip(b *testing.B) { - disableGoroutineTracking(b) + DisableGoroutineTracking(b) b.ReportAllocs() const responseSize = 1024 * 1024 @@ -3950,7 +3772,7 @@ func testClientConnCloseAtHeaders(t testing.TB) { tc.cc.Close() synctest.Wait() - if err := rt.err(); err != errClientConnForceClosed { + if err := rt.err(); err != ErrClientConnForceClosed { t.Fatalf("RoundTrip error = %v, want errClientConnForceClosed", err) } } @@ -4060,70 +3882,6 @@ func testClientConnShutdownCancel(t testing.TB) { } } -// Issue 25009: use Request.GetBody if present, even if it seems like -// we might not need it. Apparently something else can still read from -// the original request body. Data race? In any case, rewinding -// unconditionally on retry is a nicer model anyway and should -// simplify code in the future (after the Go 1.11 freeze) -func TestTransportUsesGetBodyWhenPresent(t *testing.T) { - calls := 0 - someBody := func() io.ReadCloser { - return struct{ io.ReadCloser }{io.NopCloser(bytes.NewReader(nil))} - } - req := &http.Request{ - Body: someBody(), - GetBody: func() (io.ReadCloser, error) { - calls++ - return someBody(), nil - }, - } - - req2, err := shouldRetryRequest(req, errClientConnUnusable) - if err != nil { - t.Fatal(err) - } - if calls != 1 { - t.Errorf("Calls = %d; want 1", calls) - } - if req2 == req { - t.Error("req2 changed") - } - if req2 == nil { - t.Fatal("req2 is nil") - } - if req2.Body == nil { - t.Fatal("req2.Body is nil") - } - if req2.GetBody == nil { - t.Fatal("req2.GetBody is nil") - } - if req2.Body == req.Body { - t.Error("req2.Body unchanged") - } -} - -// Issue 22891: verify that the "https" altproto we register with net/http -// is a certain type: a struct with one field with our *http2.Transport in it. -func TestNoDialH2RoundTripperType(t *testing.T) { - t1 := new(http.Transport) - t2 := new(Transport) - rt := noDialH2RoundTripper{t2} - if err := registerHTTPSProtocol(t1, rt); err != nil { - t.Fatal(err) - } - rv := reflect.ValueOf(rt) - if rv.Type().Kind() != reflect.Struct { - t.Fatalf("kind = %v; net/http expects struct", rv.Type().Kind()) - } - if n := rv.Type().NumField(); n != 1 { - t.Fatalf("fields = %d; net/http expects 1", n) - } - v := rv.Field(0) - if _, ok := v.Interface().(*Transport); !ok { - t.Fatalf("wrong kind %T; want *Transport", v.Interface()) - } -} - type errReader struct { body []byte err error @@ -4254,46 +4012,8 @@ func testTransportBodyLargerThanSpecifiedContentLength(t testing.TB, body *chunk req, _ := http.NewRequest("POST", ts.URL, body) req.ContentLength = contentLen _, err := tr.RoundTrip(req) - if err != errReqBodyTooLong { - t.Fatalf("expected %v, got %v", errReqBodyTooLong, err) - } -} - -func TestClientConnTooIdle(t *testing.T) { - tests := []struct { - cc func() *ClientConn - want bool - }{ - { - func() *ClientConn { - return &ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)} - }, - true, - }, - { - func() *ClientConn { - return &ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Time{}} - }, - false, - }, - { - func() *ClientConn { - return &ClientConn{idleTimeout: 60 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)} - }, - false, - }, - { - func() *ClientConn { - return &ClientConn{idleTimeout: 0, lastIdle: time.Now().Add(-10 * time.Second)} - }, - false, - }, - } - for i, tt := range tests { - got := tt.cc().tooIdleLocked() - if got != tt.want { - t.Errorf("%d. got %v; want %v", i, got, tt.want) - } + if err != ErrReqBodyTooLong { + t.Fatalf("expected %v, got %v", ErrReqBodyTooLong, err) } } @@ -4347,7 +4067,7 @@ func testTransportRoundtripCloseOnWriteError(t testing.TB) { } rt2 := tc.roundTrip(req) - if err := rt2.err(); err != errClientConnUnusable { + if err := rt2.err(); err != ErrClientConnUnusable { t.Fatalf("RoundTrip error %v, want errClientConnUnusable", err) } } @@ -4396,6 +4116,10 @@ func TestTransportBodyRewindRace(t *testing.T) { wg.Wait() } +type errorReader struct{ err error } + +func (r errorReader) Read(p []byte) (int, error) { return 0, r.err } + // Issue 42498: A request with a body will never be sent if the stream is // reset prior to sending any data. func TestTransportServerResetStreamAtHeaders(t *testing.T) { @@ -4811,7 +4535,7 @@ func TestTransportCloseRequestBody(t *testing.T) { tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() ctx := context.Background() - cc, err := tr.dialClientConn(ctx, ts.Listener.Addr().String(), false) + cc, err := tr.DialClientConn(ctx, ts.Listener.Addr().String(), false) if err != nil { t.Fatal(err) } @@ -4903,7 +4627,7 @@ func TestClientConnReservations(t *testing.T) { synctestTest(t, testClientConnRe func testClientConnReservations(t testing.TB) { tc := newTestClientConn(t) tc.greet( - Setting{ID: SettingMaxConcurrentStreams, Val: initialMaxConcurrentStreams}, + Setting{ID: SettingMaxConcurrentStreams, Val: InitialMaxConcurrentStreams}, ) doRoundTrip := func() { @@ -4922,11 +4646,11 @@ func testClientConnReservations(t testing.TB) { } n := 0 - for n <= initialMaxConcurrentStreams && tc.cc.ReserveNewRequest() { + for n <= InitialMaxConcurrentStreams && tc.cc.ReserveNewRequest() { n++ } - if n != initialMaxConcurrentStreams { - t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams) + if n != InitialMaxConcurrentStreams { + t.Errorf("did %v reservations; want %v", n, InitialMaxConcurrentStreams) } doRoundTrip() n2 := 0 @@ -4943,7 +4667,7 @@ func testClientConnReservations(t testing.TB) { } n2 = 0 - for n2 <= initialMaxConcurrentStreams && tc.cc.ReserveNewRequest() { + for n2 <= InitialMaxConcurrentStreams && tc.cc.ReserveNewRequest() { n2++ } if n2 != n { @@ -5581,7 +5305,7 @@ func testTransportSendPingWithReset(t testing.TB) { // Start several requests. var rts []*testRoundTrip for i := range maxConcurrent + 1 { - req := must(http.NewRequest("GET", "https://dummy.tld/", nil)) + req := Must(http.NewRequest("GET", "https://dummy.tld/", nil)) rt := tc.roundTrip(req) if i >= maxConcurrent { tc.wantIdle() @@ -5626,7 +5350,7 @@ func testTransportNoPingAfterResetWithFrames(t testing.TB) { // Start request #1. // The server immediately responds with request headers. - req1 := must(http.NewRequest("GET", "https://dummy.tld/", nil)) + req1 := Must(http.NewRequest("GET", "https://dummy.tld/", nil)) rt1 := tc.roundTrip(req1) tc.wantFrameType(FrameHeaders) tc.writeHeaders(HeadersFrameParam{ @@ -5640,7 +5364,7 @@ func testTransportNoPingAfterResetWithFrames(t testing.TB) { // Start request #2. // The connection is at its concurrency limit, so this request is not yet sent. - req2 := must(http.NewRequest("GET", "https://dummy.tld/", nil)) + req2 := Must(http.NewRequest("GET", "https://dummy.tld/", nil)) rt2 := tc.roundTrip(req2) tc.wantIdle() @@ -5670,7 +5394,7 @@ func testTransportSendNoMoreThanOnePingWithReset(t testing.TB) { makeAndResetRequest := func() { t.Helper() ctx, cancel := context.WithCancel(context.Background()) - req := must(http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)) + req := Must(http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)) rt := tc.roundTrip(req) tc.wantFrameType(FrameHeaders) cancel() @@ -5740,7 +5464,7 @@ func testTransportConnBecomesUnresponsive(t testing.TB) { const maxConcurrent = 3 t.Logf("first request opens a new connection and succeeds") - req1 := must(http.NewRequest("GET", "https://dummy.tld/", nil)) + req1 := Must(http.NewRequest("GET", "https://dummy.tld/", nil)) rt1 := tt.roundTrip(req1) tc1 := tt.getConn() tc1.wantFrameType(FrameSettings) @@ -5765,7 +5489,7 @@ func testTransportConnBecomesUnresponsive(t testing.TB) { for i := 0; i < maxConcurrent; i++ { t.Logf("request %v receives no response and is canceled", i) ctx, cancel := context.WithCancel(context.Background()) - req := must(http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)) + req := Must(http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)) tt.roundTrip(req) if tt.hasConn() { t.Fatalf("new connection created; expect existing conn to be reused") @@ -5781,7 +5505,7 @@ func testTransportConnBecomesUnresponsive(t testing.TB) { // The conn has hit its concurrency limit. // The next request is sent on a new conn. - req2 := must(http.NewRequest("GET", "https://dummy.tld/", nil)) + req2 := Must(http.NewRequest("GET", "https://dummy.tld/", nil)) rt2 := tt.roundTrip(req2) tc2 := tt.getConn() tc2.wantFrameType(FrameSettings) @@ -5820,7 +5544,7 @@ func testTransportTLSNextProtoConnOK(t testing.TB) { // Send a request on the Transport. // It uses the conn we provided. - req := must(http.NewRequest("GET", "https://dummy.tld/", nil)) + req := Must(http.NewRequest("GET", "https://dummy.tld/", nil)) rt := tt.roundTrip(req) tc.wantHeaders(wantHeader{ streamID: 1, @@ -5867,7 +5591,7 @@ func testTransportTLSNextProtoConnImmediateFailureUsed(t testing.TB) { // Send a request on the Transport. // // It should fail, because we have no usable connections, but not with ErrNoCachedConn. - req := must(http.NewRequest("GET", "https://dummy.tld/", nil)) + req := Must(http.NewRequest("GET", "https://dummy.tld/", nil)) rt := tt.roundTrip(req) if err := rt.err(); err == nil || errors.Is(err, ErrNoCachedConn) { t.Fatalf("RoundTrip with broken conn: got %v, want an error other than ErrNoCachedConn", err) @@ -5910,7 +5634,7 @@ func testTransportTLSNextProtoConnIdleTimoutBeforeUse(t testing.TB) { // Send a request on the Transport. // // It should fail with ErrNoCachedConn. - req := must(http.NewRequest("GET", "https://dummy.tld/", nil)) + req := Must(http.NewRequest("GET", "https://dummy.tld/", nil)) rt := tt.roundTrip(req) if err := rt.err(); !errors.Is(err, ErrNoCachedConn) { t.Fatalf("RoundTrip with conn closed for idleness: got %v, want ErrNoCachedConn", err) @@ -5946,7 +5670,7 @@ func testTransportTLSNextProtoConnImmediateFailureUnused(t testing.TB) { // Send a request on the Transport. // // It should fail with ErrNoCachedConn, because the pool contains no conns. - req := must(http.NewRequest("GET", "https://dummy.tld/", nil)) + req := Must(http.NewRequest("GET", "https://dummy.tld/", nil)) rt := tt.roundTrip(req) if err := rt.err(); !errors.Is(err, ErrNoCachedConn) { t.Fatalf("RoundTrip after broken conn expires: got %v, want ErrNoCachedConn", err) @@ -5954,7 +5678,7 @@ func testTransportTLSNextProtoConnImmediateFailureUnused(t testing.TB) { } func TestExtendedConnectClientWithServerSupport(t *testing.T) { - setForTest(t, &disableExtendedConnectProtocol, false) + SetDisableExtendedConnectProtocol(t, false) ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { if r.Header.Get(":protocol") != "extended-connect" { t.Fatalf("unexpected :protocol header received") @@ -5993,7 +5717,7 @@ func TestExtendedConnectClientWithServerSupport(t *testing.T) { } func TestExtendedConnectClientWithoutServerSupport(t *testing.T) { - setForTest(t, &disableExtendedConnectProtocol, true) + SetDisableExtendedConnectProtocol(t, true) ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { io.Copy(w, r.Body) }) @@ -6016,7 +5740,7 @@ func TestExtendedConnectClientWithoutServerSupport(t *testing.T) { }() _, err := tr.RoundTrip(req) - if !errors.Is(err, errExtendedConnectNotSupported) { + if !errors.Is(err, ErrExtendedConnectNotSupported) { t.Fatalf("expected error errExtendedConnectNotSupported, got: %v", err) } }