mirror of
https://github.com/golang/net.git
synced 2026-03-31 18:37:08 +09:00
Decoding QUIC packet numbers requires keeping track of the largest packet number received so far from the peer. Our tests haven't bothered doing that so far, so tests can't work with packet numbers past 255. Fix that so we can write tests that use more packets. Change-Id: Icb795e5cf69794381c12a3a03b0da6bcf47a69c0 Reviewed-on: https://go-review.googlesource.com/c/net/+/664296 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Jonathan Amsterdam <jba@google.com> Auto-Submit: Damien Neil <dneil@google.com>
1182 lines
32 KiB
Go
1182 lines
32 KiB
Go
// Copyright 2023 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package quic
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"log/slog"
|
|
"math"
|
|
"net/netip"
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/net/quic/qlog"
|
|
)
|
|
|
|
var (
|
|
testVV = flag.Bool("vv", false, "even more verbose test output")
|
|
qlogdir = flag.String("qlog", "", "write qlog logs to directory")
|
|
)
|
|
|
|
func TestConnTestConn(t *testing.T) {
|
|
tc := newTestConn(t, serverSide)
|
|
tc.handshake()
|
|
if got, want := tc.timeUntilEvent(), defaultMaxIdleTimeout; got != want {
|
|
t.Errorf("new conn timeout=%v, want %v (max_idle_timeout)", got, want)
|
|
}
|
|
|
|
ranAt, _ := runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
|
|
tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
|
|
when = now
|
|
})
|
|
return
|
|
}).result()
|
|
if !ranAt.Equal(tc.endpoint.now) {
|
|
t.Errorf("func ran on loop at %v, want %v", ranAt, tc.endpoint.now)
|
|
}
|
|
tc.wait()
|
|
|
|
nextTime := tc.endpoint.now.Add(defaultMaxIdleTimeout / 2)
|
|
tc.advanceTo(nextTime)
|
|
ranAt, _ = runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
|
|
tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
|
|
when = now
|
|
})
|
|
return
|
|
}).result()
|
|
if !ranAt.Equal(nextTime) {
|
|
t.Errorf("func ran on loop at %v, want %v", ranAt, nextTime)
|
|
}
|
|
tc.wait()
|
|
|
|
tc.advanceToTimer()
|
|
if got := tc.conn.lifetime.state; got != connStateDone {
|
|
t.Errorf("after advancing to idle timeout, conn state = %v, want done", got)
|
|
}
|
|
}
|
|
|
|
type testDatagram struct {
|
|
packets []*testPacket
|
|
paddedSize int
|
|
addr netip.AddrPort
|
|
}
|
|
|
|
func (d testDatagram) String() string {
|
|
var b strings.Builder
|
|
fmt.Fprintf(&b, "datagram with %v packets", len(d.packets))
|
|
if d.paddedSize > 0 {
|
|
fmt.Fprintf(&b, " (padded to %v bytes)", d.paddedSize)
|
|
}
|
|
b.WriteString(":")
|
|
for _, p := range d.packets {
|
|
b.WriteString("\n")
|
|
b.WriteString(p.String())
|
|
}
|
|
return b.String()
|
|
}
|
|
|
|
type testPacket struct {
|
|
ptype packetType
|
|
header byte
|
|
version uint32
|
|
num packetNumber
|
|
keyPhaseBit bool
|
|
keyNumber int
|
|
dstConnID []byte
|
|
srcConnID []byte
|
|
token []byte
|
|
originalDstConnID []byte // used for encoding Retry packets
|
|
frames []debugFrame
|
|
}
|
|
|
|
func (p testPacket) String() string {
|
|
var b strings.Builder
|
|
fmt.Fprintf(&b, " %v %v", p.ptype, p.num)
|
|
if p.version != 0 {
|
|
fmt.Fprintf(&b, " version=%v", p.version)
|
|
}
|
|
if p.srcConnID != nil {
|
|
fmt.Fprintf(&b, " src={%x}", p.srcConnID)
|
|
}
|
|
if p.dstConnID != nil {
|
|
fmt.Fprintf(&b, " dst={%x}", p.dstConnID)
|
|
}
|
|
if p.token != nil {
|
|
fmt.Fprintf(&b, " token={%x}", p.token)
|
|
}
|
|
for _, f := range p.frames {
|
|
fmt.Fprintf(&b, "\n %v", f)
|
|
}
|
|
return b.String()
|
|
}
|
|
|
|
// maxTestKeyPhases is the maximum number of 1-RTT keys we'll generate in a test.
|
|
const maxTestKeyPhases = 3
|
|
|
|
// A testConn is a Conn whose external interactions (sending and receiving packets,
|
|
// setting timers) can be manipulated in tests.
|
|
type testConn struct {
|
|
t *testing.T
|
|
conn *Conn
|
|
endpoint *testEndpoint
|
|
timer time.Time
|
|
timerLastFired time.Time
|
|
idlec chan struct{} // only accessed on the conn's loop
|
|
|
|
// Keys are distinct from the conn's keys,
|
|
// because the test may know about keys before the conn does.
|
|
// For example, when sending a datagram with coalesced
|
|
// Initial and Handshake packets to a client conn,
|
|
// we use Handshake keys to encrypt the packet.
|
|
// The client only acquires those keys when it processes
|
|
// the Initial packet.
|
|
keysInitial fixedKeyPair
|
|
keysHandshake fixedKeyPair
|
|
rkeyAppData test1RTTKeys
|
|
wkeyAppData test1RTTKeys
|
|
rsecrets [numberSpaceCount]keySecret
|
|
wsecrets [numberSpaceCount]keySecret
|
|
|
|
// testConn uses a test hook to snoop on the conn's TLS events.
|
|
// CRYPTO data produced by the conn's QUICConn is placed in
|
|
// cryptoDataOut.
|
|
//
|
|
// The peerTLSConn is is a QUICConn representing the peer.
|
|
// CRYPTO data produced by the conn is written to peerTLSConn,
|
|
// and data produced by peerTLSConn is placed in cryptoDataIn.
|
|
cryptoDataOut map[tls.QUICEncryptionLevel][]byte
|
|
cryptoDataIn map[tls.QUICEncryptionLevel][]byte
|
|
peerTLSConn *tls.QUICConn
|
|
|
|
// Information about the conn's (fake) peer.
|
|
peerConnID []byte // source conn id of peer's packets
|
|
peerNextPacketNum [numberSpaceCount]packetNumber // next packet number to use
|
|
|
|
// Maximum packet number received from the conn.
|
|
pnumMax [numberSpaceCount]packetNumber
|
|
|
|
// Datagrams, packets, and frames sent by the conn,
|
|
// but not yet processed by the test.
|
|
sentDatagrams [][]byte
|
|
sentPackets []*testPacket
|
|
sentFrames []debugFrame
|
|
lastDatagram *testDatagram
|
|
lastPacket *testPacket
|
|
|
|
recvDatagram chan *datagram
|
|
|
|
// Transport parameters sent by the conn.
|
|
sentTransportParameters *transportParameters
|
|
|
|
// Frame types to ignore in tests.
|
|
ignoreFrames map[byte]bool
|
|
|
|
// Values to set in packets sent to the conn.
|
|
sendKeyNumber int
|
|
sendKeyPhaseBit bool
|
|
|
|
asyncTestState
|
|
}
|
|
|
|
type test1RTTKeys struct {
|
|
hdr headerKey
|
|
pkt [maxTestKeyPhases]packetKey
|
|
}
|
|
|
|
type keySecret struct {
|
|
suite uint16
|
|
secret []byte
|
|
}
|
|
|
|
// newTestConn creates a Conn for testing.
|
|
//
|
|
// The Conn's event loop is controlled by the test,
|
|
// allowing test code to access Conn state directly
|
|
// by first ensuring the loop goroutine is idle.
|
|
func newTestConn(t *testing.T, side connSide, opts ...any) *testConn {
|
|
t.Helper()
|
|
config := &Config{
|
|
TLSConfig: newTestTLSConfig(side),
|
|
StatelessResetKey: testStatelessResetKey,
|
|
QLogLogger: slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{
|
|
Level: QLogLevelFrame,
|
|
Dir: *qlogdir,
|
|
})),
|
|
}
|
|
var cids newServerConnIDs
|
|
if side == serverSide {
|
|
// The initial connection ID for the server is chosen by the client.
|
|
cids.srcConnID = testPeerConnID(0)
|
|
cids.dstConnID = testPeerConnID(-1)
|
|
cids.originalDstConnID = cids.dstConnID
|
|
}
|
|
var configTransportParams []func(*transportParameters)
|
|
var configTestConn []func(*testConn)
|
|
for _, o := range opts {
|
|
switch o := o.(type) {
|
|
case func(*Config):
|
|
o(config)
|
|
case func(*tls.Config):
|
|
o(config.TLSConfig)
|
|
case func(cids *newServerConnIDs):
|
|
o(&cids)
|
|
case func(p *transportParameters):
|
|
configTransportParams = append(configTransportParams, o)
|
|
case func(p *testConn):
|
|
configTestConn = append(configTestConn, o)
|
|
default:
|
|
t.Fatalf("unknown newTestConn option %T", o)
|
|
}
|
|
}
|
|
|
|
endpoint := newTestEndpoint(t, config)
|
|
endpoint.configTransportParams = configTransportParams
|
|
endpoint.configTestConn = configTestConn
|
|
conn, err := endpoint.e.newConn(
|
|
endpoint.now,
|
|
config,
|
|
side,
|
|
cids,
|
|
"",
|
|
netip.MustParseAddrPort("127.0.0.1:443"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tc := endpoint.conns[conn]
|
|
tc.wait()
|
|
return tc
|
|
}
|
|
|
|
func newTestConnForConn(t *testing.T, endpoint *testEndpoint, conn *Conn) *testConn {
|
|
t.Helper()
|
|
tc := &testConn{
|
|
t: t,
|
|
endpoint: endpoint,
|
|
conn: conn,
|
|
peerConnID: testPeerConnID(0),
|
|
ignoreFrames: map[byte]bool{
|
|
frameTypePadding: true, // ignore PADDING by default
|
|
},
|
|
cryptoDataOut: make(map[tls.QUICEncryptionLevel][]byte),
|
|
cryptoDataIn: make(map[tls.QUICEncryptionLevel][]byte),
|
|
recvDatagram: make(chan *datagram),
|
|
}
|
|
t.Cleanup(tc.cleanup)
|
|
for _, f := range endpoint.configTestConn {
|
|
f(tc)
|
|
}
|
|
conn.testHooks = (*testConnHooks)(tc)
|
|
|
|
if endpoint.peerTLSConn != nil {
|
|
tc.peerTLSConn = endpoint.peerTLSConn
|
|
endpoint.peerTLSConn = nil
|
|
return tc
|
|
}
|
|
|
|
peerProvidedParams := defaultTransportParameters()
|
|
peerProvidedParams.initialSrcConnID = testPeerConnID(0)
|
|
if conn.side == clientSide {
|
|
peerProvidedParams.originalDstConnID = testLocalConnID(-1)
|
|
}
|
|
for _, f := range endpoint.configTransportParams {
|
|
f(&peerProvidedParams)
|
|
}
|
|
|
|
peerQUICConfig := &tls.QUICConfig{TLSConfig: newTestTLSConfig(conn.side.peer())}
|
|
if conn.side == clientSide {
|
|
tc.peerTLSConn = tls.QUICServer(peerQUICConfig)
|
|
} else {
|
|
tc.peerTLSConn = tls.QUICClient(peerQUICConfig)
|
|
}
|
|
tc.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams))
|
|
tc.peerTLSConn.Start(context.Background())
|
|
t.Cleanup(func() {
|
|
tc.peerTLSConn.Close()
|
|
})
|
|
|
|
return tc
|
|
}
|
|
|
|
// advance causes time to pass.
|
|
func (tc *testConn) advance(d time.Duration) {
|
|
tc.t.Helper()
|
|
tc.endpoint.advance(d)
|
|
}
|
|
|
|
// advanceTo sets the current time.
|
|
func (tc *testConn) advanceTo(now time.Time) {
|
|
tc.t.Helper()
|
|
tc.endpoint.advanceTo(now)
|
|
}
|
|
|
|
// advanceToTimer sets the current time to the time of the Conn's next timer event.
|
|
func (tc *testConn) advanceToTimer() {
|
|
if tc.timer.IsZero() {
|
|
tc.t.Fatalf("advancing to timer, but timer is not set")
|
|
}
|
|
tc.advanceTo(tc.timer)
|
|
}
|
|
|
|
func (tc *testConn) timerDelay() time.Duration {
|
|
if tc.timer.IsZero() {
|
|
return math.MaxInt64 // infinite
|
|
}
|
|
if tc.timer.Before(tc.endpoint.now) {
|
|
return 0
|
|
}
|
|
return tc.timer.Sub(tc.endpoint.now)
|
|
}
|
|
|
|
const infiniteDuration = time.Duration(math.MaxInt64)
|
|
|
|
// timeUntilEvent returns the amount of time until the next connection event.
|
|
func (tc *testConn) timeUntilEvent() time.Duration {
|
|
if tc.timer.IsZero() {
|
|
return infiniteDuration
|
|
}
|
|
if tc.timer.Before(tc.endpoint.now) {
|
|
return 0
|
|
}
|
|
return tc.timer.Sub(tc.endpoint.now)
|
|
}
|
|
|
|
// wait blocks until the conn becomes idle.
|
|
// The conn is idle when it is blocked waiting for a packet to arrive or a timer to expire.
|
|
// Tests shouldn't need to call wait directly.
|
|
// testConn methods that wake the Conn event loop will call wait for them.
|
|
func (tc *testConn) wait() {
|
|
tc.t.Helper()
|
|
idlec := make(chan struct{})
|
|
fail := false
|
|
tc.conn.sendMsg(func(now time.Time, c *Conn) {
|
|
if tc.idlec != nil {
|
|
tc.t.Errorf("testConn.wait called concurrently")
|
|
fail = true
|
|
close(idlec)
|
|
} else {
|
|
// nextMessage will close idlec.
|
|
tc.idlec = idlec
|
|
}
|
|
})
|
|
select {
|
|
case <-idlec:
|
|
case <-tc.conn.donec:
|
|
// We may have async ops that can proceed now that the conn is done.
|
|
tc.wakeAsync()
|
|
}
|
|
if fail {
|
|
panic(fail)
|
|
}
|
|
}
|
|
|
|
func (tc *testConn) cleanup() {
|
|
if tc.conn == nil {
|
|
return
|
|
}
|
|
tc.conn.exit()
|
|
<-tc.conn.donec
|
|
}
|
|
|
|
func (tc *testConn) acceptStream() *Stream {
|
|
tc.t.Helper()
|
|
s, err := tc.conn.AcceptStream(canceledContext())
|
|
if err != nil {
|
|
tc.t.Fatalf("conn.AcceptStream() = %v, want stream", err)
|
|
}
|
|
s.SetReadContext(canceledContext())
|
|
s.SetWriteContext(canceledContext())
|
|
return s
|
|
}
|
|
|
|
func logDatagram(t *testing.T, text string, d *testDatagram) {
|
|
t.Helper()
|
|
if !*testVV {
|
|
return
|
|
}
|
|
pad := ""
|
|
if d.paddedSize > 0 {
|
|
pad = fmt.Sprintf(" (padded to %v)", d.paddedSize)
|
|
}
|
|
t.Logf("%v datagram%v", text, pad)
|
|
for _, p := range d.packets {
|
|
var s string
|
|
switch p.ptype {
|
|
case packetType1RTT:
|
|
s = fmt.Sprintf(" %v pnum=%v", p.ptype, p.num)
|
|
default:
|
|
s = fmt.Sprintf(" %v pnum=%v ver=%v dst={%x} src={%x}", p.ptype, p.num, p.version, p.dstConnID, p.srcConnID)
|
|
}
|
|
if p.token != nil {
|
|
s += fmt.Sprintf(" token={%x}", p.token)
|
|
}
|
|
if p.keyPhaseBit {
|
|
s += fmt.Sprintf(" KeyPhase")
|
|
}
|
|
if p.keyNumber != 0 {
|
|
s += fmt.Sprintf(" keynum=%v", p.keyNumber)
|
|
}
|
|
t.Log(s)
|
|
for _, f := range p.frames {
|
|
t.Logf(" %v", f)
|
|
}
|
|
}
|
|
}
|
|
|
|
// write sends the Conn a datagram.
|
|
func (tc *testConn) write(d *testDatagram) {
|
|
tc.t.Helper()
|
|
tc.endpoint.writeDatagram(d)
|
|
}
|
|
|
|
// writeFrames sends the Conn a datagram containing the given frames.
|
|
func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) {
|
|
tc.t.Helper()
|
|
space := spaceForPacketType(ptype)
|
|
dstConnID := tc.conn.connIDState.local[0].cid
|
|
if tc.conn.connIDState.local[0].seq == -1 && ptype != packetTypeInitial {
|
|
// Only use the transient connection ID in Initial packets.
|
|
dstConnID = tc.conn.connIDState.local[1].cid
|
|
}
|
|
d := &testDatagram{
|
|
packets: []*testPacket{{
|
|
ptype: ptype,
|
|
num: tc.peerNextPacketNum[space],
|
|
keyNumber: tc.sendKeyNumber,
|
|
keyPhaseBit: tc.sendKeyPhaseBit,
|
|
frames: frames,
|
|
version: quicVersion1,
|
|
dstConnID: dstConnID,
|
|
srcConnID: tc.peerConnID,
|
|
}},
|
|
addr: tc.conn.peerAddr,
|
|
}
|
|
if ptype == packetTypeInitial && tc.conn.side == serverSide {
|
|
d.paddedSize = 1200
|
|
}
|
|
tc.write(d)
|
|
}
|
|
|
|
// writeAckForAll sends the Conn a datagram containing an ack for all packets up to the
|
|
// last one received.
|
|
func (tc *testConn) writeAckForAll() {
|
|
tc.t.Helper()
|
|
if tc.lastPacket == nil {
|
|
return
|
|
}
|
|
tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{
|
|
ranges: []i64range[packetNumber]{{0, tc.lastPacket.num + 1}},
|
|
})
|
|
}
|
|
|
|
// writeAckForLatest sends the Conn a datagram containing an ack for the
|
|
// most recent packet received.
|
|
func (tc *testConn) writeAckForLatest() {
|
|
tc.t.Helper()
|
|
if tc.lastPacket == nil {
|
|
return
|
|
}
|
|
tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{
|
|
ranges: []i64range[packetNumber]{{tc.lastPacket.num, tc.lastPacket.num + 1}},
|
|
})
|
|
}
|
|
|
|
// ignoreFrame hides frames of the given type sent by the Conn.
|
|
func (tc *testConn) ignoreFrame(frameType byte) {
|
|
tc.ignoreFrames[frameType] = true
|
|
}
|
|
|
|
// readDatagram reads the next datagram sent by the Conn.
|
|
// It returns nil if the Conn has no more datagrams to send at this time.
|
|
func (tc *testConn) readDatagram() *testDatagram {
|
|
tc.t.Helper()
|
|
tc.wait()
|
|
tc.sentPackets = nil
|
|
tc.sentFrames = nil
|
|
buf := tc.endpoint.read()
|
|
if buf == nil {
|
|
return nil
|
|
}
|
|
d := parseTestDatagram(tc.t, tc.endpoint, tc, buf)
|
|
// Log the datagram before removing ignored frames.
|
|
// When things go wrong, it's useful to see all the frames.
|
|
logDatagram(tc.t, "-> conn under test sends", d)
|
|
typeForFrame := func(f debugFrame) byte {
|
|
// This is very clunky, and points at a problem
|
|
// in how we specify what frames to ignore in tests.
|
|
//
|
|
// We mark frames to ignore using the frame type,
|
|
// but we've got a debugFrame data structure here.
|
|
// Perhaps we should be ignoring frames by debugFrame
|
|
// type instead: tc.ignoreFrame[debugFrameAck]().
|
|
switch f := f.(type) {
|
|
case debugFramePadding:
|
|
return frameTypePadding
|
|
case debugFramePing:
|
|
return frameTypePing
|
|
case debugFrameAck:
|
|
return frameTypeAck
|
|
case debugFrameResetStream:
|
|
return frameTypeResetStream
|
|
case debugFrameStopSending:
|
|
return frameTypeStopSending
|
|
case debugFrameCrypto:
|
|
return frameTypeCrypto
|
|
case debugFrameNewToken:
|
|
return frameTypeNewToken
|
|
case debugFrameStream:
|
|
return frameTypeStreamBase
|
|
case debugFrameMaxData:
|
|
return frameTypeMaxData
|
|
case debugFrameMaxStreamData:
|
|
return frameTypeMaxStreamData
|
|
case debugFrameMaxStreams:
|
|
if f.streamType == bidiStream {
|
|
return frameTypeMaxStreamsBidi
|
|
} else {
|
|
return frameTypeMaxStreamsUni
|
|
}
|
|
case debugFrameDataBlocked:
|
|
return frameTypeDataBlocked
|
|
case debugFrameStreamDataBlocked:
|
|
return frameTypeStreamDataBlocked
|
|
case debugFrameStreamsBlocked:
|
|
if f.streamType == bidiStream {
|
|
return frameTypeStreamsBlockedBidi
|
|
} else {
|
|
return frameTypeStreamsBlockedUni
|
|
}
|
|
case debugFrameNewConnectionID:
|
|
return frameTypeNewConnectionID
|
|
case debugFrameRetireConnectionID:
|
|
return frameTypeRetireConnectionID
|
|
case debugFramePathChallenge:
|
|
return frameTypePathChallenge
|
|
case debugFramePathResponse:
|
|
return frameTypePathResponse
|
|
case debugFrameConnectionCloseTransport:
|
|
return frameTypeConnectionCloseTransport
|
|
case debugFrameConnectionCloseApplication:
|
|
return frameTypeConnectionCloseApplication
|
|
case debugFrameHandshakeDone:
|
|
return frameTypeHandshakeDone
|
|
}
|
|
panic(fmt.Errorf("unhandled frame type %T", f))
|
|
}
|
|
for _, p := range d.packets {
|
|
var frames []debugFrame
|
|
for _, f := range p.frames {
|
|
if !tc.ignoreFrames[typeForFrame(f)] {
|
|
frames = append(frames, f)
|
|
}
|
|
}
|
|
p.frames = frames
|
|
}
|
|
tc.lastDatagram = d
|
|
return d
|
|
}
|
|
|
|
// readPacket reads the next packet sent by the Conn.
|
|
// It returns nil if the Conn has no more packets to send at this time.
|
|
func (tc *testConn) readPacket() *testPacket {
|
|
tc.t.Helper()
|
|
for len(tc.sentPackets) == 0 {
|
|
d := tc.readDatagram()
|
|
if d == nil {
|
|
return nil
|
|
}
|
|
for _, p := range d.packets {
|
|
if len(p.frames) == 0 {
|
|
tc.lastPacket = p
|
|
continue
|
|
}
|
|
tc.sentPackets = append(tc.sentPackets, p)
|
|
}
|
|
}
|
|
p := tc.sentPackets[0]
|
|
tc.sentPackets = tc.sentPackets[1:]
|
|
tc.lastPacket = p
|
|
return p
|
|
}
|
|
|
|
// readFrame reads the next frame sent by the Conn.
|
|
// It returns nil if the Conn has no more frames to send at this time.
|
|
func (tc *testConn) readFrame() (debugFrame, packetType) {
|
|
tc.t.Helper()
|
|
for len(tc.sentFrames) == 0 {
|
|
p := tc.readPacket()
|
|
if p == nil {
|
|
return nil, packetTypeInvalid
|
|
}
|
|
tc.sentFrames = p.frames
|
|
}
|
|
f := tc.sentFrames[0]
|
|
tc.sentFrames = tc.sentFrames[1:]
|
|
return f, tc.lastPacket.ptype
|
|
}
|
|
|
|
// wantDatagram indicates that we expect the Conn to send a datagram.
|
|
func (tc *testConn) wantDatagram(expectation string, want *testDatagram) {
|
|
tc.t.Helper()
|
|
got := tc.readDatagram()
|
|
if !datagramEqual(got, want) {
|
|
tc.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want)
|
|
}
|
|
}
|
|
|
|
func datagramEqual(a, b *testDatagram) bool {
|
|
if a == nil && b == nil {
|
|
return true
|
|
}
|
|
if a == nil || b == nil {
|
|
return false
|
|
}
|
|
if a.paddedSize != b.paddedSize ||
|
|
a.addr != b.addr ||
|
|
len(a.packets) != len(b.packets) {
|
|
return false
|
|
}
|
|
for i := range a.packets {
|
|
if !packetEqual(a.packets[i], b.packets[i]) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// wantPacket indicates that we expect the Conn to send a packet.
|
|
func (tc *testConn) wantPacket(expectation string, want *testPacket) {
|
|
tc.t.Helper()
|
|
got := tc.readPacket()
|
|
if !packetEqual(got, want) {
|
|
tc.t.Fatalf("%v:\ngot packet: %v\nwant packet: %v", expectation, got, want)
|
|
}
|
|
}
|
|
|
|
func packetEqual(a, b *testPacket) bool {
|
|
if a == nil && b == nil {
|
|
return true
|
|
}
|
|
if a == nil || b == nil {
|
|
return false
|
|
}
|
|
ac := *a
|
|
ac.frames = nil
|
|
ac.header = 0
|
|
bc := *b
|
|
bc.frames = nil
|
|
bc.header = 0
|
|
if !reflect.DeepEqual(ac, bc) {
|
|
return false
|
|
}
|
|
if len(a.frames) != len(b.frames) {
|
|
return false
|
|
}
|
|
for i := range a.frames {
|
|
if !frameEqual(a.frames[i], b.frames[i]) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// wantFrame indicates that we expect the Conn to send a frame.
|
|
func (tc *testConn) wantFrame(expectation string, wantType packetType, want debugFrame) {
|
|
tc.t.Helper()
|
|
got, gotType := tc.readFrame()
|
|
if got == nil {
|
|
tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want)
|
|
}
|
|
if gotType != wantType {
|
|
tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got)
|
|
}
|
|
if !frameEqual(got, want) {
|
|
tc.t.Fatalf("%v:\ngot frame: %v\nwant frame: %v", expectation, got, want)
|
|
}
|
|
}
|
|
|
|
func frameEqual(a, b debugFrame) bool {
|
|
switch af := a.(type) {
|
|
case debugFrameConnectionCloseTransport:
|
|
bf, ok := b.(debugFrameConnectionCloseTransport)
|
|
return ok && af.code == bf.code
|
|
}
|
|
return reflect.DeepEqual(a, b)
|
|
}
|
|
|
|
// wantFrameType indicates that we expect the Conn to send a frame,
|
|
// although we don't care about the contents.
|
|
func (tc *testConn) wantFrameType(expectation string, wantType packetType, want debugFrame) {
|
|
tc.t.Helper()
|
|
got, gotType := tc.readFrame()
|
|
if got == nil {
|
|
tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want)
|
|
}
|
|
if gotType != wantType {
|
|
tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got)
|
|
}
|
|
if reflect.TypeOf(got) != reflect.TypeOf(want) {
|
|
tc.t.Fatalf("%v:\ngot frame: %v\nwant frame of type: %v", expectation, got, want)
|
|
}
|
|
}
|
|
|
|
// wantIdle indicates that we expect the Conn to not send any more frames.
|
|
func (tc *testConn) wantIdle(expectation string) {
|
|
tc.t.Helper()
|
|
switch {
|
|
case len(tc.sentFrames) > 0:
|
|
tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, tc.sentFrames[0])
|
|
case len(tc.sentPackets) > 0:
|
|
tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, tc.sentPackets[0])
|
|
}
|
|
if f, _ := tc.readFrame(); f != nil {
|
|
tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, f)
|
|
}
|
|
}
|
|
|
|
func encodeTestPacket(t *testing.T, tc *testConn, p *testPacket, pad int) []byte {
|
|
t.Helper()
|
|
var w packetWriter
|
|
w.reset(1200)
|
|
var pnumMaxAcked packetNumber
|
|
switch p.ptype {
|
|
case packetTypeRetry:
|
|
return encodeRetryPacket(p.originalDstConnID, retryPacket{
|
|
srcConnID: p.srcConnID,
|
|
dstConnID: p.dstConnID,
|
|
token: p.token,
|
|
})
|
|
case packetType1RTT:
|
|
w.start1RTTPacket(p.num, pnumMaxAcked, p.dstConnID)
|
|
default:
|
|
w.startProtectedLongHeaderPacket(pnumMaxAcked, longPacket{
|
|
ptype: p.ptype,
|
|
version: p.version,
|
|
num: p.num,
|
|
dstConnID: p.dstConnID,
|
|
srcConnID: p.srcConnID,
|
|
extra: p.token,
|
|
})
|
|
}
|
|
for _, f := range p.frames {
|
|
f.write(&w)
|
|
}
|
|
w.appendPaddingTo(pad)
|
|
if p.ptype != packetType1RTT {
|
|
var k fixedKeys
|
|
if tc == nil {
|
|
if p.ptype == packetTypeInitial {
|
|
k = initialKeys(p.dstConnID, serverSide).r
|
|
} else {
|
|
t.Fatalf("sending %v packet with no conn", p.ptype)
|
|
}
|
|
} else {
|
|
switch p.ptype {
|
|
case packetTypeInitial:
|
|
k = tc.keysInitial.w
|
|
case packetTypeHandshake:
|
|
k = tc.keysHandshake.w
|
|
}
|
|
}
|
|
if !k.isSet() {
|
|
t.Fatalf("sending %v packet with no write key", p.ptype)
|
|
}
|
|
w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, longPacket{
|
|
ptype: p.ptype,
|
|
version: p.version,
|
|
num: p.num,
|
|
dstConnID: p.dstConnID,
|
|
srcConnID: p.srcConnID,
|
|
extra: p.token,
|
|
})
|
|
} else {
|
|
if tc == nil || !tc.wkeyAppData.hdr.isSet() {
|
|
t.Fatalf("sending 1-RTT packet with no write key")
|
|
}
|
|
// Somewhat hackish: Generate a temporary updatingKeyPair that will
|
|
// always use our desired key phase.
|
|
k := &updatingKeyPair{
|
|
w: updatingKeys{
|
|
hdr: tc.wkeyAppData.hdr,
|
|
pkt: [2]packetKey{
|
|
tc.wkeyAppData.pkt[p.keyNumber],
|
|
tc.wkeyAppData.pkt[p.keyNumber],
|
|
},
|
|
},
|
|
updateAfter: maxPacketNumber,
|
|
}
|
|
if p.keyPhaseBit {
|
|
k.phase |= keyPhaseBit
|
|
}
|
|
w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, k)
|
|
}
|
|
return w.datagram()
|
|
}
|
|
|
|
func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte) *testDatagram {
|
|
t.Helper()
|
|
bufSize := len(buf)
|
|
d := &testDatagram{}
|
|
size := len(buf)
|
|
for len(buf) > 0 {
|
|
if buf[0] == 0 {
|
|
d.paddedSize = bufSize
|
|
break
|
|
}
|
|
ptype := getPacketType(buf)
|
|
switch ptype {
|
|
case packetTypeRetry:
|
|
retry, ok := parseRetryPacket(buf, te.lastInitialDstConnID)
|
|
if !ok {
|
|
t.Fatalf("could not parse %v packet", ptype)
|
|
}
|
|
return &testDatagram{
|
|
packets: []*testPacket{{
|
|
ptype: packetTypeRetry,
|
|
dstConnID: retry.dstConnID,
|
|
srcConnID: retry.srcConnID,
|
|
token: retry.token,
|
|
}},
|
|
}
|
|
case packetTypeInitial, packetTypeHandshake:
|
|
var k fixedKeys
|
|
var pnumMax packetNumber
|
|
if tc == nil {
|
|
if ptype == packetTypeInitial {
|
|
p, _ := parseGenericLongHeaderPacket(buf)
|
|
k = initialKeys(p.srcConnID, serverSide).w
|
|
} else {
|
|
t.Fatalf("reading %v packet with no conn", ptype)
|
|
}
|
|
} else {
|
|
switch ptype {
|
|
case packetTypeInitial:
|
|
k = tc.keysInitial.r
|
|
pnumMax = tc.pnumMax[initialSpace]
|
|
case packetTypeHandshake:
|
|
k = tc.keysHandshake.r
|
|
pnumMax = tc.pnumMax[handshakeSpace]
|
|
}
|
|
}
|
|
if !k.isSet() {
|
|
t.Fatalf("reading %v packet with no read key", ptype)
|
|
}
|
|
p, n := parseLongHeaderPacket(buf, k, pnumMax)
|
|
if n < 0 {
|
|
t.Fatalf("packet parse error")
|
|
}
|
|
if tc != nil {
|
|
switch ptype {
|
|
case packetTypeInitial:
|
|
tc.pnumMax[initialSpace] = max(pnumMax, p.num)
|
|
case packetTypeHandshake:
|
|
tc.pnumMax[handshakeSpace] = max(pnumMax, p.num)
|
|
}
|
|
}
|
|
frames, err := parseTestFrames(t, p.payload)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
var token []byte
|
|
if ptype == packetTypeInitial && len(p.extra) > 0 {
|
|
token = p.extra
|
|
}
|
|
d.packets = append(d.packets, &testPacket{
|
|
ptype: p.ptype,
|
|
header: buf[0],
|
|
version: p.version,
|
|
num: p.num,
|
|
dstConnID: p.dstConnID,
|
|
srcConnID: p.srcConnID,
|
|
token: token,
|
|
frames: frames,
|
|
})
|
|
buf = buf[n:]
|
|
case packetType1RTT:
|
|
if tc == nil || !tc.rkeyAppData.hdr.isSet() {
|
|
t.Fatalf("reading 1-RTT packet with no read key")
|
|
}
|
|
var pnumMax packetNumber
|
|
if tc != nil {
|
|
pnumMax = tc.pnumMax[appDataSpace]
|
|
}
|
|
pnumOff := 1 + len(tc.peerConnID)
|
|
// Try unprotecting the packet with the first maxTestKeyPhases keys.
|
|
var phase int
|
|
var pnum packetNumber
|
|
var hdr []byte
|
|
var pay []byte
|
|
var err error
|
|
for phase = 0; phase < maxTestKeyPhases; phase++ {
|
|
b := append([]byte{}, buf...)
|
|
hdr, pay, pnum, err = tc.rkeyAppData.hdr.unprotect(b, pnumOff, pnumMax)
|
|
if err != nil {
|
|
t.Fatalf("1-RTT packet header parse error")
|
|
}
|
|
k := tc.rkeyAppData.pkt[phase]
|
|
pay, err = k.unprotect(hdr, pay, pnum)
|
|
if err == nil {
|
|
break
|
|
}
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("1-RTT packet payload parse error")
|
|
}
|
|
if tc != nil {
|
|
tc.pnumMax[appDataSpace] = max(pnumMax, pnum)
|
|
}
|
|
frames, err := parseTestFrames(t, pay)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
d.packets = append(d.packets, &testPacket{
|
|
ptype: packetType1RTT,
|
|
header: hdr[0],
|
|
num: pnum,
|
|
dstConnID: hdr[1:][:len(tc.peerConnID)],
|
|
keyPhaseBit: hdr[0]&keyPhaseBit != 0,
|
|
keyNumber: phase,
|
|
frames: frames,
|
|
})
|
|
buf = buf[len(buf):]
|
|
default:
|
|
t.Fatalf("unhandled packet type %v", ptype)
|
|
}
|
|
}
|
|
// This is rather hackish: If the last frame in the last packet
|
|
// in the datagram is PADDING, then remove it and record
|
|
// the padded size in the testDatagram.paddedSize.
|
|
//
|
|
// This makes it easier to write a test that expects a datagram
|
|
// padded to 1200 bytes.
|
|
if len(d.packets) > 0 && len(d.packets[len(d.packets)-1].frames) > 0 {
|
|
p := d.packets[len(d.packets)-1]
|
|
f := p.frames[len(p.frames)-1]
|
|
if _, ok := f.(debugFramePadding); ok {
|
|
p.frames = p.frames[:len(p.frames)-1]
|
|
d.paddedSize = size
|
|
}
|
|
}
|
|
return d
|
|
}
|
|
|
|
func parseTestFrames(t *testing.T, payload []byte) ([]debugFrame, error) {
|
|
t.Helper()
|
|
var frames []debugFrame
|
|
for len(payload) > 0 {
|
|
f, n := parseDebugFrame(payload)
|
|
if n < 0 {
|
|
return nil, errors.New("error parsing frames")
|
|
}
|
|
frames = append(frames, f)
|
|
payload = payload[n:]
|
|
}
|
|
return frames, nil
|
|
}
|
|
|
|
func spaceForPacketType(ptype packetType) numberSpace {
|
|
switch ptype {
|
|
case packetTypeInitial:
|
|
return initialSpace
|
|
case packetType0RTT:
|
|
panic("TODO: packetType0RTT")
|
|
case packetTypeHandshake:
|
|
return handshakeSpace
|
|
case packetTypeRetry:
|
|
panic("retry packets have no number space")
|
|
case packetType1RTT:
|
|
return appDataSpace
|
|
}
|
|
panic("unknown packet type")
|
|
}
|
|
|
|
// testConnHooks implements connTestHooks.
|
|
type testConnHooks testConn
|
|
|
|
func (tc *testConnHooks) init() {
|
|
tc.conn.keysAppData.updateAfter = maxPacketNumber // disable key updates
|
|
tc.keysInitial.r = tc.conn.keysInitial.w
|
|
tc.keysInitial.w = tc.conn.keysInitial.r
|
|
if tc.conn.side == serverSide {
|
|
tc.endpoint.acceptQueue = append(tc.endpoint.acceptQueue, (*testConn)(tc))
|
|
}
|
|
}
|
|
|
|
// handleTLSEvent processes TLS events generated by
|
|
// the connection under test's tls.QUICConn.
|
|
//
|
|
// We maintain a second tls.QUICConn representing the peer,
|
|
// and feed the TLS handshake data into it.
|
|
//
|
|
// We stash TLS handshake data from both sides in the testConn,
|
|
// where it can be used by tests.
|
|
//
|
|
// We snoop packet protection keys out of the tls.QUICConns,
|
|
// and verify that both sides of the connection are getting
|
|
// matching keys.
|
|
func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
|
|
checkKey := func(typ string, secrets *[numberSpaceCount]keySecret, e tls.QUICEvent) {
|
|
var space numberSpace
|
|
switch {
|
|
case e.Level == tls.QUICEncryptionLevelHandshake:
|
|
space = handshakeSpace
|
|
case e.Level == tls.QUICEncryptionLevelApplication:
|
|
space = appDataSpace
|
|
default:
|
|
tc.t.Errorf("unexpected encryption level %v", e.Level)
|
|
return
|
|
}
|
|
if secrets[space].secret == nil {
|
|
secrets[space].suite = e.Suite
|
|
secrets[space].secret = append([]byte{}, e.Data...)
|
|
} else if secrets[space].suite != e.Suite || !bytes.Equal(secrets[space].secret, e.Data) {
|
|
tc.t.Errorf("%v key mismatch for level for level %v", typ, e.Level)
|
|
}
|
|
}
|
|
setAppDataKey := func(suite uint16, secret []byte, k *test1RTTKeys) {
|
|
k.hdr.init(suite, secret)
|
|
for i := 0; i < len(k.pkt); i++ {
|
|
k.pkt[i].init(suite, secret)
|
|
secret = updateSecret(suite, secret)
|
|
}
|
|
}
|
|
switch e.Kind {
|
|
case tls.QUICSetReadSecret:
|
|
checkKey("write", &tc.wsecrets, e)
|
|
switch e.Level {
|
|
case tls.QUICEncryptionLevelHandshake:
|
|
tc.keysHandshake.w.init(e.Suite, e.Data)
|
|
case tls.QUICEncryptionLevelApplication:
|
|
setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData)
|
|
}
|
|
case tls.QUICSetWriteSecret:
|
|
checkKey("read", &tc.rsecrets, e)
|
|
switch e.Level {
|
|
case tls.QUICEncryptionLevelHandshake:
|
|
tc.keysHandshake.r.init(e.Suite, e.Data)
|
|
case tls.QUICEncryptionLevelApplication:
|
|
setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData)
|
|
}
|
|
case tls.QUICWriteData:
|
|
tc.cryptoDataOut[e.Level] = append(tc.cryptoDataOut[e.Level], e.Data...)
|
|
tc.peerTLSConn.HandleData(e.Level, e.Data)
|
|
}
|
|
for {
|
|
e := tc.peerTLSConn.NextEvent()
|
|
switch e.Kind {
|
|
case tls.QUICNoEvent:
|
|
return
|
|
case tls.QUICSetReadSecret:
|
|
checkKey("write", &tc.rsecrets, e)
|
|
switch e.Level {
|
|
case tls.QUICEncryptionLevelHandshake:
|
|
tc.keysHandshake.r.init(e.Suite, e.Data)
|
|
case tls.QUICEncryptionLevelApplication:
|
|
setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData)
|
|
}
|
|
case tls.QUICSetWriteSecret:
|
|
checkKey("read", &tc.wsecrets, e)
|
|
switch e.Level {
|
|
case tls.QUICEncryptionLevelHandshake:
|
|
tc.keysHandshake.w.init(e.Suite, e.Data)
|
|
case tls.QUICEncryptionLevelApplication:
|
|
setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData)
|
|
}
|
|
case tls.QUICWriteData:
|
|
tc.cryptoDataIn[e.Level] = append(tc.cryptoDataIn[e.Level], e.Data...)
|
|
case tls.QUICTransportParameters:
|
|
p, err := unmarshalTransportParams(e.Data)
|
|
if err != nil {
|
|
tc.t.Logf("sent unparseable transport parameters %x %v", e.Data, err)
|
|
} else {
|
|
tc.sentTransportParameters = &p
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// nextMessage is called by the Conn's event loop to request its next event.
|
|
func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.Time, m any) {
|
|
tc.timer = timer
|
|
for {
|
|
if !timer.IsZero() && !timer.After(tc.endpoint.now) {
|
|
if timer.Equal(tc.timerLastFired) {
|
|
// If the connection timer fires at time T, the Conn should take some
|
|
// action to advance the timer into the future. If the Conn reschedules
|
|
// the timer for the same time, it isn't making progress and we have a bug.
|
|
tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.endpoint.now, timer)
|
|
} else {
|
|
tc.timerLastFired = timer
|
|
return tc.endpoint.now, timerEvent{}
|
|
}
|
|
}
|
|
select {
|
|
case m := <-msgc:
|
|
return tc.endpoint.now, m
|
|
default:
|
|
}
|
|
if !tc.wakeAsync() {
|
|
break
|
|
}
|
|
}
|
|
// If the message queue is empty, then the conn is idle.
|
|
if tc.idlec != nil {
|
|
idlec := tc.idlec
|
|
tc.idlec = nil
|
|
close(idlec)
|
|
}
|
|
m = <-msgc
|
|
return tc.endpoint.now, m
|
|
}
|
|
|
|
func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) {
|
|
return testLocalConnID(seq), nil
|
|
}
|
|
|
|
func (tc *testConnHooks) timeNow() time.Time {
|
|
return tc.endpoint.now
|
|
}
|
|
|
|
// testLocalConnID returns the connection ID with a given sequence number
|
|
// used by a Conn under test.
|
|
func testLocalConnID(seq int64) []byte {
|
|
cid := make([]byte, connIDLen)
|
|
copy(cid, []byte{0xc0, 0xff, 0xee})
|
|
cid[len(cid)-1] = byte(seq)
|
|
return cid
|
|
}
|
|
|
|
// testPeerConnID returns the connection ID with a given sequence number
|
|
// used by the fake peer of a Conn under test.
|
|
func testPeerConnID(seq int64) []byte {
|
|
// Use a different length than we choose for our own conn ids,
|
|
// to help catch any bad assumptions.
|
|
return []byte{0xbe, 0xee, 0xff, byte(seq)}
|
|
}
|
|
|
|
func testPeerStatelessResetToken(seq int64) statelessResetToken {
|
|
return statelessResetToken{
|
|
0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee,
|
|
0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, byte(seq),
|
|
}
|
|
}
|
|
|
|
// canceledContext returns a canceled Context.
|
|
//
|
|
// Functions which take a context preference progress over cancelation.
|
|
// For example, a read with a canceled context will return data if any is available.
|
|
// Tests use canceled contexts to perform non-blocking operations.
|
|
func canceledContext() context.Context {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
return ctx
|
|
}
|