http2/h2c: Respect the req.Context()

When using h2c.NewHandler, the *http.Request object for h2c requests has a
.Context() that doesn't inherit from the *http.Server's BaseContext.  This
is surprising for users of vanilla net/http, and is surprising to users of
http2.ConfigureServer; HTTP/1 requests inherit from that BaseContext, and
TLS h2 requests inherit from that BaseContext, but cleartext h2c requests
don't.

So, modify h2c.NewHander to respect that base Context, by way of the
hijacked Request's .Context().
This commit is contained in:
Luke Shumaker
2020-12-15 15:41:26 -07:00
parent 5f4716e947
commit 821b2070f7
2 changed files with 56 additions and 2 deletions

View File

@@ -84,14 +84,20 @@ func (s h2cHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
defer conn.Close()
s.s.ServeConn(conn, &http2.ServeConnOpts{Handler: s.Handler})
s.s.ServeConn(conn, &http2.ServeConnOpts{
Context: r.Context(),
Handler: s.Handler,
})
return
}
// Handle Upgrade to h2c (RFC 7540 Section 3.2)
if conn, err := h2cUpgrade(w, r); err == nil {
defer conn.Close()
s.s.ServeConn(conn, &http2.ServeConnOpts{Handler: s.Handler})
s.s.ServeConn(conn, &http2.ServeConnOpts{
Context: r.Context(),
Handler: s.Handler,
})
return
}

View File

@@ -7,9 +7,14 @@ package h2c
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"net/http/httptest"
"testing"
"golang.org/x/net/http2"
@@ -56,3 +61,46 @@ func ExampleNewHandler() {
}
log.Fatal(h1s.ListenAndServe())
}
func TestContext(t *testing.T) {
baseCtx := context.WithValue(context.Background(), "testkey", "testvalue")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.ProtoMajor != 2 {
t.Errorf("Request wasn't handled by h2c. Got ProtoMajor=%v", r.ProtoMajor)
}
if r.Context().Value("testkey") != "testvalue" {
t.Errorf("Request doesn't have expected base context: %v", r.Context())
}
fmt.Fprint(w, "Hello world")
})
h2s := &http2.Server{}
h1s := httptest.NewUnstartedServer(NewHandler(handler, h2s))
h1s.Config.BaseContext = func(_ net.Listener) context.Context {
return baseCtx
}
h1s.Start()
defer h1s.Close()
client := &http.Client{
Transport: &http2.Transport{
AllowHTTP: true,
DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
return net.Dial(network, addr)
},
},
}
resp, err := client.Get(h1s.URL)
if err != nil {
t.Fatal(err)
}
_, err = ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if err := resp.Body.Close(); err != nil {
t.Fatal(err)
}
}