http2: use synthetic timers for ping timeouts in tests

Change-Id: I642890519b066937ade3c13e8387c31d29e912f4
Reviewed-on: https://go-review.googlesource.com/c/net/+/572377
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
This commit is contained in:
Damien Neil
2024-03-18 13:06:45 -07:00
parent 31d9683ed0
commit 9e0498de4d
4 changed files with 240 additions and 70 deletions

View File

@@ -123,6 +123,7 @@ func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn {
tc.fr.SetMaxReadFrameSize(10 << 20)
t.Cleanup(func() {
tc.sync()
if tc.rerr == nil {
tc.rerr = io.EOF
}
@@ -459,6 +460,14 @@ func (tc *testClientConn) writeContinuation(streamID uint32, endHeaders bool, he
tc.sync()
}
func (tc *testClientConn) writePing(ack bool, data [8]byte) {
tc.t.Helper()
if err := tc.fr.WritePing(ack, data); err != nil {
tc.t.Fatal(err)
}
tc.sync()
}
func (tc *testClientConn) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) {
tc.t.Helper()
if err := tc.fr.WriteGoAway(maxStreamID, code, debugData); err != nil {

View File

@@ -4,6 +4,7 @@
package http2
import (
"context"
"sync"
"time"
)
@@ -173,18 +174,56 @@ func (h *testSyncHooks) condWait(cond *sync.Cond) {
h.unlock()
}
// newTimer creates a new timer: A time.Timer if h is nil, or a synthetic timer in tests.
// newTimer creates a new fake timer.
func (h *testSyncHooks) newTimer(d time.Duration) timer {
h.lock()
defer h.unlock()
t := &fakeTimer{
when: h.now.Add(d),
c: make(chan time.Time),
hooks: h,
when: h.now.Add(d),
c: make(chan time.Time),
}
h.timers = append(h.timers, t)
return t
}
// afterFunc creates a new fake AfterFunc timer.
func (h *testSyncHooks) afterFunc(d time.Duration, f func()) timer {
h.lock()
defer h.unlock()
t := &fakeTimer{
hooks: h,
when: h.now.Add(d),
f: f,
}
h.timers = append(h.timers, t)
return t
}
func (h *testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(ctx)
t := h.afterFunc(d, cancel)
return ctx, func() {
t.Stop()
cancel()
}
}
func (h *testSyncHooks) timeUntilEvent() time.Duration {
h.lock()
defer h.unlock()
var next time.Time
for _, t := range h.timers {
if next.IsZero() || t.when.Before(next) {
next = t.when
}
}
if d := next.Sub(h.now); d > 0 {
return d
}
return 0
}
// advance advances time and causes synthetic timers to fire.
func (h *testSyncHooks) advance(d time.Duration) {
h.lock()
@@ -192,6 +231,7 @@ func (h *testSyncHooks) advance(d time.Duration) {
h.now = h.now.Add(d)
timers := h.timers[:0]
for _, t := range h.timers {
t := t // remove after go.mod depends on go1.22
t.mu.Lock()
switch {
case t.when.After(h.now):
@@ -200,7 +240,20 @@ func (h *testSyncHooks) advance(d time.Duration) {
// stopped timer
default:
t.when = time.Time{}
close(t.c)
if t.c != nil {
close(t.c)
}
if t.f != nil {
h.total++
go func() {
defer func() {
h.lock()
h.total--
h.unlock()
}()
t.f()
}()
}
}
t.mu.Unlock()
}
@@ -212,13 +265,16 @@ func (h *testSyncHooks) advance(d time.Duration) {
type timer interface {
C() <-chan time.Time
Stop() bool
Reset(d time.Duration) bool
}
// timeTimer implements timer using real time.
type timeTimer struct {
t *time.Timer
c chan time.Time
}
// newTimeTimer creates a new timer using real time.
func newTimeTimer(d time.Duration) timer {
ch := make(chan time.Time)
t := time.AfterFunc(d, func() {
@@ -227,16 +283,29 @@ func newTimeTimer(d time.Duration) timer {
return &timeTimer{t, ch}
}
func (t timeTimer) C() <-chan time.Time { return t.c }
func (t timeTimer) Stop() bool { return t.t.Stop() }
// newTimeAfterFunc creates an AfterFunc timer using real time.
func newTimeAfterFunc(d time.Duration, f func()) timer {
return &timeTimer{
t: time.AfterFunc(d, f),
}
}
func (t timeTimer) C() <-chan time.Time { return t.c }
func (t timeTimer) Stop() bool { return t.t.Stop() }
func (t timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) }
// fakeTimer implements timer using fake time.
type fakeTimer struct {
hooks *testSyncHooks
mu sync.Mutex
when time.Time
c chan time.Time
when time.Time // when the timer will fire
c chan time.Time // closed when the timer fires; mutually exclusive with f
f func() // called when the timer fires; mutually exclusive with c
}
func (t *fakeTimer) C() <-chan time.Time { return t.c }
func (t *fakeTimer) Stop() bool {
t.mu.Lock()
defer t.mu.Unlock()
@@ -244,3 +313,19 @@ func (t *fakeTimer) Stop() bool {
t.when = time.Time{}
return stopped
}
func (t *fakeTimer) Reset(d time.Duration) bool {
if t.c != nil || t.f == nil {
panic("fakeTimer only supports Reset on AfterFunc timers")
}
t.mu.Lock()
defer t.mu.Unlock()
t.hooks.lock()
defer t.hooks.unlock()
active := !t.when.IsZero()
t.when = t.hooks.now.Add(d)
if !active {
t.hooks.timers = append(t.hooks.timers, t)
}
return active
}

View File

@@ -391,6 +391,21 @@ func (cc *ClientConn) newTimer(d time.Duration) timer {
return newTimeTimer(d)
}
// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
func (cc *ClientConn) afterFunc(d time.Duration, f func()) timer {
if cc.syncHooks != nil {
return cc.syncHooks.afterFunc(d, f)
}
return newTimeAfterFunc(d, f)
}
func (cc *ClientConn) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
if cc.syncHooks != nil {
return cc.syncHooks.contextWithTimeout(ctx, d)
}
return context.WithTimeout(ctx, d)
}
// clientStream is the state for a single HTTP/2 stream. One of these
// is created for each Transport.RoundTrip call.
type clientStream struct {
@@ -875,7 +890,7 @@ func (cc *ClientConn) healthCheck() {
pingTimeout := cc.t.pingTimeout()
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will
// trigger the healthCheck again if there is no frame received.
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
ctx, cancel := cc.contextWithTimeout(context.Background(), pingTimeout)
defer cancel()
cc.vlogf("http2: Transport sending health check")
err := cc.Ping(ctx)
@@ -1432,6 +1447,21 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
if cc.reqHeaderMu == nil {
panic("RoundTrip on uninitialized ClientConn") // for tests
}
var newStreamHook func(*clientStream)
if cc.syncHooks != nil {
newStreamHook = cc.syncHooks.newstream
cc.syncHooks.blockUntil(func() bool {
select {
case cc.reqHeaderMu <- struct{}{}:
<-cc.reqHeaderMu
case <-cs.reqCancel:
case <-ctx.Done():
default:
return false
}
return true
})
}
select {
case cc.reqHeaderMu <- struct{}{}:
case <-cs.reqCancel:
@@ -1456,8 +1486,8 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
}
cc.mu.Unlock()
if cc.syncHooks != nil {
cc.syncHooks.newstream(cs)
if newStreamHook != nil {
newStreamHook(cs)
}
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
@@ -2369,10 +2399,9 @@ func (rl *clientConnReadLoop) run() error {
cc := rl.cc
gotSettings := false
readIdleTimeout := cc.t.ReadIdleTimeout
var t *time.Timer
var t timer
if readIdleTimeout != 0 {
t = time.AfterFunc(readIdleTimeout, cc.healthCheck)
defer t.Stop()
t = cc.afterFunc(readIdleTimeout, cc.healthCheck)
}
for {
f, err := cc.fr.ReadFrame()
@@ -3067,24 +3096,38 @@ func (cc *ClientConn) Ping(ctx context.Context) error {
}
cc.mu.Unlock()
}
errc := make(chan error, 1)
var pingError error
errc := make(chan struct{})
cc.goRun(func() {
cc.wmu.Lock()
defer cc.wmu.Unlock()
if err := cc.fr.WritePing(false, p); err != nil {
errc <- err
if pingError = cc.fr.WritePing(false, p); pingError != nil {
close(errc)
return
}
if err := cc.bw.Flush(); err != nil {
errc <- err
if pingError = cc.bw.Flush(); pingError != nil {
close(errc)
return
}
})
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-c:
case <-errc:
case <-ctx.Done():
case <-cc.readerDone:
default:
return false
}
return true
})
}
select {
case <-c:
return nil
case err := <-errc:
return err
case <-errc:
return pingError
case <-ctx.Done():
return ctx.Err()
case <-cc.readerDone:

View File

@@ -3310,26 +3310,24 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
}
func TestTransportCloseAfterLostPing(t *testing.T) {
clientDone := make(chan struct{})
ct := newClientTester(t)
ct.tr.PingTimeout = 1 * time.Second
ct.tr.ReadIdleTimeout = 1 * time.Second
ct.client = func() error {
defer ct.cc.(*net.TCPConn).CloseWrite()
defer close(clientDone)
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
_, err := ct.tr.RoundTrip(req)
if err == nil || !strings.Contains(err.Error(), "client connection lost") {
return fmt.Errorf("expected to get error about \"connection lost\", got %v", err)
}
return nil
tc := newTestClientConn(t, func(tr *Transport) {
tr.PingTimeout = 1 * time.Second
tr.ReadIdleTimeout = 1 * time.Second
})
tc.greet()
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
rt := tc.roundTrip(req)
tc.wantFrameType(FrameHeaders)
tc.advance(1 * time.Second)
tc.wantFrameType(FramePing)
tc.advance(1 * time.Second)
err := rt.err()
if err == nil || !strings.Contains(err.Error(), "client connection lost") {
t.Fatalf("expected to get error about \"connection lost\", got %v", err)
}
ct.server = func() error {
ct.greet()
<-clientDone
return nil
}
ct.run()
}
func TestTransportPingWriteBlocks(t *testing.T) {
@@ -3362,38 +3360,73 @@ func TestTransportPingWriteBlocks(t *testing.T) {
}
}
func TestTransportPingWhenReading(t *testing.T) {
testCases := []struct {
name string
readIdleTimeout time.Duration
deadline time.Duration
expectedPingCount int
}{
{
name: "two pings",
readIdleTimeout: 100 * time.Millisecond,
deadline: time.Second,
expectedPingCount: 2,
},
{
name: "zero ping",
readIdleTimeout: time.Second,
deadline: 200 * time.Millisecond,
expectedPingCount: 0,
},
{
name: "0 readIdleTimeout means no ping",
readIdleTimeout: 0 * time.Millisecond,
deadline: 500 * time.Millisecond,
expectedPingCount: 0,
},
func TestTransportPingWhenReadingMultiplePings(t *testing.T) {
tc := newTestClientConn(t, func(tr *Transport) {
tr.ReadIdleTimeout = 1000 * time.Millisecond
})
tc.greet()
ctx, cancel := context.WithCancel(context.Background())
req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
rt := tc.roundTrip(req)
tc.wantFrameType(FrameHeaders)
tc.writeHeaders(HeadersFrameParam{
StreamID: rt.streamID(),
EndHeaders: true,
EndStream: false,
BlockFragment: tc.makeHeaderBlockFragment(
":status", "200",
),
})
for i := 0; i < 5; i++ {
// No ping yet...
tc.advance(999 * time.Millisecond)
if f := tc.readFrame(); f != nil {
t.Fatalf("unexpected frame: %v", f)
}
// ...ping now.
tc.advance(1 * time.Millisecond)
f := testClientConnReadFrame[*PingFrame](tc)
tc.writePing(true, f.Data)
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
testTransportPingWhenReading(t, tc.readIdleTimeout, tc.deadline, tc.expectedPingCount)
})
// Cancel the request, Transport resets it and returns an error from body reads.
cancel()
tc.sync()
tc.wantFrameType(FrameRSTStream)
_, err := rt.readBody()
if err == nil {
t.Fatalf("Response.Body.Read() = %v, want error", err)
}
}
func TestTransportPingWhenReadingPingDisabled(t *testing.T) {
tc := newTestClientConn(t, func(tr *Transport) {
tr.ReadIdleTimeout = 0 // PINGs disabled
})
tc.greet()
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
rt := tc.roundTrip(req)
tc.wantFrameType(FrameHeaders)
tc.writeHeaders(HeadersFrameParam{
StreamID: rt.streamID(),
EndHeaders: true,
EndStream: false,
BlockFragment: tc.makeHeaderBlockFragment(
":status", "200",
),
})
// No PING is sent, even after a long delay.
tc.advance(1 * time.Minute)
if f := tc.readFrame(); f != nil {
t.Fatalf("unexpected frame: %v", f)
}
}