mirror of
https://github.com/golang/net.git
synced 2026-03-31 10:27:08 +09:00
internal/http3: add Expect: 100-continue support to Server
When serving a request containing the "Expect: 100-continue" header, Server will now send an HTTP 100 status automatically if the request body is read from within the server handler. For golang/go#70914 Change-Id: Ib8185170deabf777a02487a1ded6671db720df51 Reviewed-on: https://go-review.googlesource.com/c/net/+/742520 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Damien Neil <dneil@google.com> Reviewed-by: Nicholas Husin <husin@google.com>
This commit is contained in:
committed by
Nicholas Husin
parent
af0c9df79d
commit
73fe7011ad
@@ -57,6 +57,10 @@ type bodyReader struct {
|
||||
mu sync.Mutex
|
||||
remain int64
|
||||
err error
|
||||
// If not nil, the body contains an "Expect: 100-continue" header, and
|
||||
// send100Continue should be called when Read is invoked for the first
|
||||
// time.
|
||||
send100Continue func()
|
||||
}
|
||||
|
||||
func (r *bodyReader) Read(p []byte) (n int, err error) {
|
||||
@@ -65,6 +69,10 @@ func (r *bodyReader) Read(p []byte) (n int, err error) {
|
||||
// Use a mutex here to provide the same behavior.
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.send100Continue != nil {
|
||||
r.send100Continue()
|
||||
r.send100Continue = nil
|
||||
}
|
||||
if r.err != nil {
|
||||
return 0, r.err
|
||||
}
|
||||
|
||||
@@ -7,11 +7,11 @@ package http3
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
"golang.org/x/net/internal/httpcommon"
|
||||
"golang.org/x/net/quic"
|
||||
)
|
||||
|
||||
@@ -157,54 +157,81 @@ func (sc *serverConn) handlePushStream(*stream) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *serverConn) parseRequest(st *stream) (*http.Request, error) {
|
||||
req := &http.Request{
|
||||
URL: &url.URL{},
|
||||
Proto: "HTTP/3.0",
|
||||
ProtoMajor: 3,
|
||||
RemoteAddr: sc.qconn.RemoteAddr().String(),
|
||||
}
|
||||
type pseudoHeader struct {
|
||||
method string
|
||||
scheme string
|
||||
path string
|
||||
authority string
|
||||
}
|
||||
|
||||
func (sc *serverConn) parseHeader(st *stream) (http.Header, pseudoHeader, error) {
|
||||
ftype, err := st.readFrameHeader()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, pseudoHeader{}, err
|
||||
}
|
||||
if ftype != frameTypeHeaders {
|
||||
return nil, err
|
||||
return nil, pseudoHeader{}, err
|
||||
}
|
||||
req.Header = make(http.Header)
|
||||
header := make(http.Header)
|
||||
var pHeader pseudoHeader
|
||||
var dec qpackDecoder
|
||||
if err := dec.decode(st, func(_ indexType, name, value string) error {
|
||||
switch name {
|
||||
case ":method":
|
||||
req.Method = value
|
||||
pHeader.method = value
|
||||
case ":scheme":
|
||||
req.URL.Scheme = value
|
||||
pHeader.scheme = value
|
||||
case ":path":
|
||||
req.URL.Path = value
|
||||
pHeader.path = value
|
||||
case ":authority":
|
||||
req.URL.Host = value
|
||||
pHeader.authority = value
|
||||
default:
|
||||
req.Header.Add(name, value)
|
||||
header.Add(name, value)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
return nil, pseudoHeader{}, err
|
||||
}
|
||||
if err := st.endFrame(); err != nil {
|
||||
return nil, err
|
||||
return nil, pseudoHeader{}, err
|
||||
}
|
||||
req.Body = &bodyReader{
|
||||
st: st,
|
||||
remain: -1,
|
||||
}
|
||||
return req, nil
|
||||
return header, pHeader, nil
|
||||
}
|
||||
|
||||
func (sc *serverConn) handleRequestStream(st *stream) error {
|
||||
req, err := sc.parseRequest(st)
|
||||
header, pHeader, err := sc.parseHeader(st)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reqInfo := httpcommon.NewServerRequest(httpcommon.ServerRequestParam{
|
||||
Method: pHeader.method,
|
||||
Scheme: pHeader.scheme,
|
||||
Authority: pHeader.authority,
|
||||
Path: pHeader.path,
|
||||
Header: header,
|
||||
})
|
||||
if reqInfo.InvalidReason != "" {
|
||||
return &streamError{
|
||||
code: errH3MessageError,
|
||||
message: reqInfo.InvalidReason,
|
||||
}
|
||||
}
|
||||
req := &http.Request{
|
||||
Proto: "HTTP/3.0",
|
||||
Method: pHeader.method,
|
||||
Host: pHeader.authority,
|
||||
URL: reqInfo.URL,
|
||||
RequestURI: reqInfo.RequestURI,
|
||||
Trailer: reqInfo.Trailer,
|
||||
ProtoMajor: 3,
|
||||
RemoteAddr: sc.qconn.RemoteAddr().String(),
|
||||
Body: &bodyReader{
|
||||
st: st,
|
||||
remain: -1,
|
||||
},
|
||||
Header: header,
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
rw := &responseWriter{
|
||||
@@ -219,6 +246,12 @@ func (sc *serverConn) handleRequestStream(st *stream) error {
|
||||
},
|
||||
}
|
||||
defer rw.close()
|
||||
if reqInfo.NeedsContinue {
|
||||
req.Body.(*bodyReader).send100Continue = func() {
|
||||
rw.WriteHeader(http.StatusContinue)
|
||||
rw.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: handle panic coming from the HTTP handler.
|
||||
sc.handler.ServeHTTP(rw, req)
|
||||
@@ -238,11 +271,10 @@ func (sc *serverConn) abort(err error) {
|
||||
}
|
||||
|
||||
type responseWriter struct {
|
||||
st *stream
|
||||
bw *bodyWriter
|
||||
mu sync.Mutex
|
||||
headers http.Header
|
||||
// TODO: support 1xx status
|
||||
st *stream
|
||||
bw *bodyWriter
|
||||
mu sync.Mutex
|
||||
headers http.Header
|
||||
wroteHeader bool // Non-1xx header has been (logically) written.
|
||||
isHeadResp bool // response is for a HEAD request.
|
||||
}
|
||||
@@ -278,7 +310,9 @@ func (rw *responseWriter) writeHeaderLockedOnce(statusCode int) {
|
||||
rw.st.writeVarint(int64(frameTypeHeaders))
|
||||
rw.st.writeVarint(int64(len(encHeaders)))
|
||||
rw.st.Write(encHeaders)
|
||||
rw.wroteHeader = true
|
||||
if statusCode >= http.StatusOK {
|
||||
rw.wroteHeader = true
|
||||
}
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(statusCode int) {
|
||||
|
||||
@@ -8,8 +8,11 @@ package http3
|
||||
|
||||
import (
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
"time"
|
||||
@@ -18,6 +21,22 @@ import (
|
||||
"golang.org/x/net/quic"
|
||||
)
|
||||
|
||||
// requestHeader is a helper function to make sure that all required
|
||||
// pseudo-headers exist in an http.Header used for a request. Per
|
||||
// https://www.rfc-editor.org/rfc/rfc9114.html#name-request-pseudo-header-field:
|
||||
// "All HTTP/3 requests MUST include exactly one value for the :method,
|
||||
// :scheme, and :path pseudo-header fields, unless the request is a CONNECT
|
||||
// request;"
|
||||
func requestHeader(h http.Header) http.Header {
|
||||
minimalHeader := http.Header{
|
||||
":method": {"GET"},
|
||||
":scheme": {"https"},
|
||||
":path": {"/"},
|
||||
}
|
||||
maps.Copy(minimalHeader, h)
|
||||
return minimalHeader
|
||||
}
|
||||
|
||||
func TestServerReceivePushStream(t *testing.T) {
|
||||
// "[...] if a server receives a client-initiated push stream,
|
||||
// this MUST be treated as a connection error of type H3_STREAM_CREATION_ERROR."
|
||||
@@ -61,9 +80,9 @@ func TestServerHeader(t *testing.T) {
|
||||
tc.greet()
|
||||
|
||||
reqStream := tc.newStream(streamTypeRequest)
|
||||
reqStream.writeHeaders(http.Header{
|
||||
reqStream.writeHeaders(requestHeader(http.Header{
|
||||
"header-from-client": {"that", "should", "be", "echoed"},
|
||||
})
|
||||
}))
|
||||
synctest.Wait()
|
||||
reqStream.wantHeaders(http.Header{
|
||||
":status": {"204"},
|
||||
@@ -78,9 +97,23 @@ func TestServerPseudoHeader(t *testing.T) {
|
||||
ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Pseudo-headers from client request should populate a specific
|
||||
// field in http.Request, and should not be part of http.Request.Header.
|
||||
if r.Header.Get(":method") != "" || r.Method != "GET" {
|
||||
t.Error("want pseudo-headers from client to be reflected in appropriate fields in http.Request, not in http.Request.Header")
|
||||
if len(r.Header) != 0 {
|
||||
t.Errorf("got %v, want request header to be empty", r.Header)
|
||||
}
|
||||
if r.Method != "GET" {
|
||||
t.Errorf("got %v, want GET method", r.Method)
|
||||
}
|
||||
if r.Host != "fake.tld:1234" {
|
||||
t.Errorf("got %v, want fake.tld:1234", r.Host)
|
||||
}
|
||||
wantURL := &url.URL{
|
||||
Path: "/some/path",
|
||||
RawQuery: "query=value&query2=value2#fragment",
|
||||
}
|
||||
if !reflect.DeepEqual(r.URL, wantURL) {
|
||||
t.Errorf("got %v, want URL to be %v", r.URL, wantURL)
|
||||
}
|
||||
|
||||
// Conversely, server should not be able to set pseudo-headers by
|
||||
// writing to the ResponseWriter's Header.
|
||||
header := w.Header()
|
||||
@@ -91,10 +124,20 @@ func TestServerPseudoHeader(t *testing.T) {
|
||||
tc.greet()
|
||||
|
||||
reqStream := tc.newStream(streamTypeRequest)
|
||||
reqStream.writeHeaders(http.Header{":method": {"GET"}})
|
||||
reqStream.writeHeaders(http.Header{
|
||||
":method": {"GET"},
|
||||
":authority": {"fake.tld:1234"},
|
||||
":scheme": {"https"},
|
||||
":path": {"/some/path?query=value&query2=value2#fragment"},
|
||||
})
|
||||
synctest.Wait()
|
||||
reqStream.wantHeaders(http.Header{":status": {"321"}})
|
||||
reqStream.wantClosed("request is complete")
|
||||
|
||||
reqStream = tc.newStream(streamTypeRequest)
|
||||
reqStream.writeHeaders(http.Header{}) // Missing pseudo-header.
|
||||
synctest.Wait()
|
||||
reqStream.wantError(quic.StreamErrorCode(errH3MessageError))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -112,7 +155,7 @@ func TestServerInvalidHeader(t *testing.T) {
|
||||
tc.greet()
|
||||
|
||||
reqStream := tc.newStream(streamTypeRequest)
|
||||
reqStream.writeHeaders(http.Header{})
|
||||
reqStream.writeHeaders(requestHeader(nil))
|
||||
synctest.Wait()
|
||||
reqStream.wantHeaders(http.Header{
|
||||
":status": {"200"},
|
||||
@@ -137,9 +180,7 @@ func TestServerBody(t *testing.T) {
|
||||
tc.greet()
|
||||
|
||||
reqStream := tc.newStream(streamTypeRequest)
|
||||
reqStream.writeHeaders(http.Header{
|
||||
":path": {"/"},
|
||||
})
|
||||
reqStream.writeHeaders(requestHeader(nil))
|
||||
bodyContent := []byte("some body content that should be echoed")
|
||||
reqStream.writeData(bodyContent)
|
||||
reqStream.stream.stream.CloseWrite()
|
||||
@@ -161,14 +202,14 @@ func TestServerHeadResponseNoBody(t *testing.T) {
|
||||
tc.greet()
|
||||
|
||||
reqStream := tc.newStream(streamTypeRequest)
|
||||
reqStream.writeHeaders(http.Header{":method": {http.MethodGet}})
|
||||
reqStream.writeHeaders(requestHeader(nil))
|
||||
synctest.Wait()
|
||||
reqStream.wantHeaders(http.Header{":status": {"200"}})
|
||||
reqStream.wantData(bodyContent)
|
||||
reqStream.wantClosed("request is complete")
|
||||
|
||||
reqStream = tc.newStream(streamTypeRequest)
|
||||
reqStream.writeHeaders(http.Header{":method": {http.MethodHead}})
|
||||
reqStream.writeHeaders(requestHeader(http.Header{":method": {http.MethodHead}}))
|
||||
synctest.Wait()
|
||||
reqStream.wantHeaders(http.Header{":status": {"200"}})
|
||||
reqStream.wantClosed("request is complete")
|
||||
@@ -184,7 +225,7 @@ func TestServerHandlerEmpty(t *testing.T) {
|
||||
tc.greet()
|
||||
|
||||
reqStream := tc.newStream(streamTypeRequest)
|
||||
reqStream.writeHeaders(http.Header{":method": {http.MethodGet}})
|
||||
reqStream.writeHeaders(requestHeader(nil))
|
||||
synctest.Wait()
|
||||
reqStream.wantHeaders(http.Header{":status": {"200"}})
|
||||
reqStream.wantClosed("request is complete")
|
||||
@@ -208,7 +249,7 @@ func TestServerHandlerFlushing(t *testing.T) {
|
||||
tc.greet()
|
||||
|
||||
reqStream := tc.newStream(streamTypeRequest)
|
||||
reqStream.writeHeaders(http.Header{":method": {http.MethodGet}})
|
||||
reqStream.writeHeaders(requestHeader(nil))
|
||||
synctest.Wait()
|
||||
|
||||
respBody := make([]byte, 100)
|
||||
@@ -216,7 +257,7 @@ func TestServerHandlerFlushing(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
synctest.Wait()
|
||||
if n, err := reqStream.Read(respBody); err == nil {
|
||||
t.Errorf("want no message yet, got %v bytes read", n)
|
||||
t.Errorf("got %v bytes read, want no message yet", n)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -228,7 +269,7 @@ func TestServerHandlerFlushing(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
synctest.Wait()
|
||||
if _, err := reqStream.Read(respBody); err != io.EOF {
|
||||
t.Errorf("expected EOF, got err: %v", err)
|
||||
t.Errorf("got err %v, want EOF", err)
|
||||
}
|
||||
reqStream.wantClosed("request is complete")
|
||||
})
|
||||
@@ -250,7 +291,7 @@ func TestServerHandlerStreaming(t *testing.T) {
|
||||
tc.greet()
|
||||
|
||||
reqStream := tc.newStream(streamTypeRequest)
|
||||
reqStream.writeHeaders(http.Header{":method": {http.MethodGet}})
|
||||
reqStream.writeHeaders(requestHeader(nil))
|
||||
synctest.Wait()
|
||||
reqStream.wantHeaders(http.Header{":status": {"200"}})
|
||||
|
||||
@@ -263,6 +304,75 @@ func TestServerHandlerStreaming(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestServerExpect100Continue(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
streamIdle := make(chan bool)
|
||||
ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Expect: 100-continue header should not be accessible from the
|
||||
// server handler.
|
||||
if len(r.Header) > 0 {
|
||||
t.Errorf("got %v, want request header to be empty", r.Header)
|
||||
}
|
||||
// Reading the body will cause the server to call w.WriteHeader(100).
|
||||
<-streamIdle
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Implicitly calls w.WriteHeader(200) since non-1XX status code
|
||||
// has been sent yet so far.
|
||||
w.Write(body)
|
||||
}))
|
||||
tc := ts.connect()
|
||||
tc.greet()
|
||||
|
||||
// Client sends an Expect: 100-continue request.
|
||||
reqStream := tc.newStream(streamTypeRequest)
|
||||
reqStream.writeHeaders(requestHeader(http.Header{
|
||||
"Expect": {"100-continue"},
|
||||
}))
|
||||
|
||||
reqStream.wantIdle("stream is idle until server sends an HTTP 100 status")
|
||||
streamIdle <- true
|
||||
// Wait until server responds with HTTP status 100 before sending the
|
||||
// body.
|
||||
synctest.Wait()
|
||||
reqStream.wantHeaders(http.Header{":status": {"100"}})
|
||||
body := []byte("body that will be echoed back if we get status 100")
|
||||
reqStream.writeData(body)
|
||||
reqStream.stream.stream.CloseWrite()
|
||||
|
||||
// Receive the server's response after sending the body.
|
||||
reqStream.wantHeaders(http.Header{":status": {"200"}})
|
||||
reqStream.wantData(body)
|
||||
reqStream.wantClosed("request is complete")
|
||||
})
|
||||
}
|
||||
|
||||
func TestServerExpect100ContinueRejected(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
rejectBody := []byte("not allowed")
|
||||
ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(403)
|
||||
w.Write(rejectBody)
|
||||
}))
|
||||
tc := ts.connect()
|
||||
tc.greet()
|
||||
|
||||
// Client sends an Expect: 100-continue request.
|
||||
reqStream := tc.newStream(streamTypeRequest)
|
||||
reqStream.writeHeaders(requestHeader(http.Header{
|
||||
"Expect": {"100-continue"},
|
||||
}))
|
||||
|
||||
// Server rejects it.
|
||||
synctest.Wait()
|
||||
reqStream.wantHeaders(http.Header{":status": {"403"}})
|
||||
reqStream.wantData(rejectBody)
|
||||
reqStream.wantClosed("request is complete")
|
||||
})
|
||||
}
|
||||
|
||||
type testServer struct {
|
||||
t testing.TB
|
||||
s *Server
|
||||
|
||||
@@ -158,6 +158,19 @@ func newTestQUICStream(t testing.TB, st *stream) *testQUICStream {
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *testQUICStream) wantIdle(reason string) {
|
||||
ts.t.Helper()
|
||||
synctest.Wait()
|
||||
qs := ts.stream.stream
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
qs.SetReadContext(ctx)
|
||||
if _, err := qs.Read(make([]byte, 1)); !errors.Is(err, context.Canceled) {
|
||||
ts.t.Fatalf("%v: want stream to be idle, but stream has content", reason)
|
||||
}
|
||||
qs.SetReadContext(nil)
|
||||
}
|
||||
|
||||
// wantFrameHeader calls readFrameHeader and asserts that the frame is of a given type.
|
||||
func (ts *testQUICStream) wantFrameHeader(reason string, wantType frameType) {
|
||||
ts.t.Helper()
|
||||
|
||||
Reference in New Issue
Block a user