http2: avoid race in TestTransportReqBodyAfterResponse_403.

This test sends a request with a 10MiB body, reads a 403 response
while the body is still being written, and closes the response body.
It then verifies that the full request body was not written, since
reading a response code >=300 interrupts body writes.

This can be racy: We process the status code and interrupt the body
write in RoundTrip, but it is possible for the body write to complete
before RoundTrip looks at the response.

Adjust the test to have more control over the request body:
Only provide half the Request.Body until after the response headers
have been received.

Fixes golang/go#48792.

Change-Id: Id4802b04a50f34f6af28f4eb93e37ef70a33a068
Reviewed-on: https://go-review.googlesource.com/c/net/+/354130
Trust: Damien Neil <dneil@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
Damien Neil
2021-10-05 12:46:59 -07:00
parent d4b1ae081e
commit d2e5035098

View File

@@ -889,6 +889,7 @@ func testTransportReqBodyAfterResponse(t *testing.T, status int) {
const bodySize = 10 << 20
clientDone := make(chan struct{})
ct := newClientTester(t)
recvLen := make(chan int64, 1)
ct.client = func() error {
defer ct.cc.(*net.TCPConn).CloseWrite()
if runtime.GOOS == "plan9" {
@@ -897,8 +898,9 @@ func testTransportReqBodyAfterResponse(t *testing.T, status int) {
}
defer close(clientDone)
var n int64 // atomic
req, err := http.NewRequest("PUT", "https://dummy.tld/", io.LimitReader(countingReader{&n}, bodySize))
body := &pipe{b: new(bytes.Buffer)}
io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2))
req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
if err != nil {
return err
}
@@ -906,10 +908,11 @@ func testTransportReqBodyAfterResponse(t *testing.T, status int) {
if err != nil {
return fmt.Errorf("RoundTrip: %v", err)
}
defer res.Body.Close()
if res.StatusCode != status {
return fmt.Errorf("status code = %v; want %v", res.StatusCode, status)
}
io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2))
body.CloseWithError(io.EOF)
slurp, err := ioutil.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("Slurp: %v", err)
@@ -917,12 +920,13 @@ func testTransportReqBodyAfterResponse(t *testing.T, status int) {
if len(slurp) > 0 {
return fmt.Errorf("unexpected body: %q", slurp)
}
res.Body.Close()
if status == 200 {
if got := atomic.LoadInt64(&n); got != bodySize {
if got := <-recvLen; got != bodySize {
return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize)
}
} else {
if got := atomic.LoadInt64(&n); got == 0 || got >= bodySize {
if got := <-recvLen; got == 0 || got >= bodySize {
return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize)
}
}
@@ -948,6 +952,7 @@ func testTransportReqBodyAfterResponse(t *testing.T, status int) {
}
}
//println(fmt.Sprintf("server got frame: %v", f))
ended := false
switch f := f.(type) {
case *WindowUpdateFrame, *SettingsFrame:
case *HeadersFrame:
@@ -985,13 +990,24 @@ func testTransportReqBodyAfterResponse(t *testing.T, status int) {
return err
}
}
if f.StreamEnded() {
ended = true
}
case *RSTStreamFrame:
if status == 200 {
return fmt.Errorf("Unexpected client frame %v", f)
}
ended = true
default:
return fmt.Errorf("Unexpected client frame %v", f)
}
if ended {
select {
case recvLen <- dataRecv:
default:
}
}
}
}
ct.run()