diff --git a/internal/http3/body.go b/internal/http3/body.go index d66c1959..6db183be 100644 --- a/internal/http3/body.go +++ b/internal/http3/body.go @@ -44,18 +44,34 @@ type bodyWriter struct { enc *qpackEncoder // QPACK encoder used by the connection. } -func (w *bodyWriter) Write(p []byte) (n int, err error) { - if w.remain >= 0 && int64(len(p)) > w.remain { +func (w *bodyWriter) write(ps ...[]byte) (n int, err error) { + var size int64 + for _, p := range ps { + size += int64(len(p)) + } + // If write is called with empty byte slices, just return instead of + // sending out a DATA frame containing nothing. + if size == 0 { + return 0, nil + } + if w.remain >= 0 && size > w.remain { return 0, &streamError{ code: errH3InternalError, message: w.name + " body longer than specified content length", } } w.st.writeVarint(int64(frameTypeData)) - w.st.writeVarint(int64(len(p))) - n, err = w.st.Write(p) - if w.remain >= 0 { - w.remain -= int64(n) + w.st.writeVarint(size) + for _, p := range ps { + var n2 int + n2, err = w.st.Write(p) + n += n2 + if w.remain >= 0 { + w.remain -= int64(n) + } + if err != nil { + break + } } if w.flush && err == nil { err = w.st.Flush() @@ -66,6 +82,10 @@ func (w *bodyWriter) Write(p []byte) (n int, err error) { return n, err } +func (w *bodyWriter) Write(p []byte) (n int, err error) { + return w.write(p) +} + func (w *bodyWriter) Close() error { if w.remain > 0 { return errors.New(w.name + " body shorter than specified content length") diff --git a/internal/http3/server.go b/internal/http3/server.go index 664250d9..6536eabe 100644 --- a/internal/http3/server.go +++ b/internal/http3/server.go @@ -13,6 +13,7 @@ import ( "strconv" "strings" "sync" + "time" "golang.org/x/net/http/httpguts" "golang.org/x/net/internal/httpcommon" @@ -253,10 +254,11 @@ func (sc *serverConn) handleRequestStream(st *stream) error { defer req.Body.Close() rw := &responseWriter{ - st: st, - headers: make(http.Header), - trailer: make(http.Header), - isHeadResp: req.Method == "HEAD", + st: st, + headers: make(http.Header), + trailer: make(http.Header), + bb: make(bodyBuffer, 0, defaultBodyBufferCap), + cannotHaveBody: req.Method == "HEAD", bw: &bodyWriter{ st: st, remain: -1, @@ -268,8 +270,18 @@ func (sc *serverConn) handleRequestStream(st *stream) error { defer rw.close() if reqInfo.NeedsContinue { req.Body.(*bodyReader).send100Continue = func() { - rw.WriteHeader(http.StatusContinue) - rw.Flush() + rw.mu.Lock() + defer rw.mu.Unlock() + if rw.wroteHeader { + return + } + encHeaders := rw.bw.enc.encode(func(f func(itype indexType, name, value string)) { + f(mayIndex, ":status", strconv.Itoa(http.StatusContinue)) + }) + rw.st.writeVarint(int64(frameTypeHeaders)) + rw.st.writeVarint(int64(len(encHeaders))) + rw.st.Write(encHeaders) + rw.st.Flush() } } @@ -290,14 +302,31 @@ func (sc *serverConn) abort(err error) { } } +// responseCanHaveBody reports whether a given response status code permits a +// body. See RFC 7230, section 3.3. +func responseCanHaveBody(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true +} + type responseWriter struct { - st *stream - bw *bodyWriter - mu sync.Mutex - headers http.Header - trailer http.Header - wroteHeader bool // Non-1xx header has been (logically) written. - isHeadResp bool // response is for a HEAD request. + st *stream + bw *bodyWriter + mu sync.Mutex + headers http.Header + trailer http.Header + bb bodyBuffer + wroteHeader bool // Non-1xx header has been (logically) written. + statusCode int // Status of the response that will be sent in HEADERS frame. + statusCodeSet bool // Status of the response has been set via a call to WriteHeader. + cannotHaveBody bool // Response should not have a body (e.g. response to a HEAD request). } func (rw *responseWriter) Header() http.Header { @@ -322,11 +351,13 @@ func (rw *responseWriter) prepareTrailerForWriteLocked() { // Caller must hold rw.mu. If rw.wroteHeader is true, calling this method is a // no-op. -func (rw *responseWriter) writeHeaderLockedOnce(statusCode int) { +func (rw *responseWriter) writeHeaderLockedOnce() { if rw.wroteHeader { return } - + if !responseCanHaveBody(rw.statusCode) { + rw.cannotHaveBody = true + } // If there is any Trailer declared in headers, save them so we know which // trailers have been pre-declared. Also, write back the extracted value, // which is canonicalized, to rw.Header for consistency. @@ -335,10 +366,9 @@ func (rw *responseWriter) writeHeaderLockedOnce(statusCode int) { rw.headers.Set("Trailer", strings.Join(slices.Sorted(maps.Keys(rw.trailer)), ", ")) } - enc := &qpackEncoder{} - enc.init() - encHeaders := enc.encode(func(f func(itype indexType, name, value string)) { - f(mayIndex, ":status", strconv.Itoa(statusCode)) + rw.bb.inferHeader(rw.headers, rw.statusCode) + encHeaders := rw.bw.enc.encode(func(f func(itype indexType, name, value string)) { + f(mayIndex, ":status", strconv.Itoa(rw.statusCode)) for name, values := range rw.headers { if !httpguts.ValidHeaderFieldName(name) { continue @@ -352,45 +382,128 @@ func (rw *responseWriter) writeHeaderLockedOnce(statusCode int) { } } }) + rw.st.writeVarint(int64(frameTypeHeaders)) rw.st.writeVarint(int64(len(encHeaders))) rw.st.Write(encHeaders) - if statusCode >= http.StatusOK { + if rw.statusCode >= http.StatusOK { rw.wroteHeader = true } } func (rw *responseWriter) WriteHeader(statusCode int) { + // TODO: handle sending informational status headers (e.g. 103). rw.mu.Lock() defer rw.mu.Unlock() - rw.writeHeaderLockedOnce(statusCode) + if rw.statusCodeSet { + return + } + rw.statusCodeSet = true + rw.statusCode = statusCode } func (rw *responseWriter) Write(b []byte) (int, error) { + // Calling Write implicitly calls WriteHeader(200) if WriteHeader has not + // been called before. + rw.WriteHeader(http.StatusOK) rw.mu.Lock() defer rw.mu.Unlock() - rw.writeHeaderLockedOnce(http.StatusOK) - if rw.isHeadResp { - return 0, nil + + // If b fits entirely in our body buffer, save it to the buffer and return + // early so we can coalesce small writes. + // As a special case, we always want to save b to the buffer even when b is + // big if we had yet to write our header, so we can infer headers like + // "Content-Type" with as much information as possible. + initialBufLen := len(rw.bb) + if !rw.wroteHeader || len(b) <= cap(rw.bb)-len(rw.bb) { + b = rw.bb.write(b) + if len(b) == 0 { + return len(b), nil + } } - return rw.bw.Write(b) + + // Reaching this point means that our buffer has been sufficiently filled. + // Therefore, we now want to: + // 1. Infer and write response headers based on our body buffer, if not + // done yet. + // 2. Write our body buffer and the rest of b (if any). + // 3. Reset the current body buffer so it can be used again. + rw.writeHeaderLockedOnce() + if rw.cannotHaveBody { + return len(b), nil + } + if n, err := rw.bw.write(rw.bb, b); err != nil { + return max(0, n-initialBufLen), err + } + rw.bb.discard() + return len(b), nil } func (rw *responseWriter) Flush() { + // Calling Flush implicitly calls WriteHeader(200) if WriteHeader has not + // been called before. + rw.WriteHeader(http.StatusOK) rw.mu.Lock() - rw.writeHeaderLockedOnce(http.StatusOK) + rw.writeHeaderLockedOnce() + if !rw.cannotHaveBody { + rw.bw.Write(rw.bb) + rw.bb.discard() + } rw.mu.Unlock() - rw.bw.st.Flush() + rw.st.Flush() } func (rw *responseWriter) close() error { + rw.Flush() rw.mu.Lock() defer rw.mu.Unlock() - rw.writeHeaderLockedOnce(http.StatusOK) rw.prepareTrailerForWriteLocked() - if err := rw.bw.Close(); err != nil { return err } return rw.st.stream.Close() } + +// defaultBodyBufferCap is the default number of bytes of body that we are +// willing to save in a buffer for the sake of inferring headers and coalescing +// small writes. 512 was chosen to be consistent with how much +// http.DetectContentType is willing to read. +const defaultBodyBufferCap = 512 + +// bodyBuffer is a buffer used to store body content of a response. +type bodyBuffer []byte + +// write writes b to the buffer. It returns a new slice of b, which contains +// any remaining data that could not be written to the buffer, if any. +func (bb *bodyBuffer) write(b []byte) []byte { + n := min(len(b), cap(*bb)-len(*bb)) + *bb = append(*bb, b[:n]...) + return b[n:] +} + +// discard resets the buffer so it can be used again. +func (bb *bodyBuffer) discard() { + *bb = (*bb)[:0] +} + +// inferHeader populates h with the header values that we can infer from our +// current buffer content, if not already explicitly set. This method should be +// called only once with as much body content as possible in the buffer, before +// a HEADERS frame is sent, and before discard has been called. Doing so +// properly is the responsibility of the caller. +func (bb *bodyBuffer) inferHeader(h http.Header, status int) { + if _, ok := h["Date"]; !ok { + h.Set("Date", time.Now().UTC().Format(http.TimeFormat)) + } + // If the Content-Encoding is non-blank, we shouldn't + // sniff the body. See Issue golang.org/issue/31753. + _, hasCE := h["Content-Encoding"] + _, hasCT := h["Content-Type"] + if !hasCE && !hasCT && responseCanHaveBody(status) && len(*bb) > 0 { + h.Set("Content-Type", http.DetectContentType(*bb)) + } + // We can technically infer Content-Length too here, as long as the entire + // response body fits within hi.buf and does not require flushing. However, + // we have chosen not to do so for now as Content-Length is not very + // important for HTTP/3, and such inconsistent behavior might be confusing. +} diff --git a/internal/http3/server_test.go b/internal/http3/server_test.go index f167bb84..0583cf03 100644 --- a/internal/http3/server_test.go +++ b/internal/http3/server_test.go @@ -11,6 +11,8 @@ import ( "net/netip" "net/url" "reflect" + "slices" + "strconv" "testing" "testing/synctest" "time" @@ -82,7 +84,7 @@ func TestServerHeader(t *testing.T) { "header-from-client": {"that", "should", "be", "echoed"}, })) synctest.Wait() - reqStream.wantHeaders(http.Header{ + reqStream.wantSomeHeaders(http.Header{ ":status": {"204"}, "Header-From-Client": {"that", "should", "be", "echoed"}, }) @@ -129,7 +131,7 @@ func TestServerPseudoHeader(t *testing.T) { ":path": {"/some/path?query=value&query2=value2#fragment"}, }) synctest.Wait() - reqStream.wantHeaders(http.Header{":status": {"321"}}) + reqStream.wantSomeHeaders(http.Header{":status": {"321"}}) reqStream.wantClosed("request is complete") reqStream = tc.newStream(streamTypeRequest) @@ -155,7 +157,7 @@ func TestServerInvalidHeader(t *testing.T) { reqStream := tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(nil)) synctest.Wait() - reqStream.wantHeaders(http.Header{ + reqStream.wantSomeHeaders(http.Header{ ":status": {"200"}, "Valid-Name": {"valid value"}, "Valid-Name-2": {"valid value 2"}, @@ -183,9 +185,9 @@ func TestServerBody(t *testing.T) { reqStream.writeData(bodyContent) reqStream.stream.stream.CloseWrite() synctest.Wait() - reqStream.wantHeaders(http.Header{":status": {"200"}}) - reqStream.wantData([]byte("/")) - reqStream.wantData(bodyContent) + reqStream.wantSomeHeaders(http.Header{":status": {"200"}}) + // Small multiple calls to Write will be coalesced into one DATA frame. + reqStream.wantData(append([]byte("/"), bodyContent...)) reqStream.wantClosed("request is complete") }) } @@ -202,14 +204,14 @@ func TestServerHeadResponseNoBody(t *testing.T) { reqStream := tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(nil)) synctest.Wait() - reqStream.wantHeaders(http.Header{":status": {"200"}}) + reqStream.wantSomeHeaders(http.Header{":status": {"200"}}) reqStream.wantData(bodyContent) reqStream.wantClosed("request is complete") reqStream = tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(http.Header{":method": {http.MethodHead}})) synctest.Wait() - reqStream.wantHeaders(http.Header{":status": {"200"}}) + reqStream.wantSomeHeaders(http.Header{":status": {"200"}}) reqStream.wantClosed("request is complete") }) } @@ -225,7 +227,7 @@ func TestServerHandlerEmpty(t *testing.T) { reqStream := tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(nil)) synctest.Wait() - reqStream.wantHeaders(http.Header{":status": {"200"}}) + reqStream.wantSomeHeaders(http.Header{":status": {"200"}}) reqStream.wantClosed("request is complete") }) } @@ -291,7 +293,7 @@ func TestServerHandlerStreaming(t *testing.T) { reqStream := tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(nil)) synctest.Wait() - reqStream.wantHeaders(http.Header{":status": {"200"}}) + reqStream.wantSomeHeaders(http.Header{":status": {"200"}}) for _, data := range []string{"a", "bunch", "of", "things", "to", "stream"} { stream <- data @@ -335,13 +337,14 @@ func TestServerExpect100Continue(t *testing.T) { // Wait until server responds with HTTP status 100 before sending the // body. synctest.Wait() - reqStream.wantHeaders(http.Header{":status": {"100"}}) + reqStream.wantSomeHeaders(http.Header{":status": {"100"}}) body := []byte("body that will be echoed back if we get status 100") reqStream.writeData(body) reqStream.stream.stream.CloseWrite() // Receive the server's response after sending the body. - reqStream.wantHeaders(http.Header{":status": {"200"}}) + synctest.Wait() + reqStream.wantSomeHeaders(http.Header{":status": {"200"}}) reqStream.wantData(body) reqStream.wantClosed("request is complete") }) @@ -365,12 +368,43 @@ func TestServerExpect100ContinueRejected(t *testing.T) { // Server rejects it. synctest.Wait() - reqStream.wantHeaders(http.Header{":status": {"403"}}) + reqStream.wantSomeHeaders(http.Header{":status": {"403"}}) reqStream.wantData(rejectBody) reqStream.wantClosed("request is complete") }) } +func TestServerNoExpect100ContinueAfterNormalResponse(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.(http.Flusher).Flush() + // This should not cause an HTTP 100 status to be sent since we + // have sent an HTTP 200 response already. + io.ReadAll(r.Body) + })) + tc := ts.connect() + tc.greet() + + // Client sends an Expect: 100-continue request. + reqStream := tc.newStream(streamTypeRequest) + reqStream.writeHeaders(requestHeader(http.Header{ + "Expect": {"100-continue"}, + })) + // Client sends a body prematurely. This should not happen, unless a + // client misbehaves. We do so here anyways so the server handler can + // read the request body without hanging, which would normally cause an + // HTTP 100 to be sent. + reqStream.writeData([]byte("some body")) + reqStream.stream.stream.CloseWrite() + + // Verify that no HTTP 100 was sent. + synctest.Wait() + reqStream.wantSomeHeaders(http.Header{":status": {"200"}}) + reqStream.wantClosed("request is complete") + }) +} + func TestServerHandlerReadReqWithNoBody(t *testing.T) { synctest.Test(t, func(t *testing.T) { serverBody := []byte("hello from server!") @@ -389,7 +423,7 @@ func TestServerHandlerReadReqWithNoBody(t *testing.T) { reqStream.writeHeaders(requestHeader(nil)) reqStream.stream.stream.CloseWrite() synctest.Wait() - reqStream.wantHeaders(http.Header{":status": {"200"}}) + reqStream.wantSomeHeaders(http.Header{":status": {"200"}}) reqStream.wantData(serverBody) reqStream.wantClosed("request is complete") @@ -400,7 +434,7 @@ func TestServerHandlerReadReqWithNoBody(t *testing.T) { "Content-Length": {"0"}, })) synctest.Wait() - reqStream.wantHeaders(http.Header{":status": {"200"}}) + reqStream.wantSomeHeaders(http.Header{":status": {"200"}}) reqStream.wantData(serverBody) reqStream.wantClosed("request is complete") }) @@ -508,12 +542,12 @@ func TestServerHandlerWriteTrailer(t *testing.T) { reqStream := tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(nil)) synctest.Wait() - reqStream.wantHeaders(http.Header{ + reqStream.wantSomeHeaders(http.Header{ ":status": {"200"}, "Trailer": {"Server-Trailer-A, Server-Trailer-B, Server-Trailer-C"}, }) reqStream.wantData(body) - reqStream.wantHeaders(http.Header{ + reqStream.wantSomeHeaders(http.Header{ "Server-Trailer-A": {"valuea"}, "Server-Trailer-C": {"valuec"}, }) @@ -539,11 +573,11 @@ func TestServerHandlerWriteTrailerNoBody(t *testing.T) { reqStream := tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(nil)) synctest.Wait() - reqStream.wantHeaders(http.Header{ + reqStream.wantSomeHeaders(http.Header{ ":status": {"200"}, "Trailer": {"Server-Trailer-A, Server-Trailer-B, Server-Trailer-C"}, }) - reqStream.wantHeaders(http.Header{ + reqStream.wantSomeHeaders(http.Header{ "Server-Trailer-A": {"valuea"}, "Server-Trailer-C": {"valuec"}, }) @@ -551,6 +585,211 @@ func TestServerHandlerWriteTrailerNoBody(t *testing.T) { }) } +func TestServerInfersHeaders(t *testing.T) { + tests := []struct { + name string + flushedEarly bool + responseStatus int + does100Continue bool + declaredHeader http.Header + want http.Header + }{ + { + name: "infers undeclared headers", + responseStatus: 200, + declaredHeader: http.Header{ + "Some-Other-Header": {"some value"}, + }, + want: http.Header{ + "Date": {"Sat, 01 Jan 2000 00:00:00 GMT"}, // Synctest starting time. + "Content-Type": {"text/html; charset=utf-8"}, + "Some-Other-Header": {"some value"}, + }, + }, + { + name: "does not write over declared header", + responseStatus: 200, + declaredHeader: http.Header{ + "Date": {"some date"}, + "Content-Type": {"some content type"}, + "Some-Other-Header": {"some value"}, + }, + want: http.Header{ + "Date": {"some date"}, + "Content-Type": {"some content type"}, + "Some-Other-Header": {"some value"}, + }, + }, + { + name: "does not infer content type for response with no body", + responseStatus: 304, // 304 status response has no body. + declaredHeader: http.Header{ + "Some-Other-Header": {"some value"}, + }, + want: http.Header{ + "Date": {"Sat, 01 Jan 2000 00:00:00 GMT"}, // Synctest starting time. + "Some-Other-Header": {"some value"}, + }, + }, + { + // See golang.org/issue/31753. + name: "does not infer content type for response with declared content encoding", + responseStatus: 200, + declaredHeader: http.Header{ + "Content-Encoding": {"some encoding"}, + "Some-Other-Header": {"some value"}, + }, + want: http.Header{ + "Date": {"Sat, 01 Jan 2000 00:00:00 GMT"}, // Synctest starting time. + "Content-Encoding": {"some encoding"}, + "Some-Other-Header": {"some value"}, + }, + }, + { + name: "does not infer content type when header is flushed before body is written", + responseStatus: 200, + flushedEarly: true, + declaredHeader: http.Header{ + "Some-Other-Header": {"some value"}, + }, + want: http.Header{ + "Date": {"Sat, 01 Jan 2000 00:00:00 GMT"}, // Synctest starting time. + "Some-Other-Header": {"some value"}, + }, + }, + { + name: "infers header for the header that comes after 100 continue", + responseStatus: 200, + does100Continue: true, + declaredHeader: http.Header{ + "Some-Other-Header": {"some value"}, + }, + want: http.Header{ + "Date": {"Sat, 01 Jan 2000 00:00:00 GMT"}, // Synctest starting time. + "Content-Type": {"text/html; charset=utf-8"}, + "Some-Other-Header": {"some value"}, + }, + }, + } + + for _, tt := range tests { + synctestSubtest(t, tt.name, func(t *testing.T) { + body := []byte("some html content") + streamIdle := make(chan bool) + ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if tt.does100Continue { + <-streamIdle + io.ReadAll(r.Body) + } + for name, values := range tt.declaredHeader { + for _, value := range values { + w.Header().Add(name, value) + } + } + w.WriteHeader(tt.responseStatus) + if tt.flushedEarly { + w.(http.Flusher).Flush() + } + // Write the body one byte at a time. To confirm that body + // writes are buffered and that Content-Type will not be + // wrongly identified as text/plain rather than text/html. + for _, b := range body { + w.Write([]byte{b}) + } + })) + tc := ts.connect() + tc.greet() + + reqStream := tc.newStream(streamTypeRequest) + + if tt.does100Continue { + reqStream.writeHeaders(requestHeader(http.Header{ + "Expect": {"100-continue"}, + })) + reqStream.wantIdle("stream is idle until server sends an HTTP 100 status") + streamIdle <- true + synctest.Wait() + reqStream.wantHeaders(http.Header{":status": {"100"}}) + } + + reqStream.writeHeaders(requestHeader(nil)) + synctest.Wait() + tt.want.Add(":status", strconv.Itoa(tt.responseStatus)) + reqStream.wantHeaders(tt.want) + if responseCanHaveBody(tt.responseStatus) { + reqStream.wantData(body) + } + reqStream.wantClosed("request is complete") + }) + } +} + +func TestServerBuffersBodyWrite(t *testing.T) { + tests := []struct { + name string + bodyLen int + writeSize int + flushes bool + }{ + { + name: "buffers small body content", + bodyLen: defaultBodyBufferCap * 10, + writeSize: 5, + flushes: false, + }, + { + name: "does not buffer large body content", + bodyLen: defaultBodyBufferCap * 10, + writeSize: defaultBodyBufferCap * 2, + flushes: false, + }, + { + name: "does not buffer flushed body content", + bodyLen: defaultBodyBufferCap * 10, + writeSize: 10, + flushes: true, + }, + } + for _, tt := range tests { + synctestSubtest(t, tt.name, func(t *testing.T) { + ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for n := 0; n < tt.bodyLen; n += tt.writeSize { + data := slices.Repeat([]byte("a"), min(tt.writeSize, tt.bodyLen-n)) + w.Write(data) + if tt.flushes { + w.(http.Flusher).Flush() + } + } + })) + tc := ts.connect() + tc.greet() + + reqStream := tc.newStream(streamTypeRequest) + reqStream.writeHeaders(requestHeader(nil)) + synctest.Wait() + reqStream.wantHeaders(nil) + switch { + case tt.writeSize > defaultBodyBufferCap: + // After using the buffer once, it is no longer used since the + // writeSize is larger than the buffer. + for n := 0; n < tt.bodyLen; n += tt.writeSize { + reqStream.wantData(slices.Repeat([]byte("a"), min(tt.writeSize, tt.bodyLen-n))) + } + case tt.flushes: + for n := 0; n < tt.bodyLen; n += tt.writeSize { + reqStream.wantData(slices.Repeat([]byte("a"), min(tt.writeSize, tt.bodyLen-n))) + } + case tt.writeSize <= defaultBodyBufferCap: + dataLen := defaultBodyBufferCap + tt.writeSize - (defaultBodyBufferCap % tt.writeSize) + for n := 0; n < tt.bodyLen; n += dataLen { + reqStream.wantData(slices.Repeat([]byte("a"), min(dataLen, tt.bodyLen-n))) + } + } + reqStream.wantClosed("request is complete") + }) + } +} + type testServer struct { t testing.TB s *Server diff --git a/internal/http3/transport_test.go b/internal/http3/transport_test.go index d0f1f2ca..71c5aeca 100644 --- a/internal/http3/transport_test.go +++ b/internal/http3/transport_test.go @@ -215,6 +215,43 @@ func (ts *testQUICStream) wantHeaders(want http.Header) { } } +// wantSomeHeaders reads a HEADERS frame and asserts that want is a subset of +// the read HEADERS frame. +// This is like wantHeaders, but headers that are in the HEADERS frame but not +// in want are ignored. +func (ts *testQUICStream) wantSomeHeaders(want http.Header) { + ts.t.Helper() + ftype, err := ts.readFrameHeader() + if err != nil { + ts.t.Fatalf("want HEADERS frame, got error: %v", err) + } + if ftype != frameTypeHeaders { + ts.t.Fatalf("want HEADERS frame, got: %v", ftype) + } + + if want == nil { + panic("use wantHeaders(nil) instead to ignore the content of the frame") + } + + got := make(http.Header) + var dec qpackDecoder + err = dec.decode(ts.stream, func(_ indexType, name, value string) error { + got.Add(name, value) + return nil + }) + for name := range got { + if _, ok := want[name]; !ok { + delete(got, name) + } + } + if diff := diffHeaders(got, want); diff != "" { + ts.t.Fatalf("unexpected response headers:\n%v", diff) + } + if err := ts.endFrame(); err != nil { + ts.t.Fatalf("endFrame: %v", err) + } +} + func (ts *testQUICStream) encodeHeaders(h http.Header) []byte { ts.t.Helper() var enc qpackEncoder