mirror of
https://github.com/golang/net.git
synced 2026-03-31 18:37:08 +09:00
http2/h2c: Respect the req.Context()
When using h2c.NewHandler, the *http.Request object for h2c requests has a .Context() that doesn't inherit from the *http.Server's BaseContext. This is surprising for users of vanilla net/http, and is surprising to users of http2.ConfigureServer; HTTP/1 requests inherit from that BaseContext, and TLS h2 requests inherit from that BaseContext, but cleartext h2c requests don't. So, modify h2c.NewHander to respect that base Context, by way of the hijacked Request's .Context().
This commit is contained in:
@@ -84,14 +84,20 @@ func (s h2cHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
s.s.ServeConn(conn, &http2.ServeConnOpts{Handler: s.Handler})
|
||||
s.s.ServeConn(conn, &http2.ServeConnOpts{
|
||||
Context: r.Context(),
|
||||
Handler: s.Handler,
|
||||
})
|
||||
return
|
||||
}
|
||||
// Handle Upgrade to h2c (RFC 7540 Section 3.2)
|
||||
if conn, err := h2cUpgrade(w, r); err == nil {
|
||||
defer conn.Close()
|
||||
|
||||
s.s.ServeConn(conn, &http2.ServeConnOpts{Handler: s.Handler})
|
||||
s.s.ServeConn(conn, &http2.ServeConnOpts{
|
||||
Context: r.Context(),
|
||||
Handler: s.Handler,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -7,9 +7,14 @@ package h2c
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
@@ -56,3 +61,46 @@ func ExampleNewHandler() {
|
||||
}
|
||||
log.Fatal(h1s.ListenAndServe())
|
||||
}
|
||||
|
||||
func TestContext(t *testing.T) {
|
||||
baseCtx := context.WithValue(context.Background(), "testkey", "testvalue")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.ProtoMajor != 2 {
|
||||
t.Errorf("Request wasn't handled by h2c. Got ProtoMajor=%v", r.ProtoMajor)
|
||||
}
|
||||
if r.Context().Value("testkey") != "testvalue" {
|
||||
t.Errorf("Request doesn't have expected base context: %v", r.Context())
|
||||
}
|
||||
fmt.Fprint(w, "Hello world")
|
||||
})
|
||||
|
||||
h2s := &http2.Server{}
|
||||
h1s := httptest.NewUnstartedServer(NewHandler(handler, h2s))
|
||||
h1s.Config.BaseContext = func(_ net.Listener) context.Context {
|
||||
return baseCtx
|
||||
}
|
||||
h1s.Start()
|
||||
defer h1s.Close()
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http2.Transport{
|
||||
AllowHTTP: true,
|
||||
DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
|
||||
return net.Dial(network, addr)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Get(h1s.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user