internal/http3: add Expect: 100-continue support to Server

When serving a request containing the "Expect: 100-continue" header,
Server will now send an HTTP 100 status automatically if the request
body is read from within the server handler.

For golang/go#70914

Change-Id: Ib8185170deabf777a02487a1ded6671db720df51
Reviewed-on: https://go-review.googlesource.com/c/net/+/742520
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Nicholas Husin <husin@google.com>
This commit is contained in:
Nicholas S. Husin
2026-02-05 16:12:40 -05:00
committed by Nicholas Husin
parent af0c9df79d
commit 73fe7011ad
4 changed files with 211 additions and 46 deletions

View File

@@ -57,6 +57,10 @@ type bodyReader struct {
mu sync.Mutex
remain int64
err error
// If not nil, the body contains an "Expect: 100-continue" header, and
// send100Continue should be called when Read is invoked for the first
// time.
send100Continue func()
}
func (r *bodyReader) Read(p []byte) (n int, err error) {
@@ -65,6 +69,10 @@ func (r *bodyReader) Read(p []byte) (n int, err error) {
// Use a mutex here to provide the same behavior.
r.mu.Lock()
defer r.mu.Unlock()
if r.send100Continue != nil {
r.send100Continue()
r.send100Continue = nil
}
if r.err != nil {
return 0, r.err
}

View File

@@ -7,11 +7,11 @@ package http3
import (
"context"
"net/http"
"net/url"
"strconv"
"sync"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/internal/httpcommon"
"golang.org/x/net/quic"
)
@@ -157,54 +157,81 @@ func (sc *serverConn) handlePushStream(*stream) error {
}
}
func (sc *serverConn) parseRequest(st *stream) (*http.Request, error) {
req := &http.Request{
URL: &url.URL{},
Proto: "HTTP/3.0",
ProtoMajor: 3,
RemoteAddr: sc.qconn.RemoteAddr().String(),
}
type pseudoHeader struct {
method string
scheme string
path string
authority string
}
func (sc *serverConn) parseHeader(st *stream) (http.Header, pseudoHeader, error) {
ftype, err := st.readFrameHeader()
if err != nil {
return nil, err
return nil, pseudoHeader{}, err
}
if ftype != frameTypeHeaders {
return nil, err
return nil, pseudoHeader{}, err
}
req.Header = make(http.Header)
header := make(http.Header)
var pHeader pseudoHeader
var dec qpackDecoder
if err := dec.decode(st, func(_ indexType, name, value string) error {
switch name {
case ":method":
req.Method = value
pHeader.method = value
case ":scheme":
req.URL.Scheme = value
pHeader.scheme = value
case ":path":
req.URL.Path = value
pHeader.path = value
case ":authority":
req.URL.Host = value
pHeader.authority = value
default:
req.Header.Add(name, value)
header.Add(name, value)
}
return nil
}); err != nil {
return nil, err
return nil, pseudoHeader{}, err
}
if err := st.endFrame(); err != nil {
return nil, err
return nil, pseudoHeader{}, err
}
req.Body = &bodyReader{
st: st,
remain: -1,
}
return req, nil
return header, pHeader, nil
}
func (sc *serverConn) handleRequestStream(st *stream) error {
req, err := sc.parseRequest(st)
header, pHeader, err := sc.parseHeader(st)
if err != nil {
return err
}
reqInfo := httpcommon.NewServerRequest(httpcommon.ServerRequestParam{
Method: pHeader.method,
Scheme: pHeader.scheme,
Authority: pHeader.authority,
Path: pHeader.path,
Header: header,
})
if reqInfo.InvalidReason != "" {
return &streamError{
code: errH3MessageError,
message: reqInfo.InvalidReason,
}
}
req := &http.Request{
Proto: "HTTP/3.0",
Method: pHeader.method,
Host: pHeader.authority,
URL: reqInfo.URL,
RequestURI: reqInfo.RequestURI,
Trailer: reqInfo.Trailer,
ProtoMajor: 3,
RemoteAddr: sc.qconn.RemoteAddr().String(),
Body: &bodyReader{
st: st,
remain: -1,
},
Header: header,
}
defer req.Body.Close()
rw := &responseWriter{
@@ -219,6 +246,12 @@ func (sc *serverConn) handleRequestStream(st *stream) error {
},
}
defer rw.close()
if reqInfo.NeedsContinue {
req.Body.(*bodyReader).send100Continue = func() {
rw.WriteHeader(http.StatusContinue)
rw.Flush()
}
}
// TODO: handle panic coming from the HTTP handler.
sc.handler.ServeHTTP(rw, req)
@@ -238,11 +271,10 @@ func (sc *serverConn) abort(err error) {
}
type responseWriter struct {
st *stream
bw *bodyWriter
mu sync.Mutex
headers http.Header
// TODO: support 1xx status
st *stream
bw *bodyWriter
mu sync.Mutex
headers http.Header
wroteHeader bool // Non-1xx header has been (logically) written.
isHeadResp bool // response is for a HEAD request.
}
@@ -278,7 +310,9 @@ func (rw *responseWriter) writeHeaderLockedOnce(statusCode int) {
rw.st.writeVarint(int64(frameTypeHeaders))
rw.st.writeVarint(int64(len(encHeaders)))
rw.st.Write(encHeaders)
rw.wroteHeader = true
if statusCode >= http.StatusOK {
rw.wroteHeader = true
}
}
func (rw *responseWriter) WriteHeader(statusCode int) {

View File

@@ -8,8 +8,11 @@ package http3
import (
"io"
"maps"
"net/http"
"net/netip"
"net/url"
"reflect"
"testing"
"testing/synctest"
"time"
@@ -18,6 +21,22 @@ import (
"golang.org/x/net/quic"
)
// requestHeader is a helper function to make sure that all required
// pseudo-headers exist in an http.Header used for a request. Per
// https://www.rfc-editor.org/rfc/rfc9114.html#name-request-pseudo-header-field:
// "All HTTP/3 requests MUST include exactly one value for the :method,
// :scheme, and :path pseudo-header fields, unless the request is a CONNECT
// request;"
func requestHeader(h http.Header) http.Header {
minimalHeader := http.Header{
":method": {"GET"},
":scheme": {"https"},
":path": {"/"},
}
maps.Copy(minimalHeader, h)
return minimalHeader
}
func TestServerReceivePushStream(t *testing.T) {
// "[...] if a server receives a client-initiated push stream,
// this MUST be treated as a connection error of type H3_STREAM_CREATION_ERROR."
@@ -61,9 +80,9 @@ func TestServerHeader(t *testing.T) {
tc.greet()
reqStream := tc.newStream(streamTypeRequest)
reqStream.writeHeaders(http.Header{
reqStream.writeHeaders(requestHeader(http.Header{
"header-from-client": {"that", "should", "be", "echoed"},
})
}))
synctest.Wait()
reqStream.wantHeaders(http.Header{
":status": {"204"},
@@ -78,9 +97,23 @@ func TestServerPseudoHeader(t *testing.T) {
ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Pseudo-headers from client request should populate a specific
// field in http.Request, and should not be part of http.Request.Header.
if r.Header.Get(":method") != "" || r.Method != "GET" {
t.Error("want pseudo-headers from client to be reflected in appropriate fields in http.Request, not in http.Request.Header")
if len(r.Header) != 0 {
t.Errorf("got %v, want request header to be empty", r.Header)
}
if r.Method != "GET" {
t.Errorf("got %v, want GET method", r.Method)
}
if r.Host != "fake.tld:1234" {
t.Errorf("got %v, want fake.tld:1234", r.Host)
}
wantURL := &url.URL{
Path: "/some/path",
RawQuery: "query=value&query2=value2#fragment",
}
if !reflect.DeepEqual(r.URL, wantURL) {
t.Errorf("got %v, want URL to be %v", r.URL, wantURL)
}
// Conversely, server should not be able to set pseudo-headers by
// writing to the ResponseWriter's Header.
header := w.Header()
@@ -91,10 +124,20 @@ func TestServerPseudoHeader(t *testing.T) {
tc.greet()
reqStream := tc.newStream(streamTypeRequest)
reqStream.writeHeaders(http.Header{":method": {"GET"}})
reqStream.writeHeaders(http.Header{
":method": {"GET"},
":authority": {"fake.tld:1234"},
":scheme": {"https"},
":path": {"/some/path?query=value&query2=value2#fragment"},
})
synctest.Wait()
reqStream.wantHeaders(http.Header{":status": {"321"}})
reqStream.wantClosed("request is complete")
reqStream = tc.newStream(streamTypeRequest)
reqStream.writeHeaders(http.Header{}) // Missing pseudo-header.
synctest.Wait()
reqStream.wantError(quic.StreamErrorCode(errH3MessageError))
})
}
@@ -112,7 +155,7 @@ func TestServerInvalidHeader(t *testing.T) {
tc.greet()
reqStream := tc.newStream(streamTypeRequest)
reqStream.writeHeaders(http.Header{})
reqStream.writeHeaders(requestHeader(nil))
synctest.Wait()
reqStream.wantHeaders(http.Header{
":status": {"200"},
@@ -137,9 +180,7 @@ func TestServerBody(t *testing.T) {
tc.greet()
reqStream := tc.newStream(streamTypeRequest)
reqStream.writeHeaders(http.Header{
":path": {"/"},
})
reqStream.writeHeaders(requestHeader(nil))
bodyContent := []byte("some body content that should be echoed")
reqStream.writeData(bodyContent)
reqStream.stream.stream.CloseWrite()
@@ -161,14 +202,14 @@ func TestServerHeadResponseNoBody(t *testing.T) {
tc.greet()
reqStream := tc.newStream(streamTypeRequest)
reqStream.writeHeaders(http.Header{":method": {http.MethodGet}})
reqStream.writeHeaders(requestHeader(nil))
synctest.Wait()
reqStream.wantHeaders(http.Header{":status": {"200"}})
reqStream.wantData(bodyContent)
reqStream.wantClosed("request is complete")
reqStream = tc.newStream(streamTypeRequest)
reqStream.writeHeaders(http.Header{":method": {http.MethodHead}})
reqStream.writeHeaders(requestHeader(http.Header{":method": {http.MethodHead}}))
synctest.Wait()
reqStream.wantHeaders(http.Header{":status": {"200"}})
reqStream.wantClosed("request is complete")
@@ -184,7 +225,7 @@ func TestServerHandlerEmpty(t *testing.T) {
tc.greet()
reqStream := tc.newStream(streamTypeRequest)
reqStream.writeHeaders(http.Header{":method": {http.MethodGet}})
reqStream.writeHeaders(requestHeader(nil))
synctest.Wait()
reqStream.wantHeaders(http.Header{":status": {"200"}})
reqStream.wantClosed("request is complete")
@@ -208,7 +249,7 @@ func TestServerHandlerFlushing(t *testing.T) {
tc.greet()
reqStream := tc.newStream(streamTypeRequest)
reqStream.writeHeaders(http.Header{":method": {http.MethodGet}})
reqStream.writeHeaders(requestHeader(nil))
synctest.Wait()
respBody := make([]byte, 100)
@@ -216,7 +257,7 @@ func TestServerHandlerFlushing(t *testing.T) {
time.Sleep(time.Second)
synctest.Wait()
if n, err := reqStream.Read(respBody); err == nil {
t.Errorf("want no message yet, got %v bytes read", n)
t.Errorf("got %v bytes read, want no message yet", n)
}
time.Sleep(time.Second)
@@ -228,7 +269,7 @@ func TestServerHandlerFlushing(t *testing.T) {
time.Sleep(time.Second)
synctest.Wait()
if _, err := reqStream.Read(respBody); err != io.EOF {
t.Errorf("expected EOF, got err: %v", err)
t.Errorf("got err %v, want EOF", err)
}
reqStream.wantClosed("request is complete")
})
@@ -250,7 +291,7 @@ func TestServerHandlerStreaming(t *testing.T) {
tc.greet()
reqStream := tc.newStream(streamTypeRequest)
reqStream.writeHeaders(http.Header{":method": {http.MethodGet}})
reqStream.writeHeaders(requestHeader(nil))
synctest.Wait()
reqStream.wantHeaders(http.Header{":status": {"200"}})
@@ -263,6 +304,75 @@ func TestServerHandlerStreaming(t *testing.T) {
})
}
func TestServerExpect100Continue(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
streamIdle := make(chan bool)
ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Expect: 100-continue header should not be accessible from the
// server handler.
if len(r.Header) > 0 {
t.Errorf("got %v, want request header to be empty", r.Header)
}
// Reading the body will cause the server to call w.WriteHeader(100).
<-streamIdle
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
// Implicitly calls w.WriteHeader(200) since non-1XX status code
// has been sent yet so far.
w.Write(body)
}))
tc := ts.connect()
tc.greet()
// Client sends an Expect: 100-continue request.
reqStream := tc.newStream(streamTypeRequest)
reqStream.writeHeaders(requestHeader(http.Header{
"Expect": {"100-continue"},
}))
reqStream.wantIdle("stream is idle until server sends an HTTP 100 status")
streamIdle <- true
// Wait until server responds with HTTP status 100 before sending the
// body.
synctest.Wait()
reqStream.wantHeaders(http.Header{":status": {"100"}})
body := []byte("body that will be echoed back if we get status 100")
reqStream.writeData(body)
reqStream.stream.stream.CloseWrite()
// Receive the server's response after sending the body.
reqStream.wantHeaders(http.Header{":status": {"200"}})
reqStream.wantData(body)
reqStream.wantClosed("request is complete")
})
}
func TestServerExpect100ContinueRejected(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
rejectBody := []byte("not allowed")
ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(403)
w.Write(rejectBody)
}))
tc := ts.connect()
tc.greet()
// Client sends an Expect: 100-continue request.
reqStream := tc.newStream(streamTypeRequest)
reqStream.writeHeaders(requestHeader(http.Header{
"Expect": {"100-continue"},
}))
// Server rejects it.
synctest.Wait()
reqStream.wantHeaders(http.Header{":status": {"403"}})
reqStream.wantData(rejectBody)
reqStream.wantClosed("request is complete")
})
}
type testServer struct {
t testing.TB
s *Server

View File

@@ -158,6 +158,19 @@ func newTestQUICStream(t testing.TB, st *stream) *testQUICStream {
}
}
func (ts *testQUICStream) wantIdle(reason string) {
ts.t.Helper()
synctest.Wait()
qs := ts.stream.stream
ctx, cancel := context.WithCancel(context.Background())
cancel()
qs.SetReadContext(ctx)
if _, err := qs.Read(make([]byte, 1)); !errors.Is(err, context.Canceled) {
ts.t.Fatalf("%v: want stream to be idle, but stream has content", reason)
}
qs.SetReadContext(nil)
}
// wantFrameHeader calls readFrameHeader and asserts that the frame is of a given type.
func (ts *testQUICStream) wantFrameHeader(reason string, wantType frameType) {
ts.t.Helper()