mirror of
https://github.com/golang/net.git
synced 2026-03-31 10:27:08 +09:00
internal/http3: add HTTP 103 Early Hints support to ClientConn
RoundTrip will now call httptrace.ClientTrace.Got1xxResponse, if any, when receiving 1xx status response from a peer. This allows our client and server to use HTTP 103 end-to-end. Got100Continue and Wait100Continue have also been added to RoundTrip as they are nearby. The rest of httptrace.ClientTrace will be added in the future. For golang/go#70914 Change-Id: Ia7ef7dd026a5390225149da3d76b06a2a372c009 Reviewed-on: https://go-review.googlesource.com/c/net/+/749265 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Damien Neil <dneil@google.com> Reviewed-by: Nicholas Husin <husin@google.com>
This commit is contained in:
committed by
Nicholas Husin
parent
591bdf35bc
commit
6267c6c4c8
@@ -8,6 +8,8 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
@@ -28,6 +30,8 @@ type roundTripState struct {
|
||||
// Response.Body, provided to the caller.
|
||||
respBody io.ReadCloser
|
||||
|
||||
trace *httptrace.ClientTrace
|
||||
|
||||
errOnce sync.Once
|
||||
err error
|
||||
}
|
||||
@@ -60,6 +64,28 @@ func (rt *roundTripState) closeReqBody() {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Set up the rest of the hooks that might be in rt.trace.
|
||||
func (rt *roundTripState) maybeCallGot1xxResponse(status int, h http.Header) error {
|
||||
if rt.trace == nil || rt.trace.Got1xxResponse == nil {
|
||||
return nil
|
||||
}
|
||||
return rt.trace.Got1xxResponse(status, textproto.MIMEHeader(h))
|
||||
}
|
||||
|
||||
func (rt *roundTripState) maybeCallGot100Continue() {
|
||||
if rt.trace == nil || rt.trace.Got100Continue == nil {
|
||||
return
|
||||
}
|
||||
rt.trace.Got100Continue()
|
||||
}
|
||||
|
||||
func (rt *roundTripState) maybeCallWait100Continue() {
|
||||
if rt.trace == nil || rt.trace.Wait100Continue == nil {
|
||||
return
|
||||
}
|
||||
rt.trace.Wait100Continue()
|
||||
}
|
||||
|
||||
// RoundTrip sends a request on the connection.
|
||||
func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) {
|
||||
// Each request gets its own QUIC stream.
|
||||
@@ -68,8 +94,9 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error)
|
||||
return nil, err
|
||||
}
|
||||
rt := &roundTripState{
|
||||
cc: cc,
|
||||
st: st,
|
||||
cc: cc,
|
||||
st: st,
|
||||
trace: httptrace.ContextClientTrace(req.Context()),
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
@@ -113,7 +140,9 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error)
|
||||
|
||||
var bodyAndTrailerWritten bool
|
||||
is100ContinueReq := httpguts.HeaderValuesContainsToken(req.Header["Expect"], "100-continue")
|
||||
if !is100ContinueReq && !bodyAndTrailerWritten {
|
||||
if is100ContinueReq {
|
||||
rt.maybeCallWait100Continue()
|
||||
} else {
|
||||
bodyAndTrailerWritten = true
|
||||
go cc.writeBodyAndTrailer(rt, req)
|
||||
}
|
||||
@@ -131,10 +160,14 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if statusCode >= 100 && statusCode < 199 {
|
||||
// TODO: Handle 1xx responses.
|
||||
// TODO: Handle 1xx responses.
|
||||
if isInfoStatus(statusCode) {
|
||||
if err := rt.maybeCallGot1xxResponse(statusCode, h); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch statusCode {
|
||||
case 100:
|
||||
rt.maybeCallGot100Continue()
|
||||
if is100ContinueReq && !bodyAndTrailerWritten {
|
||||
bodyAndTrailerWritten = true
|
||||
go cc.writeBodyAndTrailer(rt, req)
|
||||
|
||||
@@ -9,6 +9,10 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/textproto"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
@@ -354,13 +358,27 @@ func TestRoundTripRequestBodyErrorAfterHeaders(t *testing.T) {
|
||||
|
||||
func TestRoundTripExpect100Continue(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
var callCount1xx, callCount100, callCount100Wait int
|
||||
trace := &httptrace.ClientTrace{
|
||||
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
|
||||
callCount1xx++
|
||||
return nil
|
||||
},
|
||||
Got100Continue: func() {
|
||||
callCount100++
|
||||
},
|
||||
Wait100Continue: func() {
|
||||
callCount100Wait++
|
||||
},
|
||||
}
|
||||
|
||||
tc := newTestClientConn(t)
|
||||
tc.greet()
|
||||
clientBody := []byte("client's body that will be sent later")
|
||||
serverBody := []byte("server's body")
|
||||
|
||||
// Client sends an Expect: 100-continue request.
|
||||
req, _ := http.NewRequest("PUT", "https://example.tld/", bytes.NewBuffer(clientBody))
|
||||
req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(t.Context(), trace), "GET", "https://example.tld/", bytes.NewBuffer(clientBody))
|
||||
req.Header = http.Header{"Expect": {"100-continue"}}
|
||||
rt := tc.roundTrip(req)
|
||||
st := tc.wantStream(streamTypeRequest)
|
||||
@@ -387,16 +405,35 @@ func TestRoundTripExpect100Continue(t *testing.T) {
|
||||
// Client receives the response from server.
|
||||
rt.wantStatus(200)
|
||||
rt.wantBody(serverBody)
|
||||
|
||||
gotCount := []int{callCount1xx, callCount100, callCount100Wait}
|
||||
if !slices.Equal(gotCount, []int{1, 1, 1}) {
|
||||
t.Errorf("Got1xxResponse, Got100Continue, and Wait100Continue was called %v times respectively, want [1 1 1]", gotCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoundTripExpect100ContinueRejected(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
var callCount1xx, callCount100, callCount100Wait int
|
||||
trace := &httptrace.ClientTrace{
|
||||
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
|
||||
callCount1xx++
|
||||
return nil
|
||||
},
|
||||
Got100Continue: func() {
|
||||
callCount100++
|
||||
},
|
||||
Wait100Continue: func() {
|
||||
callCount100Wait++
|
||||
},
|
||||
}
|
||||
|
||||
tc := newTestClientConn(t)
|
||||
tc.greet()
|
||||
|
||||
// Client sends an Expect: 100-continue request.
|
||||
req, _ := http.NewRequest("PUT", "https://example.tld/", bytes.NewBufferString("client's body"))
|
||||
req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(t.Context(), trace), "GET", "https://example.tld/", bytes.NewBufferString("client's body"))
|
||||
req.Header = http.Header{"Expect": {"100-continue"}}
|
||||
rt := tc.roundTrip(req)
|
||||
st := tc.wantStream(streamTypeRequest)
|
||||
@@ -416,6 +453,11 @@ func TestRoundTripExpect100ContinueRejected(t *testing.T) {
|
||||
|
||||
rt.wantStatus(200)
|
||||
rt.wantBody(serverBody)
|
||||
|
||||
gotCount := []int{callCount1xx, callCount100, callCount100Wait}
|
||||
if !slices.Equal(gotCount, []int{0, 0, 1}) {
|
||||
t.Errorf("Got1xxResponse, Got100Continue, and Wait100Continue was called %v times respectively, want [0 0 1]", gotCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -633,3 +675,58 @@ func TestRoundTripReadTrailerNoBody(t *testing.T) {
|
||||
st.wantClosed("request is complete")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoundTrip103EarlyHints(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
firstHeader := http.Header{
|
||||
":status": {"103"},
|
||||
"Link": {"</style.css>; rel=preload; as=style"},
|
||||
}
|
||||
secondHeader := http.Header{
|
||||
":status": {"103"},
|
||||
"Link": {"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"},
|
||||
}
|
||||
|
||||
var respCounter int
|
||||
trace := &httptrace.ClientTrace{
|
||||
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
|
||||
var wantHeader textproto.MIMEHeader
|
||||
switch respCounter {
|
||||
case 0:
|
||||
wantHeader = textproto.MIMEHeader(firstHeader)
|
||||
case 1:
|
||||
wantHeader = textproto.MIMEHeader(secondHeader)
|
||||
default:
|
||||
t.Error("Unexpected 1xx response")
|
||||
}
|
||||
wantHeader.Del(":status")
|
||||
if !reflect.DeepEqual(header, wantHeader) {
|
||||
t.Errorf("got %v early hints header, want %v", header, wantHeader)
|
||||
}
|
||||
respCounter++
|
||||
return nil
|
||||
},
|
||||
}
|
||||
req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(t.Context(), trace), "GET", "https://example.tld/", nil)
|
||||
|
||||
tc := newTestClientConn(t)
|
||||
tc.greet()
|
||||
rt := tc.roundTrip(req)
|
||||
st := tc.wantStream(streamTypeRequest)
|
||||
|
||||
st.wantHeaders(nil)
|
||||
st.writeHeaders(firstHeader)
|
||||
st.writeHeaders(secondHeader)
|
||||
|
||||
st.writeHeaders(http.Header{
|
||||
":status": {"200"},
|
||||
})
|
||||
body := []byte("some body")
|
||||
st.writeData(body)
|
||||
st.stream.stream.CloseWrite()
|
||||
|
||||
rt.wantStatus(200)
|
||||
rt.wantBody(body)
|
||||
st.wantClosed("request is complete")
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user