mirror of
https://github.com/golang/net.git
synced 2026-03-31 10:27:08 +09:00
quic: skip packet numbers for optimistic ack defense
An "optimistic ACK attack" involves an attacker sending ACKs for packets it hasn't received, causing the victim's congestion controller to improperly send at a higher rate. The standard defense against this attack is to skip the occasional packet number, and to close the connection with an error if the peer ACKs an unsent packet. Implement this defense, increasing the gap between skipped packet numbers as a connection's lifetime grows and correspondingly the amount of work required on the part of the attacker. Change-Id: I01f44f13367821b86af6535ffb69d380e2b4d7b7 Reviewed-on: https://go-review.googlesource.com/c/net/+/664298 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Jonathan Amsterdam <jba@google.com> Auto-Submit: Damien Neil <dneil@google.com>
This commit is contained in:
committed by
Gopher Robot
parent
3f563d3b0d
commit
3e7a445bf4
13
quic/conn.go
13
quic/conn.go
@@ -6,10 +6,12 @@ package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
cryptorand "crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/rand/v2"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
@@ -24,6 +26,7 @@ type Conn struct {
|
||||
testHooks connTestHooks
|
||||
peerAddr netip.AddrPort
|
||||
localAddr netip.AddrPort
|
||||
prng *rand.Rand
|
||||
|
||||
msgc chan any
|
||||
donec chan struct{} // closed when conn loop exits
|
||||
@@ -36,6 +39,7 @@ type Conn struct {
|
||||
loss lossState
|
||||
streams streamsState
|
||||
path pathState
|
||||
skip skipState
|
||||
|
||||
// Packet protection keys, CRYPTO streams, and TLS state.
|
||||
keysInitial fixedKeyPair
|
||||
@@ -136,6 +140,14 @@ func newConn(now time.Time, side connSide, cids newServerConnIDs, peerHostname s
|
||||
}
|
||||
}
|
||||
|
||||
// A per-conn ChaCha8 PRNG is probably more than we need,
|
||||
// but at least it's fairly small.
|
||||
var seed [32]byte
|
||||
if _, err := cryptorand.Read(seed[:]); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
c.prng = rand.New(rand.NewChaCha8(seed))
|
||||
|
||||
// TODO: PMTU discovery.
|
||||
c.logConnectionStarted(cids.originalDstConnID, peerAddr)
|
||||
c.keysAppData.init()
|
||||
@@ -143,6 +155,7 @@ func newConn(now time.Time, side connSide, cids newServerConnIDs, peerHostname s
|
||||
c.streamsInit()
|
||||
c.lifetimeInit()
|
||||
c.restartIdleTimer(now)
|
||||
c.skip.init(c)
|
||||
|
||||
if err := c.startTLS(now, initialConnID, peerHostname, transportParameters{
|
||||
initialSrcConnID: c.connIDState.srcConnID(),
|
||||
|
||||
@@ -421,15 +421,10 @@ func (c *Conn) handleFrames(now time.Time, dgram *datagram, ptype packetType, sp
|
||||
func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte) int {
|
||||
c.loss.receiveAckStart()
|
||||
largest, ackDelay, n := consumeAckFrame(payload, func(rangeIndex int, start, end packetNumber) {
|
||||
if end > c.loss.nextNumber(space) {
|
||||
// Acknowledgement of a packet we never sent.
|
||||
c.abort(now, localTransportError{
|
||||
code: errProtocolViolation,
|
||||
reason: "acknowledgement for unsent packet",
|
||||
})
|
||||
if err := c.loss.receiveAckRange(now, space, rangeIndex, start, end, c.handleAckOrLoss); err != nil {
|
||||
c.abort(now, err)
|
||||
return
|
||||
}
|
||||
c.loss.receiveAckRange(now, space, rangeIndex, start, end, c.handleAckOrLoss)
|
||||
})
|
||||
// Prior to receiving the peer's transport parameters, we cannot
|
||||
// interpret the ACK Delay field because we don't know the ack_delay_exponent
|
||||
|
||||
@@ -142,6 +142,10 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
|
||||
}
|
||||
if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, &c.keysAppData); sent != nil {
|
||||
c.packetSent(now, appDataSpace, sent)
|
||||
if c.skip.shouldSkip(pnum + 1) {
|
||||
c.loss.skipNumber(now, appDataSpace)
|
||||
c.skip.updateNumberSkip(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -66,7 +66,11 @@ func TestSendPacketNumberSize(t *testing.T) {
|
||||
// current packet and the max acked one is sufficiently large.
|
||||
for want := maxAcked + 1; want < maxAcked+0x100; want++ {
|
||||
p := recvPing()
|
||||
if p.num != want {
|
||||
if p.num == want+1 {
|
||||
// The conn skipped a packet number
|
||||
// (defense against optimistic ACK attacks).
|
||||
want++
|
||||
} else if p.num != want {
|
||||
t.Fatalf("received packet number %v, want %v", p.num, want)
|
||||
}
|
||||
gotPnumLen := int(p.header&0x03) + 1
|
||||
|
||||
@@ -242,9 +242,7 @@ func TestStreamsWriteQueueFairness(t *testing.T) {
|
||||
if p == nil {
|
||||
break
|
||||
}
|
||||
tc.writeFrames(packetType1RTT, debugFrameAck{
|
||||
ranges: []i64range[packetNumber]{{0, p.num}},
|
||||
})
|
||||
tc.writeAckForLatest()
|
||||
for _, f := range p.frames {
|
||||
sf, ok := f.(debugFrameStream)
|
||||
if !ok {
|
||||
|
||||
25
quic/loss.go
25
quic/loss.go
@@ -178,6 +178,15 @@ func (c *lossState) nextNumber(space numberSpace) packetNumber {
|
||||
return c.spaces[space].nextNum
|
||||
}
|
||||
|
||||
// skipPacketNumber skips a packet number as a defense against optimistic ACK attacks.
|
||||
func (c *lossState) skipNumber(now time.Time, space numberSpace) {
|
||||
sent := newSentPacket()
|
||||
sent.num = c.spaces[space].nextNum
|
||||
sent.time = now
|
||||
sent.state = sentPacketUnsent
|
||||
c.spaces[space].add(sent)
|
||||
}
|
||||
|
||||
// packetSent records a sent packet.
|
||||
func (c *lossState) packetSent(now time.Time, log *slog.Logger, space numberSpace, sent *sentPacket) {
|
||||
sent.time = now
|
||||
@@ -230,17 +239,20 @@ func (c *lossState) receiveAckStart() {
|
||||
|
||||
// receiveAckRange processes a range within an ACK frame.
|
||||
// The ackf function is called for each newly-acknowledged packet.
|
||||
func (c *lossState) receiveAckRange(now time.Time, space numberSpace, rangeIndex int, start, end packetNumber, ackf func(numberSpace, *sentPacket, packetFate)) {
|
||||
func (c *lossState) receiveAckRange(now time.Time, space numberSpace, rangeIndex int, start, end packetNumber, ackf func(numberSpace, *sentPacket, packetFate)) error {
|
||||
// Limit our range to the intersection of the ACK range and
|
||||
// the in-flight packets we have state for.
|
||||
if s := c.spaces[space].start(); start < s {
|
||||
start = s
|
||||
}
|
||||
if e := c.spaces[space].end(); end > e {
|
||||
end = e
|
||||
return localTransportError{
|
||||
code: errProtocolViolation,
|
||||
reason: "acknowledgement for unsent packet",
|
||||
}
|
||||
}
|
||||
if start >= end {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
if rangeIndex == 0 {
|
||||
// If the latest packet in the ACK frame is newly-acked,
|
||||
@@ -252,6 +264,12 @@ func (c *lossState) receiveAckRange(now time.Time, space numberSpace, rangeIndex
|
||||
}
|
||||
for pnum := start; pnum < end; pnum++ {
|
||||
sent := c.spaces[space].num(pnum)
|
||||
if sent.state == sentPacketUnsent {
|
||||
return localTransportError{
|
||||
code: errProtocolViolation,
|
||||
reason: "acknowledgement for unsent packet",
|
||||
}
|
||||
}
|
||||
if sent.state != sentPacketSent {
|
||||
continue
|
||||
}
|
||||
@@ -266,6 +284,7 @@ func (c *lossState) receiveAckRange(now time.Time, space numberSpace, rangeIndex
|
||||
c.ackFrameContainsAckEliciting = true
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// receiveAckEnd finishes processing an ack frame.
|
||||
|
||||
@@ -38,9 +38,10 @@ type sentPacket struct {
|
||||
type sentPacketState uint8
|
||||
|
||||
const (
|
||||
sentPacketSent = sentPacketState(iota) // sent but neither acked nor lost
|
||||
sentPacketAcked // acked
|
||||
sentPacketLost // declared lost
|
||||
sentPacketSent = sentPacketState(iota) // sent but neither acked nor lost
|
||||
sentPacketAcked // acked
|
||||
sentPacketLost // declared lost
|
||||
sentPacketUnsent // never sent
|
||||
)
|
||||
|
||||
var sentPool = sync.Pool{
|
||||
|
||||
62
quic/skip.go
Normal file
62
quic/skip.go
Normal file
@@ -0,0 +1,62 @@
|
||||
// Copyright 2025 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package quic
|
||||
|
||||
// skipState is state for optimistic ACK defenses.
|
||||
//
|
||||
// An endpoint performs an optimistic ACK attack by sending acknowledgements for packets
|
||||
// which it has not received, potentially convincing the sender's congestion controller to
|
||||
// send at rates beyond what the network supports.
|
||||
//
|
||||
// We defend against this by periodically skipping packet numbers.
|
||||
// Receiving an ACK for an unsent packet number is a PROTOCOL_VIOLATION error.
|
||||
//
|
||||
// We only skip packet numbers in the Application Data number space.
|
||||
// The total data sent in the Initial/Handshake spaces should generally fit into
|
||||
// the initial congestion window.
|
||||
//
|
||||
// https://www.rfc-editor.org/rfc/rfc9000.html#section-21.4
|
||||
type skipState struct {
|
||||
// skip is the next packet number (in the Application Data space) we should skip.
|
||||
skip packetNumber
|
||||
|
||||
// maxSkip is the maximum number of packets to send before skipping another number.
|
||||
// Increases over time.
|
||||
maxSkip int64
|
||||
}
|
||||
|
||||
func (ss *skipState) init(c *Conn) {
|
||||
ss.maxSkip = 256 // skip our first packet number within this range
|
||||
ss.updateNumberSkip(c)
|
||||
}
|
||||
|
||||
// shouldSkipAfter returns whether we should skip the given packet number.
|
||||
func (ss *skipState) shouldSkip(num packetNumber) bool {
|
||||
return ss.skip == num
|
||||
}
|
||||
|
||||
// updateNumberSkip schedules a packet to be skipped after skipping lastSkipped.
|
||||
func (ss *skipState) updateNumberSkip(c *Conn) {
|
||||
// Send at least this many packets before skipping.
|
||||
// Limits the impact of skipping a little,
|
||||
// plus allows most tests to ignore skipping.
|
||||
const minSkip = 64
|
||||
|
||||
skip := minSkip + c.prng.Int64N(ss.maxSkip-minSkip)
|
||||
ss.skip += packetNumber(skip)
|
||||
|
||||
// Double the size of the skip each time until we reach 128k.
|
||||
// The idea here is that an attacker needs to correctly ack ~N packets in order
|
||||
// to send an optimistic ack for another ~N packets.
|
||||
// Skipping packet numbers comes with a small cost (it causes the receiver to
|
||||
// send an immediate ACK rather than the usual delayed ACK), so we increase the
|
||||
// time between skips as a connection's lifetime grows.
|
||||
//
|
||||
// The 128k cap is arbitrary, chosen so that we skip a packet number
|
||||
// about once a second when sending full-size datagrams at 1Gbps.
|
||||
if ss.maxSkip < 128*1024 {
|
||||
ss.maxSkip *= 2
|
||||
}
|
||||
}
|
||||
81
quic/skip_test.go
Normal file
81
quic/skip_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
// Copyright 2025 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package quic
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSkipPackets(t *testing.T) {
|
||||
tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters)
|
||||
connWritesPacket := func() {
|
||||
s.WriteByte(0)
|
||||
s.Flush()
|
||||
tc.wantFrameType("conn sends STREAM data",
|
||||
packetType1RTT, debugFrameStream{})
|
||||
tc.writeAckForLatest()
|
||||
tc.wantIdle("conn is idle")
|
||||
}
|
||||
connWritesPacket()
|
||||
|
||||
expectSkip:
|
||||
for maxUntilSkip := 256; maxUntilSkip <= 1024; maxUntilSkip *= 2 {
|
||||
for range maxUntilSkip + 1 {
|
||||
nextNum := tc.lastPacket.num + 1
|
||||
|
||||
connWritesPacket()
|
||||
|
||||
if tc.lastPacket.num == nextNum+1 {
|
||||
// A packet number was skipped, as expected.
|
||||
continue expectSkip
|
||||
}
|
||||
if tc.lastPacket.num != nextNum {
|
||||
t.Fatalf("got packet number %v, want %v or %v+1", tc.lastPacket.num, nextNum, nextNum)
|
||||
}
|
||||
|
||||
}
|
||||
t.Fatalf("no numbers skipped after %v packets", maxUntilSkip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipAckForSkippedPacket(t *testing.T) {
|
||||
tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters)
|
||||
|
||||
// Cause the connection to send packets until it skips a packet number.
|
||||
for {
|
||||
// Cause the connection to send a packet.
|
||||
last := tc.lastPacket
|
||||
s.WriteByte(0)
|
||||
s.Flush()
|
||||
tc.wantFrameType("conn sends STREAM data",
|
||||
packetType1RTT, debugFrameStream{})
|
||||
|
||||
if tc.lastPacket.num > 256 {
|
||||
t.Fatalf("no numbers skipped after 256 packets")
|
||||
}
|
||||
|
||||
// Acknowledge everything up to the packet before the one we just received.
|
||||
// We don't acknowledge the most-recently-received packet, because doing
|
||||
// so will cause the connection to drop state for the skipped packet number.
|
||||
// (We only retain state up to the oldest in-flight packet.)
|
||||
//
|
||||
// If the conn has skipped a packet number, then this ack will improperly
|
||||
// acknowledge the unsent packet.
|
||||
t.Log(tc.lastPacket.num)
|
||||
tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{
|
||||
ranges: []i64range[packetNumber]{{0, tc.lastPacket.num}},
|
||||
})
|
||||
|
||||
if last != nil && tc.lastPacket.num == last.num+2 {
|
||||
// The connection has skipped a packet number.
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// We wrote an ACK for a skipped packet number.
|
||||
// The connection should close.
|
||||
tc.wantFrame("ACK for skipped packet causes CONNECTION_CLOSE",
|
||||
packetType1RTT, debugFrameConnectionCloseTransport{
|
||||
code: errProtocolViolation,
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user