Files
golang.net/quic/conn_test.go
Damien Neil 22500a668c quic: decode packet numbers >255 in tests
Decoding QUIC packet numbers requires keeping track of the largest
packet number received so far from the peer. Our tests haven't
bothered doing that so far, so tests can't work with packet
numbers past 255.

Fix that so we can write tests that use more packets.

Change-Id: Icb795e5cf69794381c12a3a03b0da6bcf47a69c0
Reviewed-on: https://go-review.googlesource.com/c/net/+/664296
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
Auto-Submit: Damien Neil <dneil@google.com>
2025-04-09 15:11:48 -07:00

1182 lines
32 KiB
Go

// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
"flag"
"fmt"
"log/slog"
"math"
"net/netip"
"reflect"
"strings"
"testing"
"time"
"golang.org/x/net/quic/qlog"
)
var (
testVV = flag.Bool("vv", false, "even more verbose test output")
qlogdir = flag.String("qlog", "", "write qlog logs to directory")
)
func TestConnTestConn(t *testing.T) {
tc := newTestConn(t, serverSide)
tc.handshake()
if got, want := tc.timeUntilEvent(), defaultMaxIdleTimeout; got != want {
t.Errorf("new conn timeout=%v, want %v (max_idle_timeout)", got, want)
}
ranAt, _ := runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
when = now
})
return
}).result()
if !ranAt.Equal(tc.endpoint.now) {
t.Errorf("func ran on loop at %v, want %v", ranAt, tc.endpoint.now)
}
tc.wait()
nextTime := tc.endpoint.now.Add(defaultMaxIdleTimeout / 2)
tc.advanceTo(nextTime)
ranAt, _ = runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
when = now
})
return
}).result()
if !ranAt.Equal(nextTime) {
t.Errorf("func ran on loop at %v, want %v", ranAt, nextTime)
}
tc.wait()
tc.advanceToTimer()
if got := tc.conn.lifetime.state; got != connStateDone {
t.Errorf("after advancing to idle timeout, conn state = %v, want done", got)
}
}
type testDatagram struct {
packets []*testPacket
paddedSize int
addr netip.AddrPort
}
func (d testDatagram) String() string {
var b strings.Builder
fmt.Fprintf(&b, "datagram with %v packets", len(d.packets))
if d.paddedSize > 0 {
fmt.Fprintf(&b, " (padded to %v bytes)", d.paddedSize)
}
b.WriteString(":")
for _, p := range d.packets {
b.WriteString("\n")
b.WriteString(p.String())
}
return b.String()
}
type testPacket struct {
ptype packetType
header byte
version uint32
num packetNumber
keyPhaseBit bool
keyNumber int
dstConnID []byte
srcConnID []byte
token []byte
originalDstConnID []byte // used for encoding Retry packets
frames []debugFrame
}
func (p testPacket) String() string {
var b strings.Builder
fmt.Fprintf(&b, " %v %v", p.ptype, p.num)
if p.version != 0 {
fmt.Fprintf(&b, " version=%v", p.version)
}
if p.srcConnID != nil {
fmt.Fprintf(&b, " src={%x}", p.srcConnID)
}
if p.dstConnID != nil {
fmt.Fprintf(&b, " dst={%x}", p.dstConnID)
}
if p.token != nil {
fmt.Fprintf(&b, " token={%x}", p.token)
}
for _, f := range p.frames {
fmt.Fprintf(&b, "\n %v", f)
}
return b.String()
}
// maxTestKeyPhases is the maximum number of 1-RTT keys we'll generate in a test.
const maxTestKeyPhases = 3
// A testConn is a Conn whose external interactions (sending and receiving packets,
// setting timers) can be manipulated in tests.
type testConn struct {
t *testing.T
conn *Conn
endpoint *testEndpoint
timer time.Time
timerLastFired time.Time
idlec chan struct{} // only accessed on the conn's loop
// Keys are distinct from the conn's keys,
// because the test may know about keys before the conn does.
// For example, when sending a datagram with coalesced
// Initial and Handshake packets to a client conn,
// we use Handshake keys to encrypt the packet.
// The client only acquires those keys when it processes
// the Initial packet.
keysInitial fixedKeyPair
keysHandshake fixedKeyPair
rkeyAppData test1RTTKeys
wkeyAppData test1RTTKeys
rsecrets [numberSpaceCount]keySecret
wsecrets [numberSpaceCount]keySecret
// testConn uses a test hook to snoop on the conn's TLS events.
// CRYPTO data produced by the conn's QUICConn is placed in
// cryptoDataOut.
//
// The peerTLSConn is is a QUICConn representing the peer.
// CRYPTO data produced by the conn is written to peerTLSConn,
// and data produced by peerTLSConn is placed in cryptoDataIn.
cryptoDataOut map[tls.QUICEncryptionLevel][]byte
cryptoDataIn map[tls.QUICEncryptionLevel][]byte
peerTLSConn *tls.QUICConn
// Information about the conn's (fake) peer.
peerConnID []byte // source conn id of peer's packets
peerNextPacketNum [numberSpaceCount]packetNumber // next packet number to use
// Maximum packet number received from the conn.
pnumMax [numberSpaceCount]packetNumber
// Datagrams, packets, and frames sent by the conn,
// but not yet processed by the test.
sentDatagrams [][]byte
sentPackets []*testPacket
sentFrames []debugFrame
lastDatagram *testDatagram
lastPacket *testPacket
recvDatagram chan *datagram
// Transport parameters sent by the conn.
sentTransportParameters *transportParameters
// Frame types to ignore in tests.
ignoreFrames map[byte]bool
// Values to set in packets sent to the conn.
sendKeyNumber int
sendKeyPhaseBit bool
asyncTestState
}
type test1RTTKeys struct {
hdr headerKey
pkt [maxTestKeyPhases]packetKey
}
type keySecret struct {
suite uint16
secret []byte
}
// newTestConn creates a Conn for testing.
//
// The Conn's event loop is controlled by the test,
// allowing test code to access Conn state directly
// by first ensuring the loop goroutine is idle.
func newTestConn(t *testing.T, side connSide, opts ...any) *testConn {
t.Helper()
config := &Config{
TLSConfig: newTestTLSConfig(side),
StatelessResetKey: testStatelessResetKey,
QLogLogger: slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{
Level: QLogLevelFrame,
Dir: *qlogdir,
})),
}
var cids newServerConnIDs
if side == serverSide {
// The initial connection ID for the server is chosen by the client.
cids.srcConnID = testPeerConnID(0)
cids.dstConnID = testPeerConnID(-1)
cids.originalDstConnID = cids.dstConnID
}
var configTransportParams []func(*transportParameters)
var configTestConn []func(*testConn)
for _, o := range opts {
switch o := o.(type) {
case func(*Config):
o(config)
case func(*tls.Config):
o(config.TLSConfig)
case func(cids *newServerConnIDs):
o(&cids)
case func(p *transportParameters):
configTransportParams = append(configTransportParams, o)
case func(p *testConn):
configTestConn = append(configTestConn, o)
default:
t.Fatalf("unknown newTestConn option %T", o)
}
}
endpoint := newTestEndpoint(t, config)
endpoint.configTransportParams = configTransportParams
endpoint.configTestConn = configTestConn
conn, err := endpoint.e.newConn(
endpoint.now,
config,
side,
cids,
"",
netip.MustParseAddrPort("127.0.0.1:443"))
if err != nil {
t.Fatal(err)
}
tc := endpoint.conns[conn]
tc.wait()
return tc
}
func newTestConnForConn(t *testing.T, endpoint *testEndpoint, conn *Conn) *testConn {
t.Helper()
tc := &testConn{
t: t,
endpoint: endpoint,
conn: conn,
peerConnID: testPeerConnID(0),
ignoreFrames: map[byte]bool{
frameTypePadding: true, // ignore PADDING by default
},
cryptoDataOut: make(map[tls.QUICEncryptionLevel][]byte),
cryptoDataIn: make(map[tls.QUICEncryptionLevel][]byte),
recvDatagram: make(chan *datagram),
}
t.Cleanup(tc.cleanup)
for _, f := range endpoint.configTestConn {
f(tc)
}
conn.testHooks = (*testConnHooks)(tc)
if endpoint.peerTLSConn != nil {
tc.peerTLSConn = endpoint.peerTLSConn
endpoint.peerTLSConn = nil
return tc
}
peerProvidedParams := defaultTransportParameters()
peerProvidedParams.initialSrcConnID = testPeerConnID(0)
if conn.side == clientSide {
peerProvidedParams.originalDstConnID = testLocalConnID(-1)
}
for _, f := range endpoint.configTransportParams {
f(&peerProvidedParams)
}
peerQUICConfig := &tls.QUICConfig{TLSConfig: newTestTLSConfig(conn.side.peer())}
if conn.side == clientSide {
tc.peerTLSConn = tls.QUICServer(peerQUICConfig)
} else {
tc.peerTLSConn = tls.QUICClient(peerQUICConfig)
}
tc.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams))
tc.peerTLSConn.Start(context.Background())
t.Cleanup(func() {
tc.peerTLSConn.Close()
})
return tc
}
// advance causes time to pass.
func (tc *testConn) advance(d time.Duration) {
tc.t.Helper()
tc.endpoint.advance(d)
}
// advanceTo sets the current time.
func (tc *testConn) advanceTo(now time.Time) {
tc.t.Helper()
tc.endpoint.advanceTo(now)
}
// advanceToTimer sets the current time to the time of the Conn's next timer event.
func (tc *testConn) advanceToTimer() {
if tc.timer.IsZero() {
tc.t.Fatalf("advancing to timer, but timer is not set")
}
tc.advanceTo(tc.timer)
}
func (tc *testConn) timerDelay() time.Duration {
if tc.timer.IsZero() {
return math.MaxInt64 // infinite
}
if tc.timer.Before(tc.endpoint.now) {
return 0
}
return tc.timer.Sub(tc.endpoint.now)
}
const infiniteDuration = time.Duration(math.MaxInt64)
// timeUntilEvent returns the amount of time until the next connection event.
func (tc *testConn) timeUntilEvent() time.Duration {
if tc.timer.IsZero() {
return infiniteDuration
}
if tc.timer.Before(tc.endpoint.now) {
return 0
}
return tc.timer.Sub(tc.endpoint.now)
}
// wait blocks until the conn becomes idle.
// The conn is idle when it is blocked waiting for a packet to arrive or a timer to expire.
// Tests shouldn't need to call wait directly.
// testConn methods that wake the Conn event loop will call wait for them.
func (tc *testConn) wait() {
tc.t.Helper()
idlec := make(chan struct{})
fail := false
tc.conn.sendMsg(func(now time.Time, c *Conn) {
if tc.idlec != nil {
tc.t.Errorf("testConn.wait called concurrently")
fail = true
close(idlec)
} else {
// nextMessage will close idlec.
tc.idlec = idlec
}
})
select {
case <-idlec:
case <-tc.conn.donec:
// We may have async ops that can proceed now that the conn is done.
tc.wakeAsync()
}
if fail {
panic(fail)
}
}
func (tc *testConn) cleanup() {
if tc.conn == nil {
return
}
tc.conn.exit()
<-tc.conn.donec
}
func (tc *testConn) acceptStream() *Stream {
tc.t.Helper()
s, err := tc.conn.AcceptStream(canceledContext())
if err != nil {
tc.t.Fatalf("conn.AcceptStream() = %v, want stream", err)
}
s.SetReadContext(canceledContext())
s.SetWriteContext(canceledContext())
return s
}
func logDatagram(t *testing.T, text string, d *testDatagram) {
t.Helper()
if !*testVV {
return
}
pad := ""
if d.paddedSize > 0 {
pad = fmt.Sprintf(" (padded to %v)", d.paddedSize)
}
t.Logf("%v datagram%v", text, pad)
for _, p := range d.packets {
var s string
switch p.ptype {
case packetType1RTT:
s = fmt.Sprintf(" %v pnum=%v", p.ptype, p.num)
default:
s = fmt.Sprintf(" %v pnum=%v ver=%v dst={%x} src={%x}", p.ptype, p.num, p.version, p.dstConnID, p.srcConnID)
}
if p.token != nil {
s += fmt.Sprintf(" token={%x}", p.token)
}
if p.keyPhaseBit {
s += fmt.Sprintf(" KeyPhase")
}
if p.keyNumber != 0 {
s += fmt.Sprintf(" keynum=%v", p.keyNumber)
}
t.Log(s)
for _, f := range p.frames {
t.Logf(" %v", f)
}
}
}
// write sends the Conn a datagram.
func (tc *testConn) write(d *testDatagram) {
tc.t.Helper()
tc.endpoint.writeDatagram(d)
}
// writeFrames sends the Conn a datagram containing the given frames.
func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) {
tc.t.Helper()
space := spaceForPacketType(ptype)
dstConnID := tc.conn.connIDState.local[0].cid
if tc.conn.connIDState.local[0].seq == -1 && ptype != packetTypeInitial {
// Only use the transient connection ID in Initial packets.
dstConnID = tc.conn.connIDState.local[1].cid
}
d := &testDatagram{
packets: []*testPacket{{
ptype: ptype,
num: tc.peerNextPacketNum[space],
keyNumber: tc.sendKeyNumber,
keyPhaseBit: tc.sendKeyPhaseBit,
frames: frames,
version: quicVersion1,
dstConnID: dstConnID,
srcConnID: tc.peerConnID,
}},
addr: tc.conn.peerAddr,
}
if ptype == packetTypeInitial && tc.conn.side == serverSide {
d.paddedSize = 1200
}
tc.write(d)
}
// writeAckForAll sends the Conn a datagram containing an ack for all packets up to the
// last one received.
func (tc *testConn) writeAckForAll() {
tc.t.Helper()
if tc.lastPacket == nil {
return
}
tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{
ranges: []i64range[packetNumber]{{0, tc.lastPacket.num + 1}},
})
}
// writeAckForLatest sends the Conn a datagram containing an ack for the
// most recent packet received.
func (tc *testConn) writeAckForLatest() {
tc.t.Helper()
if tc.lastPacket == nil {
return
}
tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{
ranges: []i64range[packetNumber]{{tc.lastPacket.num, tc.lastPacket.num + 1}},
})
}
// ignoreFrame hides frames of the given type sent by the Conn.
func (tc *testConn) ignoreFrame(frameType byte) {
tc.ignoreFrames[frameType] = true
}
// readDatagram reads the next datagram sent by the Conn.
// It returns nil if the Conn has no more datagrams to send at this time.
func (tc *testConn) readDatagram() *testDatagram {
tc.t.Helper()
tc.wait()
tc.sentPackets = nil
tc.sentFrames = nil
buf := tc.endpoint.read()
if buf == nil {
return nil
}
d := parseTestDatagram(tc.t, tc.endpoint, tc, buf)
// Log the datagram before removing ignored frames.
// When things go wrong, it's useful to see all the frames.
logDatagram(tc.t, "-> conn under test sends", d)
typeForFrame := func(f debugFrame) byte {
// This is very clunky, and points at a problem
// in how we specify what frames to ignore in tests.
//
// We mark frames to ignore using the frame type,
// but we've got a debugFrame data structure here.
// Perhaps we should be ignoring frames by debugFrame
// type instead: tc.ignoreFrame[debugFrameAck]().
switch f := f.(type) {
case debugFramePadding:
return frameTypePadding
case debugFramePing:
return frameTypePing
case debugFrameAck:
return frameTypeAck
case debugFrameResetStream:
return frameTypeResetStream
case debugFrameStopSending:
return frameTypeStopSending
case debugFrameCrypto:
return frameTypeCrypto
case debugFrameNewToken:
return frameTypeNewToken
case debugFrameStream:
return frameTypeStreamBase
case debugFrameMaxData:
return frameTypeMaxData
case debugFrameMaxStreamData:
return frameTypeMaxStreamData
case debugFrameMaxStreams:
if f.streamType == bidiStream {
return frameTypeMaxStreamsBidi
} else {
return frameTypeMaxStreamsUni
}
case debugFrameDataBlocked:
return frameTypeDataBlocked
case debugFrameStreamDataBlocked:
return frameTypeStreamDataBlocked
case debugFrameStreamsBlocked:
if f.streamType == bidiStream {
return frameTypeStreamsBlockedBidi
} else {
return frameTypeStreamsBlockedUni
}
case debugFrameNewConnectionID:
return frameTypeNewConnectionID
case debugFrameRetireConnectionID:
return frameTypeRetireConnectionID
case debugFramePathChallenge:
return frameTypePathChallenge
case debugFramePathResponse:
return frameTypePathResponse
case debugFrameConnectionCloseTransport:
return frameTypeConnectionCloseTransport
case debugFrameConnectionCloseApplication:
return frameTypeConnectionCloseApplication
case debugFrameHandshakeDone:
return frameTypeHandshakeDone
}
panic(fmt.Errorf("unhandled frame type %T", f))
}
for _, p := range d.packets {
var frames []debugFrame
for _, f := range p.frames {
if !tc.ignoreFrames[typeForFrame(f)] {
frames = append(frames, f)
}
}
p.frames = frames
}
tc.lastDatagram = d
return d
}
// readPacket reads the next packet sent by the Conn.
// It returns nil if the Conn has no more packets to send at this time.
func (tc *testConn) readPacket() *testPacket {
tc.t.Helper()
for len(tc.sentPackets) == 0 {
d := tc.readDatagram()
if d == nil {
return nil
}
for _, p := range d.packets {
if len(p.frames) == 0 {
tc.lastPacket = p
continue
}
tc.sentPackets = append(tc.sentPackets, p)
}
}
p := tc.sentPackets[0]
tc.sentPackets = tc.sentPackets[1:]
tc.lastPacket = p
return p
}
// readFrame reads the next frame sent by the Conn.
// It returns nil if the Conn has no more frames to send at this time.
func (tc *testConn) readFrame() (debugFrame, packetType) {
tc.t.Helper()
for len(tc.sentFrames) == 0 {
p := tc.readPacket()
if p == nil {
return nil, packetTypeInvalid
}
tc.sentFrames = p.frames
}
f := tc.sentFrames[0]
tc.sentFrames = tc.sentFrames[1:]
return f, tc.lastPacket.ptype
}
// wantDatagram indicates that we expect the Conn to send a datagram.
func (tc *testConn) wantDatagram(expectation string, want *testDatagram) {
tc.t.Helper()
got := tc.readDatagram()
if !datagramEqual(got, want) {
tc.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want)
}
}
func datagramEqual(a, b *testDatagram) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
if a.paddedSize != b.paddedSize ||
a.addr != b.addr ||
len(a.packets) != len(b.packets) {
return false
}
for i := range a.packets {
if !packetEqual(a.packets[i], b.packets[i]) {
return false
}
}
return true
}
// wantPacket indicates that we expect the Conn to send a packet.
func (tc *testConn) wantPacket(expectation string, want *testPacket) {
tc.t.Helper()
got := tc.readPacket()
if !packetEqual(got, want) {
tc.t.Fatalf("%v:\ngot packet: %v\nwant packet: %v", expectation, got, want)
}
}
func packetEqual(a, b *testPacket) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
ac := *a
ac.frames = nil
ac.header = 0
bc := *b
bc.frames = nil
bc.header = 0
if !reflect.DeepEqual(ac, bc) {
return false
}
if len(a.frames) != len(b.frames) {
return false
}
for i := range a.frames {
if !frameEqual(a.frames[i], b.frames[i]) {
return false
}
}
return true
}
// wantFrame indicates that we expect the Conn to send a frame.
func (tc *testConn) wantFrame(expectation string, wantType packetType, want debugFrame) {
tc.t.Helper()
got, gotType := tc.readFrame()
if got == nil {
tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want)
}
if gotType != wantType {
tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got)
}
if !frameEqual(got, want) {
tc.t.Fatalf("%v:\ngot frame: %v\nwant frame: %v", expectation, got, want)
}
}
func frameEqual(a, b debugFrame) bool {
switch af := a.(type) {
case debugFrameConnectionCloseTransport:
bf, ok := b.(debugFrameConnectionCloseTransport)
return ok && af.code == bf.code
}
return reflect.DeepEqual(a, b)
}
// wantFrameType indicates that we expect the Conn to send a frame,
// although we don't care about the contents.
func (tc *testConn) wantFrameType(expectation string, wantType packetType, want debugFrame) {
tc.t.Helper()
got, gotType := tc.readFrame()
if got == nil {
tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want)
}
if gotType != wantType {
tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got)
}
if reflect.TypeOf(got) != reflect.TypeOf(want) {
tc.t.Fatalf("%v:\ngot frame: %v\nwant frame of type: %v", expectation, got, want)
}
}
// wantIdle indicates that we expect the Conn to not send any more frames.
func (tc *testConn) wantIdle(expectation string) {
tc.t.Helper()
switch {
case len(tc.sentFrames) > 0:
tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, tc.sentFrames[0])
case len(tc.sentPackets) > 0:
tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, tc.sentPackets[0])
}
if f, _ := tc.readFrame(); f != nil {
tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, f)
}
}
func encodeTestPacket(t *testing.T, tc *testConn, p *testPacket, pad int) []byte {
t.Helper()
var w packetWriter
w.reset(1200)
var pnumMaxAcked packetNumber
switch p.ptype {
case packetTypeRetry:
return encodeRetryPacket(p.originalDstConnID, retryPacket{
srcConnID: p.srcConnID,
dstConnID: p.dstConnID,
token: p.token,
})
case packetType1RTT:
w.start1RTTPacket(p.num, pnumMaxAcked, p.dstConnID)
default:
w.startProtectedLongHeaderPacket(pnumMaxAcked, longPacket{
ptype: p.ptype,
version: p.version,
num: p.num,
dstConnID: p.dstConnID,
srcConnID: p.srcConnID,
extra: p.token,
})
}
for _, f := range p.frames {
f.write(&w)
}
w.appendPaddingTo(pad)
if p.ptype != packetType1RTT {
var k fixedKeys
if tc == nil {
if p.ptype == packetTypeInitial {
k = initialKeys(p.dstConnID, serverSide).r
} else {
t.Fatalf("sending %v packet with no conn", p.ptype)
}
} else {
switch p.ptype {
case packetTypeInitial:
k = tc.keysInitial.w
case packetTypeHandshake:
k = tc.keysHandshake.w
}
}
if !k.isSet() {
t.Fatalf("sending %v packet with no write key", p.ptype)
}
w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, longPacket{
ptype: p.ptype,
version: p.version,
num: p.num,
dstConnID: p.dstConnID,
srcConnID: p.srcConnID,
extra: p.token,
})
} else {
if tc == nil || !tc.wkeyAppData.hdr.isSet() {
t.Fatalf("sending 1-RTT packet with no write key")
}
// Somewhat hackish: Generate a temporary updatingKeyPair that will
// always use our desired key phase.
k := &updatingKeyPair{
w: updatingKeys{
hdr: tc.wkeyAppData.hdr,
pkt: [2]packetKey{
tc.wkeyAppData.pkt[p.keyNumber],
tc.wkeyAppData.pkt[p.keyNumber],
},
},
updateAfter: maxPacketNumber,
}
if p.keyPhaseBit {
k.phase |= keyPhaseBit
}
w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, k)
}
return w.datagram()
}
func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte) *testDatagram {
t.Helper()
bufSize := len(buf)
d := &testDatagram{}
size := len(buf)
for len(buf) > 0 {
if buf[0] == 0 {
d.paddedSize = bufSize
break
}
ptype := getPacketType(buf)
switch ptype {
case packetTypeRetry:
retry, ok := parseRetryPacket(buf, te.lastInitialDstConnID)
if !ok {
t.Fatalf("could not parse %v packet", ptype)
}
return &testDatagram{
packets: []*testPacket{{
ptype: packetTypeRetry,
dstConnID: retry.dstConnID,
srcConnID: retry.srcConnID,
token: retry.token,
}},
}
case packetTypeInitial, packetTypeHandshake:
var k fixedKeys
var pnumMax packetNumber
if tc == nil {
if ptype == packetTypeInitial {
p, _ := parseGenericLongHeaderPacket(buf)
k = initialKeys(p.srcConnID, serverSide).w
} else {
t.Fatalf("reading %v packet with no conn", ptype)
}
} else {
switch ptype {
case packetTypeInitial:
k = tc.keysInitial.r
pnumMax = tc.pnumMax[initialSpace]
case packetTypeHandshake:
k = tc.keysHandshake.r
pnumMax = tc.pnumMax[handshakeSpace]
}
}
if !k.isSet() {
t.Fatalf("reading %v packet with no read key", ptype)
}
p, n := parseLongHeaderPacket(buf, k, pnumMax)
if n < 0 {
t.Fatalf("packet parse error")
}
if tc != nil {
switch ptype {
case packetTypeInitial:
tc.pnumMax[initialSpace] = max(pnumMax, p.num)
case packetTypeHandshake:
tc.pnumMax[handshakeSpace] = max(pnumMax, p.num)
}
}
frames, err := parseTestFrames(t, p.payload)
if err != nil {
t.Fatal(err)
}
var token []byte
if ptype == packetTypeInitial && len(p.extra) > 0 {
token = p.extra
}
d.packets = append(d.packets, &testPacket{
ptype: p.ptype,
header: buf[0],
version: p.version,
num: p.num,
dstConnID: p.dstConnID,
srcConnID: p.srcConnID,
token: token,
frames: frames,
})
buf = buf[n:]
case packetType1RTT:
if tc == nil || !tc.rkeyAppData.hdr.isSet() {
t.Fatalf("reading 1-RTT packet with no read key")
}
var pnumMax packetNumber
if tc != nil {
pnumMax = tc.pnumMax[appDataSpace]
}
pnumOff := 1 + len(tc.peerConnID)
// Try unprotecting the packet with the first maxTestKeyPhases keys.
var phase int
var pnum packetNumber
var hdr []byte
var pay []byte
var err error
for phase = 0; phase < maxTestKeyPhases; phase++ {
b := append([]byte{}, buf...)
hdr, pay, pnum, err = tc.rkeyAppData.hdr.unprotect(b, pnumOff, pnumMax)
if err != nil {
t.Fatalf("1-RTT packet header parse error")
}
k := tc.rkeyAppData.pkt[phase]
pay, err = k.unprotect(hdr, pay, pnum)
if err == nil {
break
}
}
if err != nil {
t.Fatalf("1-RTT packet payload parse error")
}
if tc != nil {
tc.pnumMax[appDataSpace] = max(pnumMax, pnum)
}
frames, err := parseTestFrames(t, pay)
if err != nil {
t.Fatal(err)
}
d.packets = append(d.packets, &testPacket{
ptype: packetType1RTT,
header: hdr[0],
num: pnum,
dstConnID: hdr[1:][:len(tc.peerConnID)],
keyPhaseBit: hdr[0]&keyPhaseBit != 0,
keyNumber: phase,
frames: frames,
})
buf = buf[len(buf):]
default:
t.Fatalf("unhandled packet type %v", ptype)
}
}
// This is rather hackish: If the last frame in the last packet
// in the datagram is PADDING, then remove it and record
// the padded size in the testDatagram.paddedSize.
//
// This makes it easier to write a test that expects a datagram
// padded to 1200 bytes.
if len(d.packets) > 0 && len(d.packets[len(d.packets)-1].frames) > 0 {
p := d.packets[len(d.packets)-1]
f := p.frames[len(p.frames)-1]
if _, ok := f.(debugFramePadding); ok {
p.frames = p.frames[:len(p.frames)-1]
d.paddedSize = size
}
}
return d
}
func parseTestFrames(t *testing.T, payload []byte) ([]debugFrame, error) {
t.Helper()
var frames []debugFrame
for len(payload) > 0 {
f, n := parseDebugFrame(payload)
if n < 0 {
return nil, errors.New("error parsing frames")
}
frames = append(frames, f)
payload = payload[n:]
}
return frames, nil
}
func spaceForPacketType(ptype packetType) numberSpace {
switch ptype {
case packetTypeInitial:
return initialSpace
case packetType0RTT:
panic("TODO: packetType0RTT")
case packetTypeHandshake:
return handshakeSpace
case packetTypeRetry:
panic("retry packets have no number space")
case packetType1RTT:
return appDataSpace
}
panic("unknown packet type")
}
// testConnHooks implements connTestHooks.
type testConnHooks testConn
func (tc *testConnHooks) init() {
tc.conn.keysAppData.updateAfter = maxPacketNumber // disable key updates
tc.keysInitial.r = tc.conn.keysInitial.w
tc.keysInitial.w = tc.conn.keysInitial.r
if tc.conn.side == serverSide {
tc.endpoint.acceptQueue = append(tc.endpoint.acceptQueue, (*testConn)(tc))
}
}
// handleTLSEvent processes TLS events generated by
// the connection under test's tls.QUICConn.
//
// We maintain a second tls.QUICConn representing the peer,
// and feed the TLS handshake data into it.
//
// We stash TLS handshake data from both sides in the testConn,
// where it can be used by tests.
//
// We snoop packet protection keys out of the tls.QUICConns,
// and verify that both sides of the connection are getting
// matching keys.
func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
checkKey := func(typ string, secrets *[numberSpaceCount]keySecret, e tls.QUICEvent) {
var space numberSpace
switch {
case e.Level == tls.QUICEncryptionLevelHandshake:
space = handshakeSpace
case e.Level == tls.QUICEncryptionLevelApplication:
space = appDataSpace
default:
tc.t.Errorf("unexpected encryption level %v", e.Level)
return
}
if secrets[space].secret == nil {
secrets[space].suite = e.Suite
secrets[space].secret = append([]byte{}, e.Data...)
} else if secrets[space].suite != e.Suite || !bytes.Equal(secrets[space].secret, e.Data) {
tc.t.Errorf("%v key mismatch for level for level %v", typ, e.Level)
}
}
setAppDataKey := func(suite uint16, secret []byte, k *test1RTTKeys) {
k.hdr.init(suite, secret)
for i := 0; i < len(k.pkt); i++ {
k.pkt[i].init(suite, secret)
secret = updateSecret(suite, secret)
}
}
switch e.Kind {
case tls.QUICSetReadSecret:
checkKey("write", &tc.wsecrets, e)
switch e.Level {
case tls.QUICEncryptionLevelHandshake:
tc.keysHandshake.w.init(e.Suite, e.Data)
case tls.QUICEncryptionLevelApplication:
setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData)
}
case tls.QUICSetWriteSecret:
checkKey("read", &tc.rsecrets, e)
switch e.Level {
case tls.QUICEncryptionLevelHandshake:
tc.keysHandshake.r.init(e.Suite, e.Data)
case tls.QUICEncryptionLevelApplication:
setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData)
}
case tls.QUICWriteData:
tc.cryptoDataOut[e.Level] = append(tc.cryptoDataOut[e.Level], e.Data...)
tc.peerTLSConn.HandleData(e.Level, e.Data)
}
for {
e := tc.peerTLSConn.NextEvent()
switch e.Kind {
case tls.QUICNoEvent:
return
case tls.QUICSetReadSecret:
checkKey("write", &tc.rsecrets, e)
switch e.Level {
case tls.QUICEncryptionLevelHandshake:
tc.keysHandshake.r.init(e.Suite, e.Data)
case tls.QUICEncryptionLevelApplication:
setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData)
}
case tls.QUICSetWriteSecret:
checkKey("read", &tc.wsecrets, e)
switch e.Level {
case tls.QUICEncryptionLevelHandshake:
tc.keysHandshake.w.init(e.Suite, e.Data)
case tls.QUICEncryptionLevelApplication:
setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData)
}
case tls.QUICWriteData:
tc.cryptoDataIn[e.Level] = append(tc.cryptoDataIn[e.Level], e.Data...)
case tls.QUICTransportParameters:
p, err := unmarshalTransportParams(e.Data)
if err != nil {
tc.t.Logf("sent unparseable transport parameters %x %v", e.Data, err)
} else {
tc.sentTransportParameters = &p
}
}
}
}
// nextMessage is called by the Conn's event loop to request its next event.
func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.Time, m any) {
tc.timer = timer
for {
if !timer.IsZero() && !timer.After(tc.endpoint.now) {
if timer.Equal(tc.timerLastFired) {
// If the connection timer fires at time T, the Conn should take some
// action to advance the timer into the future. If the Conn reschedules
// the timer for the same time, it isn't making progress and we have a bug.
tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.endpoint.now, timer)
} else {
tc.timerLastFired = timer
return tc.endpoint.now, timerEvent{}
}
}
select {
case m := <-msgc:
return tc.endpoint.now, m
default:
}
if !tc.wakeAsync() {
break
}
}
// If the message queue is empty, then the conn is idle.
if tc.idlec != nil {
idlec := tc.idlec
tc.idlec = nil
close(idlec)
}
m = <-msgc
return tc.endpoint.now, m
}
func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) {
return testLocalConnID(seq), nil
}
func (tc *testConnHooks) timeNow() time.Time {
return tc.endpoint.now
}
// testLocalConnID returns the connection ID with a given sequence number
// used by a Conn under test.
func testLocalConnID(seq int64) []byte {
cid := make([]byte, connIDLen)
copy(cid, []byte{0xc0, 0xff, 0xee})
cid[len(cid)-1] = byte(seq)
return cid
}
// testPeerConnID returns the connection ID with a given sequence number
// used by the fake peer of a Conn under test.
func testPeerConnID(seq int64) []byte {
// Use a different length than we choose for our own conn ids,
// to help catch any bad assumptions.
return []byte{0xbe, 0xee, 0xff, byte(seq)}
}
func testPeerStatelessResetToken(seq int64) statelessResetToken {
return statelessResetToken{
0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee,
0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, byte(seq),
}
}
// canceledContext returns a canceled Context.
//
// Functions which take a context preference progress over cancelation.
// For example, a read with a canceled context will return data if any is available.
// Tests use canceled contexts to perform non-blocking operations.
func canceledContext() context.Context {
ctx, cancel := context.WithCancel(context.Background())
cancel()
return ctx
}