mirror of
https://github.com/golang/net.git
synced 2026-04-01 02:47:08 +09:00
http2: use synthetic timers for ping timeouts in tests
Change-Id: I642890519b066937ade3c13e8387c31d29e912f4 Reviewed-on: https://go-review.googlesource.com/c/net/+/572377 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Jonathan Amsterdam <jba@google.com>
This commit is contained in:
@@ -123,6 +123,7 @@ func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn {
|
||||
tc.fr.SetMaxReadFrameSize(10 << 20)
|
||||
|
||||
t.Cleanup(func() {
|
||||
tc.sync()
|
||||
if tc.rerr == nil {
|
||||
tc.rerr = io.EOF
|
||||
}
|
||||
@@ -459,6 +460,14 @@ func (tc *testClientConn) writeContinuation(streamID uint32, endHeaders bool, he
|
||||
tc.sync()
|
||||
}
|
||||
|
||||
func (tc *testClientConn) writePing(ack bool, data [8]byte) {
|
||||
tc.t.Helper()
|
||||
if err := tc.fr.WritePing(ack, data); err != nil {
|
||||
tc.t.Fatal(err)
|
||||
}
|
||||
tc.sync()
|
||||
}
|
||||
|
||||
func (tc *testClientConn) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) {
|
||||
tc.t.Helper()
|
||||
if err := tc.fr.WriteGoAway(maxStreamID, code, debugData); err != nil {
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package http2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -173,18 +174,56 @@ func (h *testSyncHooks) condWait(cond *sync.Cond) {
|
||||
h.unlock()
|
||||
}
|
||||
|
||||
// newTimer creates a new timer: A time.Timer if h is nil, or a synthetic timer in tests.
|
||||
// newTimer creates a new fake timer.
|
||||
func (h *testSyncHooks) newTimer(d time.Duration) timer {
|
||||
h.lock()
|
||||
defer h.unlock()
|
||||
t := &fakeTimer{
|
||||
when: h.now.Add(d),
|
||||
c: make(chan time.Time),
|
||||
hooks: h,
|
||||
when: h.now.Add(d),
|
||||
c: make(chan time.Time),
|
||||
}
|
||||
h.timers = append(h.timers, t)
|
||||
return t
|
||||
}
|
||||
|
||||
// afterFunc creates a new fake AfterFunc timer.
|
||||
func (h *testSyncHooks) afterFunc(d time.Duration, f func()) timer {
|
||||
h.lock()
|
||||
defer h.unlock()
|
||||
t := &fakeTimer{
|
||||
hooks: h,
|
||||
when: h.now.Add(d),
|
||||
f: f,
|
||||
}
|
||||
h.timers = append(h.timers, t)
|
||||
return t
|
||||
}
|
||||
|
||||
func (h *testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
t := h.afterFunc(d, cancel)
|
||||
return ctx, func() {
|
||||
t.Stop()
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *testSyncHooks) timeUntilEvent() time.Duration {
|
||||
h.lock()
|
||||
defer h.unlock()
|
||||
var next time.Time
|
||||
for _, t := range h.timers {
|
||||
if next.IsZero() || t.when.Before(next) {
|
||||
next = t.when
|
||||
}
|
||||
}
|
||||
if d := next.Sub(h.now); d > 0 {
|
||||
return d
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// advance advances time and causes synthetic timers to fire.
|
||||
func (h *testSyncHooks) advance(d time.Duration) {
|
||||
h.lock()
|
||||
@@ -192,6 +231,7 @@ func (h *testSyncHooks) advance(d time.Duration) {
|
||||
h.now = h.now.Add(d)
|
||||
timers := h.timers[:0]
|
||||
for _, t := range h.timers {
|
||||
t := t // remove after go.mod depends on go1.22
|
||||
t.mu.Lock()
|
||||
switch {
|
||||
case t.when.After(h.now):
|
||||
@@ -200,7 +240,20 @@ func (h *testSyncHooks) advance(d time.Duration) {
|
||||
// stopped timer
|
||||
default:
|
||||
t.when = time.Time{}
|
||||
close(t.c)
|
||||
if t.c != nil {
|
||||
close(t.c)
|
||||
}
|
||||
if t.f != nil {
|
||||
h.total++
|
||||
go func() {
|
||||
defer func() {
|
||||
h.lock()
|
||||
h.total--
|
||||
h.unlock()
|
||||
}()
|
||||
t.f()
|
||||
}()
|
||||
}
|
||||
}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
@@ -212,13 +265,16 @@ func (h *testSyncHooks) advance(d time.Duration) {
|
||||
type timer interface {
|
||||
C() <-chan time.Time
|
||||
Stop() bool
|
||||
Reset(d time.Duration) bool
|
||||
}
|
||||
|
||||
// timeTimer implements timer using real time.
|
||||
type timeTimer struct {
|
||||
t *time.Timer
|
||||
c chan time.Time
|
||||
}
|
||||
|
||||
// newTimeTimer creates a new timer using real time.
|
||||
func newTimeTimer(d time.Duration) timer {
|
||||
ch := make(chan time.Time)
|
||||
t := time.AfterFunc(d, func() {
|
||||
@@ -227,16 +283,29 @@ func newTimeTimer(d time.Duration) timer {
|
||||
return &timeTimer{t, ch}
|
||||
}
|
||||
|
||||
func (t timeTimer) C() <-chan time.Time { return t.c }
|
||||
func (t timeTimer) Stop() bool { return t.t.Stop() }
|
||||
// newTimeAfterFunc creates an AfterFunc timer using real time.
|
||||
func newTimeAfterFunc(d time.Duration, f func()) timer {
|
||||
return &timeTimer{
|
||||
t: time.AfterFunc(d, f),
|
||||
}
|
||||
}
|
||||
|
||||
func (t timeTimer) C() <-chan time.Time { return t.c }
|
||||
func (t timeTimer) Stop() bool { return t.t.Stop() }
|
||||
func (t timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) }
|
||||
|
||||
// fakeTimer implements timer using fake time.
|
||||
type fakeTimer struct {
|
||||
hooks *testSyncHooks
|
||||
|
||||
mu sync.Mutex
|
||||
when time.Time
|
||||
c chan time.Time
|
||||
when time.Time // when the timer will fire
|
||||
c chan time.Time // closed when the timer fires; mutually exclusive with f
|
||||
f func() // called when the timer fires; mutually exclusive with c
|
||||
}
|
||||
|
||||
func (t *fakeTimer) C() <-chan time.Time { return t.c }
|
||||
|
||||
func (t *fakeTimer) Stop() bool {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
@@ -244,3 +313,19 @@ func (t *fakeTimer) Stop() bool {
|
||||
t.when = time.Time{}
|
||||
return stopped
|
||||
}
|
||||
|
||||
func (t *fakeTimer) Reset(d time.Duration) bool {
|
||||
if t.c != nil || t.f == nil {
|
||||
panic("fakeTimer only supports Reset on AfterFunc timers")
|
||||
}
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.hooks.lock()
|
||||
defer t.hooks.unlock()
|
||||
active := !t.when.IsZero()
|
||||
t.when = t.hooks.now.Add(d)
|
||||
if !active {
|
||||
t.hooks.timers = append(t.hooks.timers, t)
|
||||
}
|
||||
return active
|
||||
}
|
||||
|
||||
@@ -391,6 +391,21 @@ func (cc *ClientConn) newTimer(d time.Duration) timer {
|
||||
return newTimeTimer(d)
|
||||
}
|
||||
|
||||
// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
|
||||
func (cc *ClientConn) afterFunc(d time.Duration, f func()) timer {
|
||||
if cc.syncHooks != nil {
|
||||
return cc.syncHooks.afterFunc(d, f)
|
||||
}
|
||||
return newTimeAfterFunc(d, f)
|
||||
}
|
||||
|
||||
func (cc *ClientConn) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
|
||||
if cc.syncHooks != nil {
|
||||
return cc.syncHooks.contextWithTimeout(ctx, d)
|
||||
}
|
||||
return context.WithTimeout(ctx, d)
|
||||
}
|
||||
|
||||
// clientStream is the state for a single HTTP/2 stream. One of these
|
||||
// is created for each Transport.RoundTrip call.
|
||||
type clientStream struct {
|
||||
@@ -875,7 +890,7 @@ func (cc *ClientConn) healthCheck() {
|
||||
pingTimeout := cc.t.pingTimeout()
|
||||
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will
|
||||
// trigger the healthCheck again if there is no frame received.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
||||
ctx, cancel := cc.contextWithTimeout(context.Background(), pingTimeout)
|
||||
defer cancel()
|
||||
cc.vlogf("http2: Transport sending health check")
|
||||
err := cc.Ping(ctx)
|
||||
@@ -1432,6 +1447,21 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
|
||||
if cc.reqHeaderMu == nil {
|
||||
panic("RoundTrip on uninitialized ClientConn") // for tests
|
||||
}
|
||||
var newStreamHook func(*clientStream)
|
||||
if cc.syncHooks != nil {
|
||||
newStreamHook = cc.syncHooks.newstream
|
||||
cc.syncHooks.blockUntil(func() bool {
|
||||
select {
|
||||
case cc.reqHeaderMu <- struct{}{}:
|
||||
<-cc.reqHeaderMu
|
||||
case <-cs.reqCancel:
|
||||
case <-ctx.Done():
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
select {
|
||||
case cc.reqHeaderMu <- struct{}{}:
|
||||
case <-cs.reqCancel:
|
||||
@@ -1456,8 +1486,8 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
|
||||
}
|
||||
cc.mu.Unlock()
|
||||
|
||||
if cc.syncHooks != nil {
|
||||
cc.syncHooks.newstream(cs)
|
||||
if newStreamHook != nil {
|
||||
newStreamHook(cs)
|
||||
}
|
||||
|
||||
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
|
||||
@@ -2369,10 +2399,9 @@ func (rl *clientConnReadLoop) run() error {
|
||||
cc := rl.cc
|
||||
gotSettings := false
|
||||
readIdleTimeout := cc.t.ReadIdleTimeout
|
||||
var t *time.Timer
|
||||
var t timer
|
||||
if readIdleTimeout != 0 {
|
||||
t = time.AfterFunc(readIdleTimeout, cc.healthCheck)
|
||||
defer t.Stop()
|
||||
t = cc.afterFunc(readIdleTimeout, cc.healthCheck)
|
||||
}
|
||||
for {
|
||||
f, err := cc.fr.ReadFrame()
|
||||
@@ -3067,24 +3096,38 @@ func (cc *ClientConn) Ping(ctx context.Context) error {
|
||||
}
|
||||
cc.mu.Unlock()
|
||||
}
|
||||
errc := make(chan error, 1)
|
||||
var pingError error
|
||||
errc := make(chan struct{})
|
||||
cc.goRun(func() {
|
||||
cc.wmu.Lock()
|
||||
defer cc.wmu.Unlock()
|
||||
if err := cc.fr.WritePing(false, p); err != nil {
|
||||
errc <- err
|
||||
if pingError = cc.fr.WritePing(false, p); pingError != nil {
|
||||
close(errc)
|
||||
return
|
||||
}
|
||||
if err := cc.bw.Flush(); err != nil {
|
||||
errc <- err
|
||||
if pingError = cc.bw.Flush(); pingError != nil {
|
||||
close(errc)
|
||||
return
|
||||
}
|
||||
})
|
||||
if cc.syncHooks != nil {
|
||||
cc.syncHooks.blockUntil(func() bool {
|
||||
select {
|
||||
case <-c:
|
||||
case <-errc:
|
||||
case <-ctx.Done():
|
||||
case <-cc.readerDone:
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
select {
|
||||
case <-c:
|
||||
return nil
|
||||
case err := <-errc:
|
||||
return err
|
||||
case <-errc:
|
||||
return pingError
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-cc.readerDone:
|
||||
|
||||
@@ -3310,26 +3310,24 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTransportCloseAfterLostPing(t *testing.T) {
|
||||
clientDone := make(chan struct{})
|
||||
ct := newClientTester(t)
|
||||
ct.tr.PingTimeout = 1 * time.Second
|
||||
ct.tr.ReadIdleTimeout = 1 * time.Second
|
||||
ct.client = func() error {
|
||||
defer ct.cc.(*net.TCPConn).CloseWrite()
|
||||
defer close(clientDone)
|
||||
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
|
||||
_, err := ct.tr.RoundTrip(req)
|
||||
if err == nil || !strings.Contains(err.Error(), "client connection lost") {
|
||||
return fmt.Errorf("expected to get error about \"connection lost\", got %v", err)
|
||||
}
|
||||
return nil
|
||||
tc := newTestClientConn(t, func(tr *Transport) {
|
||||
tr.PingTimeout = 1 * time.Second
|
||||
tr.ReadIdleTimeout = 1 * time.Second
|
||||
})
|
||||
tc.greet()
|
||||
|
||||
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
|
||||
rt := tc.roundTrip(req)
|
||||
tc.wantFrameType(FrameHeaders)
|
||||
|
||||
tc.advance(1 * time.Second)
|
||||
tc.wantFrameType(FramePing)
|
||||
|
||||
tc.advance(1 * time.Second)
|
||||
err := rt.err()
|
||||
if err == nil || !strings.Contains(err.Error(), "client connection lost") {
|
||||
t.Fatalf("expected to get error about \"connection lost\", got %v", err)
|
||||
}
|
||||
ct.server = func() error {
|
||||
ct.greet()
|
||||
<-clientDone
|
||||
return nil
|
||||
}
|
||||
ct.run()
|
||||
}
|
||||
|
||||
func TestTransportPingWriteBlocks(t *testing.T) {
|
||||
@@ -3362,38 +3360,73 @@ func TestTransportPingWriteBlocks(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportPingWhenReading(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
readIdleTimeout time.Duration
|
||||
deadline time.Duration
|
||||
expectedPingCount int
|
||||
}{
|
||||
{
|
||||
name: "two pings",
|
||||
readIdleTimeout: 100 * time.Millisecond,
|
||||
deadline: time.Second,
|
||||
expectedPingCount: 2,
|
||||
},
|
||||
{
|
||||
name: "zero ping",
|
||||
readIdleTimeout: time.Second,
|
||||
deadline: 200 * time.Millisecond,
|
||||
expectedPingCount: 0,
|
||||
},
|
||||
{
|
||||
name: "0 readIdleTimeout means no ping",
|
||||
readIdleTimeout: 0 * time.Millisecond,
|
||||
deadline: 500 * time.Millisecond,
|
||||
expectedPingCount: 0,
|
||||
},
|
||||
func TestTransportPingWhenReadingMultiplePings(t *testing.T) {
|
||||
tc := newTestClientConn(t, func(tr *Transport) {
|
||||
tr.ReadIdleTimeout = 1000 * time.Millisecond
|
||||
})
|
||||
tc.greet()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
|
||||
rt := tc.roundTrip(req)
|
||||
|
||||
tc.wantFrameType(FrameHeaders)
|
||||
tc.writeHeaders(HeadersFrameParam{
|
||||
StreamID: rt.streamID(),
|
||||
EndHeaders: true,
|
||||
EndStream: false,
|
||||
BlockFragment: tc.makeHeaderBlockFragment(
|
||||
":status", "200",
|
||||
),
|
||||
})
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
// No ping yet...
|
||||
tc.advance(999 * time.Millisecond)
|
||||
if f := tc.readFrame(); f != nil {
|
||||
t.Fatalf("unexpected frame: %v", f)
|
||||
}
|
||||
|
||||
// ...ping now.
|
||||
tc.advance(1 * time.Millisecond)
|
||||
f := testClientConnReadFrame[*PingFrame](tc)
|
||||
tc.writePing(true, f.Data)
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
testTransportPingWhenReading(t, tc.readIdleTimeout, tc.deadline, tc.expectedPingCount)
|
||||
})
|
||||
// Cancel the request, Transport resets it and returns an error from body reads.
|
||||
cancel()
|
||||
tc.sync()
|
||||
|
||||
tc.wantFrameType(FrameRSTStream)
|
||||
_, err := rt.readBody()
|
||||
if err == nil {
|
||||
t.Fatalf("Response.Body.Read() = %v, want error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportPingWhenReadingPingDisabled(t *testing.T) {
|
||||
tc := newTestClientConn(t, func(tr *Transport) {
|
||||
tr.ReadIdleTimeout = 0 // PINGs disabled
|
||||
})
|
||||
tc.greet()
|
||||
|
||||
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
|
||||
rt := tc.roundTrip(req)
|
||||
|
||||
tc.wantFrameType(FrameHeaders)
|
||||
tc.writeHeaders(HeadersFrameParam{
|
||||
StreamID: rt.streamID(),
|
||||
EndHeaders: true,
|
||||
EndStream: false,
|
||||
BlockFragment: tc.makeHeaderBlockFragment(
|
||||
":status", "200",
|
||||
),
|
||||
})
|
||||
|
||||
// No PING is sent, even after a long delay.
|
||||
tc.advance(1 * time.Minute)
|
||||
if f := tc.readFrame(); f != nil {
|
||||
t.Fatalf("unexpected frame: %v", f)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user