http2: add server-side trailer support

Change-Id: I39dbf0cdeee0123b6c6efff1fc6854bcedb94753
Reviewed-on: https://go-review.googlesource.com/17878
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
Blake Mizerany
2015-12-15 17:33:14 -08:00
committed by Brad Fitzpatrick
parent c24de9d546
commit b4be494138
3 changed files with 121 additions and 29 deletions

View File

@@ -46,6 +46,7 @@ import (
"log"
"net"
"net/http"
"net/textproto"
"net/url"
"runtime"
"strconv"
@@ -1877,6 +1878,7 @@ type responseWriterState struct {
// mutated by http.Handler goroutine:
handlerHeader http.Header // nil until called
snapHeader http.Header // snapshot of handlerHeader at WriteHeader time
trailers []string // set in writeChunk
status int // status code passed to WriteHeader
wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet.
sentHeader bool // have we sent the header frame?
@@ -1893,6 +1895,21 @@ type chunkWriter struct{ rws *responseWriterState }
func (cw chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) }
func (rws *responseWriterState) hasTrailers() bool { return len(rws.trailers) != 0 }
// declareTrailer is called for each Trailer header when the
// response header is written. It notes that a header will need to be
// written in the trailers at the end of the response.
func (rws *responseWriterState) declareTrailer(k string) {
k = http.CanonicalHeaderKey(k)
switch k {
case "Transfer-Encoding", "Content-Length", "Trailer":
// Forbidden by RFC 2616 14.40.
return
}
rws.trailers = append(rws.trailers, k)
}
// writeChunk writes chunks from the bufio.Writer. But because
// bufio.Writer may bypass its chunking, sometimes p may be
// arbitrarily large.
@@ -1903,6 +1920,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
if !rws.wroteHeader {
rws.writeHeader(200)
}
isHeadResp := rws.req.Method == "HEAD"
if !rws.sentHeader {
rws.sentHeader = true
@@ -1928,7 +1946,12 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
// TODO(bradfitz): be faster here, like net/http? measure.
date = time.Now().UTC().Format(http.TimeFormat)
}
endStream := (rws.handlerDone && len(p) == 0) || isHeadResp
for _, v := range rws.snapHeader["Trailer"] {
foreachHeaderElement(v, rws.declareTrailer)
}
endStream := (rws.handlerDone && !rws.hasTrailers() && len(p) == 0) || isHeadResp
err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{
streamID: rws.stream.id,
httpResCode: rws.status,
@@ -1952,8 +1975,22 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
return 0, nil
}
if err := rws.conn.writeDataFromHandler(rws.stream, p, rws.handlerDone); err != nil {
return 0, err
endStream := rws.handlerDone && !rws.hasTrailers()
if len(p) > 0 || endStream {
// only send a 0 byte DATA frame if we're ending the stream.
if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil {
return 0, err
}
}
if rws.handlerDone && rws.hasTrailers() {
err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{
streamID: rws.stream.id,
h: rws.handlerHeader,
trailers: rws.trailers,
endStream: true,
})
return len(p), err
}
return len(p), nil
}
@@ -2083,3 +2120,21 @@ func (w *responseWriter) handlerDone() {
w.rws = nil
responseWriterStatePool.Put(rws)
}
// foreachHeaderElement splits v according to the "#rule" construction
// in RFC 2616 section 2.1 and calls fn for each non-empty element.
func foreachHeaderElement(v string, fn func(string)) {
v = textproto.TrimString(v)
if v == "" {
return
}
if !strings.Contains(v, ",") {
fn(v)
return
}
for _, f := range strings.Split(v, ",") {
if f = textproto.TrimString(f); f != "" {
fn(f)
}
}
}

View File

