diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go index 0f57d3e3..e587421e 100644 --- a/http2/clientconn_test.go +++ b/http2/clientconn_test.go @@ -437,7 +437,7 @@ func (rt *testRoundTrip) wantStatus(want int) { } } -// body reads the contents of the response body. +// readBody reads the contents of the response body. func (rt *testRoundTrip) readBody() ([]byte, error) { t := rt.t t.Helper() diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index d52c8455..58253d0b 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -11,6 +11,7 @@ import ( "strconv" "sync" + "golang.org/x/net/http/httpguts" "golang.org/x/net/internal/httpcommon" ) @@ -113,14 +114,18 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) return nil, err } + is100ContinueReq := httpguts.HeaderValuesContainsToken(req.Header["Expect"], "100-continue") if encr.HasBody { - // TODO: Defer sending the request body when "Expect: 100-continue" is set. rt.reqBody = req.Body rt.reqBodyWriter.st = st rt.reqBodyWriter.remain = contentLength rt.reqBodyWriter.flush = true rt.reqBodyWriter.name = "request" - go copyRequestBody(rt) + + if !is100ContinueReq { + encr.HasBody = false + go copyRequestBody(rt) + } } // Read the response headers. @@ -138,7 +143,19 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) if statusCode >= 100 && statusCode < 199 { // TODO: Handle 1xx responses. - continue + switch statusCode { + case 100: + if encr.HasBody && is100ContinueReq { + encr.HasBody = false + go copyRequestBody(rt) + continue + } + // If we did not send "Expect: 100-continue" request but + // received status 100 anyways, just continue per usual and + // let the caller decide what to do with the response. + default: + continue + } } // We have the response headers. diff --git a/internal/http3/roundtrip_test.go b/internal/http3/roundtrip_test.go index 230ff82c..efbe1052 100644 --- a/internal/http3/roundtrip_test.go +++ b/internal/http3/roundtrip_test.go @@ -352,3 +352,70 @@ func TestRoundTripRequestBodyErrorAfterHeaders(t *testing.T) { } }) } + +func TestRoundTripExpect100Continue(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + tc := newTestClientConn(t) + tc.greet() + clientBody := []byte("client's body that will be sent later") + serverBody := []byte("server's body") + + // Client sends an Expect: 100-continue request. + req, _ := http.NewRequest("PUT", "https://example.tld/", bytes.NewBuffer(clientBody)) + req.Header = http.Header{"Expect": {"100-continue"}} + rt := tc.roundTrip(req) + st := tc.wantStream(streamTypeRequest) + + // Server reads the header. + st.wantHeaders(nil) + st.wantIdle("client has yet to send its body") + + // Server responds with HTTP status 100. + st.writeHeaders(http.Header{ + ":status": []string{"100"}, + }) + + // Client sends its body after receiving HTTP status 100 response. + st.wantData(clientBody) + + // The server sends its response after getting the client's body. + st.writeHeaders(http.Header{ + ":status": []string{"200"}, + }) + st.writeData(serverBody) + st.stream.stream.CloseWrite() + + // Client receives the response from server. + rt.wantStatus(200) + rt.wantBody(serverBody) + }) +} + +func TestRoundTripExpect100ContinueRejected(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + tc := newTestClientConn(t) + tc.greet() + + // Client sends an Expect: 100-continue request. + req, _ := http.NewRequest("PUT", "https://example.tld/", bytes.NewBufferString("client's body")) + req.Header = http.Header{"Expect": {"100-continue"}} + rt := tc.roundTrip(req) + st := tc.wantStream(streamTypeRequest) + + // Server reads the header. + st.wantHeaders(nil) + st.wantIdle("client has yet to send its body") + + // Server rejects it. + st.writeHeaders(http.Header{ + ":status": []string{"200"}, + }) + st.wantIdle("client does not send its body without getting status 100") + serverBody := []byte("server's body") + st.writeData(serverBody) + st.stream.stream.CloseWrite() + + rt.wantStatus(200) + rt.wantBody(serverBody) + }) +} diff --git a/internal/http3/transport_test.go b/internal/http3/transport_test.go index 271cf262..c9858e4a 100644 --- a/internal/http3/transport_test.go +++ b/internal/http3/transport_test.go @@ -454,6 +454,26 @@ func (rt *testRoundTrip) wantHeaders(want http.Header) { } } +// readBody reads the contents of the response body. +func (rt *testRoundTrip) readBody() ([]byte, error) { + t := rt.t + t.Helper() + return io.ReadAll(rt.response().Body) +} + +// wantBody consumes the a body and asserts that it is as expected. +func (rt *testRoundTrip) wantBody(want []byte) { + t := rt.t + t.Helper() + got, err := rt.readBody() + if err != nil { + t.Fatalf("unexpected error reading response body: %v", err) + } + if !bytes.Equal(got, want) { + t.Fatalf("unexpected response body:\ngot: %q\nwant: %q", got, want) + } +} + func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { rt := &testRoundTrip{t: tc.t} go func() {