diff --git a/quic/endpoint.go b/quic/endpoint.go index bf021751..b9ababe6 100644 --- a/quic/endpoint.go +++ b/quic/endpoint.go @@ -73,6 +73,25 @@ func Listen(network, address string, listenConfig *Config) (*Endpoint, error) { return newEndpoint(pc, listenConfig, nil) } +// NewEndpoint creates an endpoint using a net.PacketConn as the underlying transport. +// +// If the PacketConn is not a *net.UDPConn, the endpoint may be slower and lack +// access to some features of the network. +func NewEndpoint(conn net.PacketConn, config *Config) (*Endpoint, error) { + var pc packetConn + var err error + switch conn := conn.(type) { + case *net.UDPConn: + pc, err = newNetUDPConn(conn) + default: + pc, err = newNetPacketConn(conn) + } + if err != nil { + return nil, err + } + return newEndpoint(pc, config, nil) +} + func newEndpoint(pc packetConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) { e := &Endpoint{ listenConfig: config, diff --git a/quic/udp_packetconn.go b/quic/udp_packetconn.go new file mode 100644 index 00000000..85ce349f --- /dev/null +++ b/quic/udp_packetconn.go @@ -0,0 +1,69 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "net" + "net/netip" +) + +// netPacketConn is a packetConn implementation wrapping a net.PacketConn. +// +// This is mostly useful for tests, since PacketConn doesn't provide access to +// important features such as identifying the local address packets were received on. +type netPacketConn struct { + c net.PacketConn + localAddr netip.AddrPort +} + +func newNetPacketConn(pc net.PacketConn) (*netPacketConn, error) { + addr, err := addrPortFromAddr(pc.LocalAddr()) + if err != nil { + return nil, err + } + return &netPacketConn{ + c: pc, + localAddr: addr, + }, nil +} + +func (c *netPacketConn) Close() error { + return c.c.Close() +} + +func (c *netPacketConn) LocalAddr() netip.AddrPort { + return c.localAddr +} + +func (c *netPacketConn) Read(f func(*datagram)) { + for { + dgram := newDatagram() + n, peerAddr, err := c.c.ReadFrom(dgram.b) + if err != nil { + return + } + dgram.peerAddr, err = addrPortFromAddr(peerAddr) + if err != nil { + continue + } + dgram.b = dgram.b[:n] + f(dgram) + } +} + +func (c *netPacketConn) Write(dgram datagram) error { + _, err := c.c.WriteTo(dgram.b, net.UDPAddrFromAddrPort(dgram.peerAddr)) + return err +} + +func addrPortFromAddr(addr net.Addr) (netip.AddrPort, error) { + switch a := addr.(type) { + case *net.UDPAddr: + return a.AddrPort(), nil + } + return netip.ParseAddrPort(addr.String()) +}