@@ -2515,17 +2515,32 @@ func TestServerReadsTrailers(t *testing.T) {
}
// test that a server handler can send trailers
func TestServerWritesTrailers(t *testing.T) {
t.Skip("known failing test; see golang.org/issue/13557")
func TestServerWritesTrailers_WithFlush(t *testing.T) { testServerWritesTrailers(t, true) }
func TestServerWritesTrailers_WithoutFlush(t *testing.T) { testServerWritesTrailers(t, false) }
func testServerWritesTrailers(t *testing.T, withFlush bool) {
// See https://httpwg.github.io/specs/rfc7540.html#rfc.section.8.1.3
testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
w.Header().Add("Trailer", "Server-Trailer-C")
// TODO: decide if the server should filter these while
// writing the Trailer header in the response. Currently it
// appears net/http doesn't do this for http/1.1
w.Header().Add("Trailer", "Transfer-Encoding, Content-Length, Trailer") // filtered
w.Header().Set("Foo", "Bar")
w.Header().Set("Content-Length", "5")
io.WriteString(w, "Hello")
w.(http.Flusher).Flush()
if withFlush {
w.(http.Flusher).Flush()
}
w.Header().Set("Server-Trailer-A", "valuea")
w.Header().Set("Server-Trailer-C", "valuec") // skipping B
w.Header().Set("Server-Surpise", "surprise! this isn't predeclared!")
w.Header().Set("Transfer-Encoding", "should not be included; Forbidden by RFC 2616 14.40")
w.Header().Set("Content-Length", "should not be included; Forbidden by RFC 2616 14.40")
w.Header().Set("Trailer", "should not be included; Forbidden by RFC 2616 14.40")
return nil
}, func(st *serverTester) {
getSlash(st)
@@ -2542,7 +2557,9 @@ func TestServerWritesTrailers(t *testing.T) {
{"foo", "Bar"},
{"trailer", "Server-Trailer-A, Server-Trailer-B"},
{"trailer", "Server-Trailer-C"},
{"trailer", "Transfer-Encoding, Content-Length, Trailer"},
{"content-type", "text/plain; charset=utf-8"},
{"content-length", "5"},
}
if !reflect.DeepEqual(goth, wanth) {
t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
@@ -2561,8 +2578,14 @@ func TestServerWritesTrailers(t *testing.T) {
if !tf.HeadersEnded() {
t.Fatalf("trailers HEADERS lacked END_HEADERS")
}
pairs := st.decodeHeader(tf.HeaderBlockFragment())
t.Logf("Got: %v", pairs)
wanth = [][2]string{
{"server-trailer-a", "valuea"},
{"server-trailer-c", "valuec"},
}
goth = st.decodeHeader(tf.HeaderBlockFragment())
if !reflect.DeepEqual(goth, wanth) {
t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
}
})
}

View File

@@ -123,11 +123,12 @@ func (writeSettingsAck) writeFrame(ctx writeContext) error {
}
// writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames
// for HTTP response headers from a server handler.
// for HTTP response headers or trailers from a server handler.
type writeResHeaders struct {
streamID uint32
httpResCode int
httpResCode int // 0 means no ":status" line
h http.Header // may be nil
trailers []string // if non-nil, which keys of h to write. nil means all.
endStream bool
date string
@@ -138,26 +139,16 @@ type writeResHeaders struct {
func (w *writeResHeaders) writeFrame(ctx writeContext) error {
enc, buf := ctx.HeaderEncoder()
buf.Reset()
enc.WriteField(hpack.HeaderField{Name: ":status", Value: httpCodeString(w.httpResCode)})
// TODO: garbage. pool sorters like http1? hot path for 1 key?
keys := make([]string, 0, len(w.h))
for k := range w.h {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
vv := w.h[k]
k = lowerHeader(k)
isTE := k == "transfer-encoding"
for _, v := range vv {
// TODO: more of "8.1.2.2 Connection-Specific Header Fields"
if isTE && v != "trailers" {
continue
}
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
}
if w.httpResCode != 0 {
enc.WriteField(hpack.HeaderField{
Name: ":status",
Value: httpCodeString(w.httpResCode),
})
}
encodeHeaders(enc, w.h, w.trailers)
if w.contentType != "" {
enc.WriteField(hpack.HeaderField{Name: "content-type", Value: w.contentType})
}
@@ -169,7 +160,7 @@ func (w *writeResHeaders) writeFrame(ctx writeContext) error {
}
headerBlock := buf.Bytes()
if len(headerBlock) == 0 {
if len(headerBlock) == 0 && w.trailers == nil {
panic("unexpected empty hpack")
}
@@ -232,3 +223,26 @@ type writeWindowUpdate struct {
func (wu writeWindowUpdate) writeFrame(ctx writeContext) error {
return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n)
}
func encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) {
// TODO: garbage. pool sorters like http1? hot path for 1 key?
if keys == nil {
keys = make([]string, 0, len(h))
for k := range h {
keys = append(keys, k)
}
sort.Strings(keys)
}
for _, k := range keys {
vv := h[k]
k = lowerHeader(k)
isTE := k == "transfer-encoding"
for _, v := range vv {
// TODO: more of "8.1.2.2 Connection-Specific Header Fields"
if isTE && v != "trailers" {
continue
}
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
}
}
}