diff --git a/internal/http3/body.go b/internal/http3/body.go index fc758bd9..df06432c 100644 --- a/internal/http3/body.go +++ b/internal/http3/body.go @@ -57,6 +57,10 @@ type bodyReader struct { mu sync.Mutex remain int64 err error + // If not nil, the body contains an "Expect: 100-continue" header, and + // send100Continue should be called when Read is invoked for the first + // time. + send100Continue func() } func (r *bodyReader) Read(p []byte) (n int, err error) { @@ -65,6 +69,10 @@ func (r *bodyReader) Read(p []byte) (n int, err error) { // Use a mutex here to provide the same behavior. r.mu.Lock() defer r.mu.Unlock() + if r.send100Continue != nil { + r.send100Continue() + r.send100Continue = nil + } if r.err != nil { return 0, r.err } diff --git a/internal/http3/server.go b/internal/http3/server.go index a6976c08..5285d4c7 100644 --- a/internal/http3/server.go +++ b/internal/http3/server.go @@ -7,11 +7,11 @@ package http3 import ( "context" "net/http" - "net/url" "strconv" "sync" "golang.org/x/net/http/httpguts" + "golang.org/x/net/internal/httpcommon" "golang.org/x/net/quic" ) @@ -157,54 +157,81 @@ func (sc *serverConn) handlePushStream(*stream) error { } } -func (sc *serverConn) parseRequest(st *stream) (*http.Request, error) { - req := &http.Request{ - URL: &url.URL{}, - Proto: "HTTP/3.0", - ProtoMajor: 3, - RemoteAddr: sc.qconn.RemoteAddr().String(), - } +type pseudoHeader struct { + method string + scheme string + path string + authority string +} + +func (sc *serverConn) parseHeader(st *stream) (http.Header, pseudoHeader, error) { ftype, err := st.readFrameHeader() if err != nil { - return nil, err + return nil, pseudoHeader{}, err } if ftype != frameTypeHeaders { - return nil, err + return nil, pseudoHeader{}, err } - req.Header = make(http.Header) + header := make(http.Header) + var pHeader pseudoHeader var dec qpackDecoder if err := dec.decode(st, func(_ indexType, name, value string) error { switch name { case ":method": - req.Method = value + pHeader.method = value case ":scheme": - req.URL.Scheme = value + pHeader.scheme = value case ":path": - req.URL.Path = value + pHeader.path = value case ":authority": - req.URL.Host = value + pHeader.authority = value default: - req.Header.Add(name, value) + header.Add(name, value) } return nil }); err != nil { - return nil, err + return nil, pseudoHeader{}, err } if err := st.endFrame(); err != nil { - return nil, err + return nil, pseudoHeader{}, err } - req.Body = &bodyReader{ - st: st, - remain: -1, - } - return req, nil + return header, pHeader, nil } func (sc *serverConn) handleRequestStream(st *stream) error { - req, err := sc.parseRequest(st) + header, pHeader, err := sc.parseHeader(st) if err != nil { return err } + + reqInfo := httpcommon.NewServerRequest(httpcommon.ServerRequestParam{ + Method: pHeader.method, + Scheme: pHeader.scheme, + Authority: pHeader.authority, + Path: pHeader.path, + Header: header, + }) + if reqInfo.InvalidReason != "" { + return &streamError{ + code: errH3MessageError, + message: reqInfo.InvalidReason, + } + } + req := &http.Request{ + Proto: "HTTP/3.0", + Method: pHeader.method, + Host: pHeader.authority, + URL: reqInfo.URL, + RequestURI: reqInfo.RequestURI, + Trailer: reqInfo.Trailer, + ProtoMajor: 3, + RemoteAddr: sc.qconn.RemoteAddr().String(), + Body: &bodyReader{ + st: st, + remain: -1, + }, + Header: header, + } defer req.Body.Close() rw := &responseWriter{ @@ -219,6 +246,12 @@ func (sc *serverConn) handleRequestStream(st *stream) error { }, } defer rw.close() + if reqInfo.NeedsContinue { + req.Body.(*bodyReader).send100Continue = func() { + rw.WriteHeader(http.StatusContinue) + rw.Flush() + } + } // TODO: handle panic coming from the HTTP handler. sc.handler.ServeHTTP(rw, req) @@ -238,11 +271,10 @@ func (sc *serverConn) abort(err error) { } type responseWriter struct { - st *stream - bw *bodyWriter - mu sync.Mutex - headers http.Header - // TODO: support 1xx status + st *stream + bw *bodyWriter + mu sync.Mutex + headers http.Header wroteHeader bool // Non-1xx header has been (logically) written. isHeadResp bool // response is for a HEAD request. } @@ -278,7 +310,9 @@ func (rw *responseWriter) writeHeaderLockedOnce(statusCode int) { rw.st.writeVarint(int64(frameTypeHeaders)) rw.st.writeVarint(int64(len(encHeaders))) rw.st.Write(encHeaders) - rw.wroteHeader = true + if statusCode >= http.StatusOK { + rw.wroteHeader = true + } } func (rw *responseWriter) WriteHeader(statusCode int) { diff --git a/internal/http3/server_test.go b/internal/http3/server_test.go index b7f195da..2b6de06d 100644 --- a/internal/http3/server_test.go +++ b/internal/http3/server_test.go @@ -8,8 +8,11 @@ package http3 import ( "io" + "maps" "net/http" "net/netip" + "net/url" + "reflect" "testing" "testing/synctest" "time" @@ -18,6 +21,22 @@ import ( "golang.org/x/net/quic" ) +// requestHeader is a helper function to make sure that all required +// pseudo-headers exist in an http.Header used for a request. Per +// https://www.rfc-editor.org/rfc/rfc9114.html#name-request-pseudo-header-field: +// "All HTTP/3 requests MUST include exactly one value for the :method, +// :scheme, and :path pseudo-header fields, unless the request is a CONNECT +// request;" +func requestHeader(h http.Header) http.Header { + minimalHeader := http.Header{ + ":method": {"GET"}, + ":scheme": {"https"}, + ":path": {"/"}, + } + maps.Copy(minimalHeader, h) + return minimalHeader +} + func TestServerReceivePushStream(t *testing.T) { // "[...] if a server receives a client-initiated push stream, // this MUST be treated as a connection error of type H3_STREAM_CREATION_ERROR." @@ -61,9 +80,9 @@ func TestServerHeader(t *testing.T) { tc.greet() reqStream := tc.newStream(streamTypeRequest) - reqStream.writeHeaders(http.Header{ + reqStream.writeHeaders(requestHeader(http.Header{ "header-from-client": {"that", "should", "be", "echoed"}, - }) + })) synctest.Wait() reqStream.wantHeaders(http.Header{ ":status": {"204"}, @@ -78,9 +97,23 @@ func TestServerPseudoHeader(t *testing.T) { ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Pseudo-headers from client request should populate a specific // field in http.Request, and should not be part of http.Request.Header. - if r.Header.Get(":method") != "" || r.Method != "GET" { - t.Error("want pseudo-headers from client to be reflected in appropriate fields in http.Request, not in http.Request.Header") + if len(r.Header) != 0 { + t.Errorf("got %v, want request header to be empty", r.Header) } + if r.Method != "GET" { + t.Errorf("got %v, want GET method", r.Method) + } + if r.Host != "fake.tld:1234" { + t.Errorf("got %v, want fake.tld:1234", r.Host) + } + wantURL := &url.URL{ + Path: "/some/path", + RawQuery: "query=value&query2=value2#fragment", + } + if !reflect.DeepEqual(r.URL, wantURL) { + t.Errorf("got %v, want URL to be %v", r.URL, wantURL) + } + // Conversely, server should not be able to set pseudo-headers by // writing to the ResponseWriter's Header. header := w.Header() @@ -91,10 +124,20 @@ func TestServerPseudoHeader(t *testing.T) { tc.greet() reqStream := tc.newStream(streamTypeRequest) - reqStream.writeHeaders(http.Header{":method": {"GET"}}) + reqStream.writeHeaders(http.Header{ + ":method": {"GET"}, + ":authority": {"fake.tld:1234"}, + ":scheme": {"https"}, + ":path": {"/some/path?query=value&query2=value2#fragment"}, + }) synctest.Wait() reqStream.wantHeaders(http.Header{":status": {"321"}}) reqStream.wantClosed("request is complete") + + reqStream = tc.newStream(streamTypeRequest) + reqStream.writeHeaders(http.Header{}) // Missing pseudo-header. + synctest.Wait() + reqStream.wantError(quic.StreamErrorCode(errH3MessageError)) }) } @@ -112,7 +155,7 @@ func TestServerInvalidHeader(t *testing.T) { tc.greet() reqStream := tc.newStream(streamTypeRequest) - reqStream.writeHeaders(http.Header{}) + reqStream.writeHeaders(requestHeader(nil)) synctest.Wait() reqStream.wantHeaders(http.Header{ ":status": {"200"}, @@ -137,9 +180,7 @@ func TestServerBody(t *testing.T) { tc.greet() reqStream := tc.newStream(streamTypeRequest) - reqStream.writeHeaders(http.Header{ - ":path": {"/"}, - }) + reqStream.writeHeaders(requestHeader(nil)) bodyContent := []byte("some body content that should be echoed") reqStream.writeData(bodyContent) reqStream.stream.stream.CloseWrite() @@ -161,14 +202,14 @@ func TestServerHeadResponseNoBody(t *testing.T) { tc.greet() reqStream := tc.newStream(streamTypeRequest) - reqStream.writeHeaders(http.Header{":method": {http.MethodGet}}) + reqStream.writeHeaders(requestHeader(nil)) synctest.Wait() reqStream.wantHeaders(http.Header{":status": {"200"}}) reqStream.wantData(bodyContent) reqStream.wantClosed("request is complete") reqStream = tc.newStream(streamTypeRequest) - reqStream.writeHeaders(http.Header{":method": {http.MethodHead}}) + reqStream.writeHeaders(requestHeader(http.Header{":method": {http.MethodHead}})) synctest.Wait() reqStream.wantHeaders(http.Header{":status": {"200"}}) reqStream.wantClosed("request is complete") @@ -184,7 +225,7 @@ func TestServerHandlerEmpty(t *testing.T) { tc.greet() reqStream := tc.newStream(streamTypeRequest) - reqStream.writeHeaders(http.Header{":method": {http.MethodGet}}) + reqStream.writeHeaders(requestHeader(nil)) synctest.Wait() reqStream.wantHeaders(http.Header{":status": {"200"}}) reqStream.wantClosed("request is complete") @@ -208,7 +249,7 @@ func TestServerHandlerFlushing(t *testing.T) { tc.greet() reqStream := tc.newStream(streamTypeRequest) - reqStream.writeHeaders(http.Header{":method": {http.MethodGet}}) + reqStream.writeHeaders(requestHeader(nil)) synctest.Wait() respBody := make([]byte, 100) @@ -216,7 +257,7 @@ func TestServerHandlerFlushing(t *testing.T) { time.Sleep(time.Second) synctest.Wait() if n, err := reqStream.Read(respBody); err == nil { - t.Errorf("want no message yet, got %v bytes read", n) + t.Errorf("got %v bytes read, want no message yet", n) } time.Sleep(time.Second) @@ -228,7 +269,7 @@ func TestServerHandlerFlushing(t *testing.T) { time.Sleep(time.Second) synctest.Wait() if _, err := reqStream.Read(respBody); err != io.EOF { - t.Errorf("expected EOF, got err: %v", err) + t.Errorf("got err %v, want EOF", err) } reqStream.wantClosed("request is complete") }) @@ -250,7 +291,7 @@ func TestServerHandlerStreaming(t *testing.T) { tc.greet() reqStream := tc.newStream(streamTypeRequest) - reqStream.writeHeaders(http.Header{":method": {http.MethodGet}}) + reqStream.writeHeaders(requestHeader(nil)) synctest.Wait() reqStream.wantHeaders(http.Header{":status": {"200"}}) @@ -263,6 +304,75 @@ func TestServerHandlerStreaming(t *testing.T) { }) } +func TestServerExpect100Continue(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + streamIdle := make(chan bool) + ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Expect: 100-continue header should not be accessible from the + // server handler. + if len(r.Header) > 0 { + t.Errorf("got %v, want request header to be empty", r.Header) + } + // Reading the body will cause the server to call w.WriteHeader(100). + <-streamIdle + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatal(err) + } + // Implicitly calls w.WriteHeader(200) since non-1XX status code + // has been sent yet so far. + w.Write(body) + })) + tc := ts.connect() + tc.greet() + + // Client sends an Expect: 100-continue request. + reqStream := tc.newStream(streamTypeRequest) + reqStream.writeHeaders(requestHeader(http.Header{ + "Expect": {"100-continue"}, + })) + + reqStream.wantIdle("stream is idle until server sends an HTTP 100 status") + streamIdle <- true + // Wait until server responds with HTTP status 100 before sending the + // body. + synctest.Wait() + reqStream.wantHeaders(http.Header{":status": {"100"}}) + body := []byte("body that will be echoed back if we get status 100") + reqStream.writeData(body) + reqStream.stream.stream.CloseWrite() + + // Receive the server's response after sending the body. + reqStream.wantHeaders(http.Header{":status": {"200"}}) + reqStream.wantData(body) + reqStream.wantClosed("request is complete") + }) +} + +func TestServerExpect100ContinueRejected(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + rejectBody := []byte("not allowed") + ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(403) + w.Write(rejectBody) + })) + tc := ts.connect() + tc.greet() + + // Client sends an Expect: 100-continue request. + reqStream := tc.newStream(streamTypeRequest) + reqStream.writeHeaders(requestHeader(http.Header{ + "Expect": {"100-continue"}, + })) + + // Server rejects it. + synctest.Wait() + reqStream.wantHeaders(http.Header{":status": {"403"}}) + reqStream.wantData(rejectBody) + reqStream.wantClosed("request is complete") + }) +} + type testServer struct { t testing.TB s *Server diff --git a/internal/http3/transport_test.go b/internal/http3/transport_test.go index 8d12ecfe..271cf262 100644 --- a/internal/http3/transport_test.go +++ b/internal/http3/transport_test.go @@ -158,6 +158,19 @@ func newTestQUICStream(t testing.TB, st *stream) *testQUICStream { } } +func (ts *testQUICStream) wantIdle(reason string) { + ts.t.Helper() + synctest.Wait() + qs := ts.stream.stream + ctx, cancel := context.WithCancel(context.Background()) + cancel() + qs.SetReadContext(ctx) + if _, err := qs.Read(make([]byte, 1)); !errors.Is(err, context.Canceled) { + ts.t.Fatalf("%v: want stream to be idle, but stream has content", reason) + } + qs.SetReadContext(nil) +} + // wantFrameHeader calls readFrameHeader and asserts that the frame is of a given type. func (ts *testQUICStream) wantFrameHeader(reason string, wantType frameType) { ts.t.Helper()