diff --git a/internal/http3/server.go b/internal/http3/server.go index b9d053a0..9d8937d1 100644 --- a/internal/http3/server.go +++ b/internal/http3/server.go @@ -270,18 +270,7 @@ func (sc *serverConn) handleRequestStream(st *stream) error { defer rw.close() if reqInfo.NeedsContinue { req.Body.(*bodyReader).send100Continue = func() { - rw.mu.Lock() - defer rw.mu.Unlock() - if rw.wroteHeader { - return - } - encHeaders := rw.bw.enc.encode(func(f func(itype indexType, name, value string)) { - f(mayIndex, ":status", strconv.Itoa(http.StatusContinue)) - }) - rw.st.writeVarint(int64(frameTypeHeaders)) - rw.st.writeVarint(int64(len(encHeaders))) - rw.st.Write(encHeaders) - rw.st.Flush() + rw.WriteHeader(100) } } @@ -350,8 +339,10 @@ func (rw *responseWriter) prepareTrailerForWriteLocked() { } } -// Caller must hold rw.mu. If rw.wroteHeader is true, calling this method is a -// no-op. +// writeHeaderLockedOnce writes the final response header. If rw.wroteHeader is +// true, calling this method is a no-op. Sending informational status headers +// should be done using writeInfoHeaderLocked, rather than this method. +// Caller must hold rw.mu. func (rw *responseWriter) writeHeaderLockedOnce() { if rw.wroteHeader { return @@ -387,9 +378,42 @@ func (rw *responseWriter) writeHeaderLockedOnce() { rw.st.writeVarint(int64(frameTypeHeaders)) rw.st.writeVarint(int64(len(encHeaders))) rw.st.Write(encHeaders) - if rw.statusCode >= http.StatusOK { - rw.wroteHeader = true + rw.wroteHeader = true +} + +// writeHeaderLocked writes informational status headers (i.e. status 1XX). +// If a non-informational status header has been written via +// writeHeaderLockedOnce, this method is a no-op. +// Caller must hold rw.mu. +func (rw *responseWriter) writeHeaderLocked(statusCode int) { + if rw.wroteHeader { + return } + encHeaders := rw.bw.enc.encode(func(f func(itype indexType, name, value string)) { + f(mayIndex, ":status", strconv.Itoa(statusCode)) + for name, values := range rw.headers { + if name == "Content-Length" || name == "Transfer-Encoding" { + continue + } + if !httpguts.ValidHeaderFieldName(name) { + continue + } + for _, val := range values { + if !httpguts.ValidHeaderFieldValue(val) { + continue + } + // Issue #71374: Consider supporting never-indexed fields. + f(mayIndex, name, val) + } + } + }) + rw.st.writeVarint(int64(frameTypeHeaders)) + rw.st.writeVarint(int64(len(encHeaders))) + rw.st.Write(encHeaders) +} + +func isInfoStatus(status int) bool { + return status >= 100 && status < 200 } func (rw *responseWriter) WriteHeader(statusCode int) { @@ -399,9 +423,19 @@ func (rw *responseWriter) WriteHeader(statusCode int) { if rw.statusCodeSet { return } + + // Informational headers can be sent multiple times, and should be flushed + // immediately. + if isInfoStatus(statusCode) { + rw.writeHeaderLocked(statusCode) + rw.st.Flush() + return + } + + // Non-informational headers should only be set once, and should be + // buffered. rw.statusCodeSet = true rw.statusCode = statusCode - if n, err := strconv.Atoi(rw.Header().Get("Content-Length")); err == nil { rw.bodyLenLeft = n } else { diff --git a/internal/http3/server_test.go b/internal/http3/server_test.go index 0318ab48..7b938a74 100644 --- a/internal/http3/server_test.go +++ b/internal/http3/server_test.go @@ -865,6 +865,53 @@ func TestServerBuffersBodyWrite(t *testing.T) { } } +func TestServer103EarlyHints(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + body := []byte("some body") + ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + + h.Add("Content-Length", "123") // Must be ignored + h.Add("Link", "; rel=preload; as=style") + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusEarlyHints) + + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusEarlyHints) + + w.Write(body) // Implicitly sends status 200. + w.WriteHeader(http.StatusEarlyHints) // Should be a no-op. + })) + tc := ts.connect() + tc.greet() + + reqStream := tc.newStream(streamTypeRequest) + reqStream.writeHeaders(requestHeader(nil)) + synctest.Wait() + reqStream.wantHeaders(http.Header{ + ":status": {"103"}, + "Link": { + "; rel=preload; as=style", + "; rel=preload; as=script", + }, + }) + reqStream.wantHeaders(http.Header{ + ":status": {"103"}, + "Link": { + "; rel=preload; as=style", + "; rel=preload; as=script", + "; rel=preload; as=script", + }, + }) + reqStream.wantSomeHeaders(http.Header{ + ":status": {"200"}, + "Content-Length": {"123"}, + }) + reqStream.wantData(body) + reqStream.wantClosed("request is complete") + }) +} + type testServer struct { t testing.TB s *Server