diff --git a/ipv4/batch.go b/ipv4/batch.go new file mode 100644 index 00000000..b4454992 --- /dev/null +++ b/ipv4/batch.go @@ -0,0 +1,191 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.9 + +package ipv4 + +import ( + "net" + "runtime" + "syscall" + + "golang.org/x/net/internal/socket" +) + +// BUG(mikio): On Windows, the ReadBatch and WriteBatch methods of +// PacketConn are not implemented. + +// BUG(mikio): On Windows, the ReadBatch and WriteBatch methods of +// RawConn are not implemented. + +// A Message represents an IO message. +// +// type Message struct { +// Buffers [][]byte +// OOB []byte +// Addr net.Addr +// N int +// NN int +// Flags int +// } +// +// The Buffers fields represents a list of contiguous buffers, which +// can be used for vectored IO, for example, putting a header and a +// payload in each slice. +// When writing, the Buffers field must contain at least one byte to +// write. +// When reading, the Buffers field will always contain a byte to read. +// +// The OOB field contains protocol-specific control or miscellaneous +// ancillary data known as out-of-band data. +// It can be nil when not required. +// +// The Addr field specifies a destination address when writing. +// It can be nil when the underlying protocol of the endpoint uses +// connection-oriented communication. +// After a successful read, it may contain the source address on the +// received packet. +// +// The N field indicates the number of bytes read or written from/to +// Buffers. +// +// The NN field indicates the number of bytes read or written from/to +// OOB. +// +// The Flags field contains protocol-specific information on the +// received message. +type Message = socket.Message + +// ReadBatch reads a batch of messages. +// +// The provided flags is a set of platform-dependent flags, such as +// syscall.MSG_PEEK. +// +// On a successful read it returns the number of messages received, up +// to len(ms). +// +// On Linux, a batch read will be optimized. +// On other platforms, this method will read only a single message. +// +// Unlike the ReadFrom method, it doesn't strip the IPv4 header +// followed by option headers from the received IPv4 datagram when the +// underlying transport is net.IPConn. Each Buffers field of Message +// must be large enough to accommodate an IPv4 header and option +// headers. +func (c *payloadHandler) ReadBatch(ms []Message, flags int) (int, error) { + if !c.ok() { + return 0, syscall.EINVAL + } + switch runtime.GOOS { + case "linux": + n, err := c.RecvMsgs([]socket.Message(ms), flags) + if err != nil { + err = &net.OpError{Op: "read", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: err} + } + return n, err + default: + n := 1 + err := c.RecvMsg(&ms[0], flags) + if err != nil { + n = 0 + err = &net.OpError{Op: "read", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: err} + } + return n, err + } +} + +// WriteBatch writes a batch of messages. +// +// The provided flags is a set of platform-dependent flags, such as +// syscall.MSG_DONTROUTE. +// +// It returns the number of messages written on a successful write. +// +// On Linux, a batch write will be optimized. +// On other platforms, this method will write only a single message. +func (c *payloadHandler) WriteBatch(ms []Message, flags int) (int, error) { + if !c.ok() { + return 0, syscall.EINVAL + } + switch runtime.GOOS { + case "linux": + n, err := c.SendMsgs([]socket.Message(ms), flags) + if err != nil { + err = &net.OpError{Op: "write", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: err} + } + return n, err + default: + n := 1 + err := c.SendMsg(&ms[0], flags) + if err != nil { + n = 0 + err = &net.OpError{Op: "write", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: err} + } + return n, err + } +} + +// ReadBatch reads a batch of messages. +// +// The provided flags is a set of platform-dependent flags, such as +// syscall.MSG_PEEK. +// +// On a successful read it returns the number of messages received, up +// to len(ms). +// +// On Linux, a batch read will be optimized. +// On other platforms, this method will read only a single message. +func (c *packetHandler) ReadBatch(ms []Message, flags int) (int, error) { + if !c.ok() { + return 0, syscall.EINVAL + } + switch runtime.GOOS { + case "linux": + n, err := c.RecvMsgs([]socket.Message(ms), flags) + if err != nil { + err = &net.OpError{Op: "read", Net: c.IPConn.LocalAddr().Network(), Source: c.IPConn.LocalAddr(), Err: err} + } + return n, err + default: + n := 1 + err := c.RecvMsg(&ms[0], flags) + if err != nil { + n = 0 + err = &net.OpError{Op: "read", Net: c.IPConn.LocalAddr().Network(), Source: c.IPConn.LocalAddr(), Err: err} + } + return n, err + } +} + +// WriteBatch writes a batch of messages. +// +// The provided flags is a set of platform-dependent flags, such as +// syscall.MSG_DONTROUTE. +// +// It returns the number of messages written on a successful write. +// +// On Linux, a batch write will be optimized. +// On other platforms, this method will write only a single message. +func (c *packetHandler) WriteBatch(ms []Message, flags int) (int, error) { + if !c.ok() { + return 0, syscall.EINVAL + } + switch runtime.GOOS { + case "linux": + n, err := c.SendMsgs([]socket.Message(ms), flags) + if err != nil { + err = &net.OpError{Op: "write", Net: c.IPConn.LocalAddr().Network(), Source: c.IPConn.LocalAddr(), Err: err} + } + return n, err + default: + n := 1 + err := c.SendMsg(&ms[0], flags) + if err != nil { + n = 0 + err = &net.OpError{Op: "write", Net: c.IPConn.LocalAddr().Network(), Source: c.IPConn.LocalAddr(), Err: err} + } + return n, err + } +} diff --git a/ipv4/control.go b/ipv4/control.go index da4da2dd..fc99327a 100644 --- a/ipv4/control.go +++ b/ipv4/control.go @@ -8,6 +8,9 @@ import ( "fmt" "net" "sync" + + "golang.org/x/net/internal/iana" + "golang.org/x/net/internal/socket" ) type rawOpt struct { @@ -51,6 +54,77 @@ func (cm *ControlMessage) String() string { return fmt.Sprintf("ttl=%d src=%v dst=%v ifindex=%d", cm.TTL, cm.Src, cm.Dst, cm.IfIndex) } +// Marshal returns the binary encoding of cm. +func (cm *ControlMessage) Marshal() []byte { + if cm == nil { + return nil + } + var m socket.ControlMessage + if ctlOpts[ctlPacketInfo].name > 0 && (cm.Src.To4() != nil || cm.IfIndex > 0) { + m = socket.NewControlMessage([]int{ctlOpts[ctlPacketInfo].length}) + } + if len(m) > 0 { + ctlOpts[ctlPacketInfo].marshal(m, cm) + } + return m +} + +// Parse parses b as a control message and stores the result in cm. +func (cm *ControlMessage) Parse(b []byte) error { + ms, err := socket.ControlMessage(b).Parse() + if err != nil { + return err + } + for _, m := range ms { + lvl, typ, l, err := m.ParseHeader() + if err != nil { + return err + } + if lvl != iana.ProtocolIP { + continue + } + switch typ { + case ctlOpts[ctlTTL].name: + ctlOpts[ctlTTL].parse(cm, m.Data(l)) + case ctlOpts[ctlDst].name: + ctlOpts[ctlDst].parse(cm, m.Data(l)) + case ctlOpts[ctlInterface].name: + ctlOpts[ctlInterface].parse(cm, m.Data(l)) + case ctlOpts[ctlPacketInfo].name: + ctlOpts[ctlPacketInfo].parse(cm, m.Data(l)) + } + } + return nil +} + +// NewControlMessage returns a new control message. +// +// The returned message is large enough for options specified by cf. +func NewControlMessage(cf ControlFlags) []byte { + opt := rawOpt{cflags: cf} + var l int + if opt.isset(FlagTTL) && ctlOpts[ctlTTL].name > 0 { + l += socket.ControlMessageSpace(ctlOpts[ctlTTL].length) + } + if ctlOpts[ctlPacketInfo].name > 0 { + if opt.isset(FlagSrc | FlagDst | FlagInterface) { + l += socket.ControlMessageSpace(ctlOpts[ctlPacketInfo].length) + } + } else { + if opt.isset(FlagDst) && ctlOpts[ctlDst].name > 0 { + l += socket.ControlMessageSpace(ctlOpts[ctlDst].length) + } + if opt.isset(FlagInterface) && ctlOpts[ctlInterface].name > 0 { + l += socket.ControlMessageSpace(ctlOpts[ctlInterface].length) + } + } + var b []byte + if l > 0 { + b = make([]byte, l) + } + return b +} + // Ancillary data socket options const ( ctlTTL = iota // header field diff --git a/ipv4/control_bsd.go b/ipv4/control_bsd.go index 3f27f994..77e7ad5b 100644 --- a/ipv4/control_bsd.go +++ b/ipv4/control_bsd.go @@ -12,26 +12,26 @@ import ( "unsafe" "golang.org/x/net/internal/iana" + "golang.org/x/net/internal/socket" ) func marshalDst(b []byte, cm *ControlMessage) []byte { - m := (*syscall.Cmsghdr)(unsafe.Pointer(&b[0])) - m.Level = iana.ProtocolIP - m.Type = sysIP_RECVDSTADDR - m.SetLen(syscall.CmsgLen(net.IPv4len)) - return b[syscall.CmsgSpace(net.IPv4len):] + m := socket.ControlMessage(b) + m.MarshalHeader(iana.ProtocolIP, sysIP_RECVDSTADDR, net.IPv4len) + return m.Next(net.IPv4len) } func parseDst(cm *ControlMessage, b []byte) { - cm.Dst = b[:net.IPv4len] + if len(cm.Dst) < net.IPv4len { + cm.Dst = make(net.IP, net.IPv4len) + } + copy(cm.Dst, b[:net.IPv4len]) } func marshalInterface(b []byte, cm *ControlMessage) []byte { - m := (*syscall.Cmsghdr)(unsafe.Pointer(&b[0])) - m.Level = iana.ProtocolIP - m.Type = sysIP_RECVIF - m.SetLen(syscall.CmsgLen(syscall.SizeofSockaddrDatalink)) - return b[syscall.CmsgSpace(syscall.SizeofSockaddrDatalink):] + m := socket.ControlMessage(b) + m.MarshalHeader(iana.ProtocolIP, sysIP_RECVIF, syscall.SizeofSockaddrDatalink) + return m.Next(syscall.SizeofSockaddrDatalink) } func parseInterface(cm *ControlMessage, b []byte) { diff --git a/ipv4/control_pktinfo.go b/ipv4/control_pktinfo.go index 9ed97734..425338f3 100644 --- a/ipv4/control_pktinfo.go +++ b/ipv4/control_pktinfo.go @@ -7,19 +7,18 @@ package ipv4 import ( - "syscall" + "net" "unsafe" "golang.org/x/net/internal/iana" + "golang.org/x/net/internal/socket" ) func marshalPacketInfo(b []byte, cm *ControlMessage) []byte { - m := (*syscall.Cmsghdr)(unsafe.Pointer(&b[0])) - m.Level = iana.ProtocolIP - m.Type = sysIP_PKTINFO - m.SetLen(syscall.CmsgLen(sizeofInetPktinfo)) + m := socket.ControlMessage(b) + m.MarshalHeader(iana.ProtocolIP, sysIP_PKTINFO, sizeofInetPktinfo) if cm != nil { - pi := (*inetPktinfo)(unsafe.Pointer(&b[syscall.CmsgLen(0)])) + pi := (*inetPktinfo)(unsafe.Pointer(&m.Data(sizeofInetPktinfo)[0])) if ip := cm.Src.To4(); ip != nil { copy(pi.Spec_dst[:], ip) } @@ -27,11 +26,14 @@ func marshalPacketInfo(b []byte, cm *ControlMessage) []byte { pi.setIfindex(cm.IfIndex) } } - return b[syscall.CmsgSpace(sizeofInetPktinfo):] + return m.Next(sizeofInetPktinfo) } func parsePacketInfo(cm *ControlMessage, b []byte) { pi := (*inetPktinfo)(unsafe.Pointer(&b[0])) cm.IfIndex = int(pi.Ifindex) - cm.Dst = pi.Addr[:] + if len(cm.Dst) < net.IPv4len { + cm.Dst = make(net.IP, net.IPv4len) + } + copy(cm.Dst, pi.Addr[:]) } diff --git a/ipv4/control_stub.go b/ipv4/control_stub.go index de9b1a09..5a2f7d8d 100644 --- a/ipv4/control_stub.go +++ b/ipv4/control_stub.go @@ -11,15 +11,3 @@ import "golang.org/x/net/internal/socket" func setControlMessage(c *socket.Conn, opt *rawOpt, cf ControlFlags, on bool) error { return errOpNoSupport } - -func newControlMessage(opt *rawOpt) []byte { - return nil -} - -func parseControlMessage(b []byte) (*ControlMessage, error) { - return nil, errOpNoSupport -} - -func marshalControlMessage(cm *ControlMessage) []byte { - return nil -} diff --git a/ipv4/control_unix.go b/ipv4/control_unix.go index 91115201..e1ae8167 100644 --- a/ipv4/control_unix.go +++ b/ipv4/control_unix.go @@ -7,8 +7,6 @@ package ipv4 import ( - "os" - "syscall" "unsafe" "golang.org/x/net/internal/iana" @@ -64,84 +62,10 @@ func setControlMessage(c *socket.Conn, opt *rawOpt, cf ControlFlags, on bool) er return nil } -func newControlMessage(opt *rawOpt) (oob []byte) { - opt.RLock() - var l int - if opt.isset(FlagTTL) && ctlOpts[ctlTTL].name > 0 { - l += syscall.CmsgSpace(ctlOpts[ctlTTL].length) - } - if ctlOpts[ctlPacketInfo].name > 0 { - if opt.isset(FlagSrc | FlagDst | FlagInterface) { - l += syscall.CmsgSpace(ctlOpts[ctlPacketInfo].length) - } - } else { - if opt.isset(FlagDst) && ctlOpts[ctlDst].name > 0 { - l += syscall.CmsgSpace(ctlOpts[ctlDst].length) - } - if opt.isset(FlagInterface) && ctlOpts[ctlInterface].name > 0 { - l += syscall.CmsgSpace(ctlOpts[ctlInterface].length) - } - } - if l > 0 { - oob = make([]byte, l) - } - opt.RUnlock() - return -} - -func parseControlMessage(b []byte) (*ControlMessage, error) { - if len(b) == 0 { - return nil, nil - } - cmsgs, err := syscall.ParseSocketControlMessage(b) - if err != nil { - return nil, os.NewSyscallError("parse socket control message", err) - } - cm := &ControlMessage{} - for _, m := range cmsgs { - if m.Header.Level != iana.ProtocolIP { - continue - } - switch int(m.Header.Type) { - case ctlOpts[ctlTTL].name: - ctlOpts[ctlTTL].parse(cm, m.Data[:]) - case ctlOpts[ctlDst].name: - ctlOpts[ctlDst].parse(cm, m.Data[:]) - case ctlOpts[ctlInterface].name: - ctlOpts[ctlInterface].parse(cm, m.Data[:]) - case ctlOpts[ctlPacketInfo].name: - ctlOpts[ctlPacketInfo].parse(cm, m.Data[:]) - } - } - return cm, nil -} - -func marshalControlMessage(cm *ControlMessage) (oob []byte) { - if cm == nil { - return nil - } - var l int - pktinfo := false - if ctlOpts[ctlPacketInfo].name > 0 && (cm.Src.To4() != nil || cm.IfIndex > 0) { - pktinfo = true - l += syscall.CmsgSpace(ctlOpts[ctlPacketInfo].length) - } - if l > 0 { - oob = make([]byte, l) - b := oob - if pktinfo { - b = ctlOpts[ctlPacketInfo].marshal(b, cm) - } - } - return -} - func marshalTTL(b []byte, cm *ControlMessage) []byte { - m := (*syscall.Cmsghdr)(unsafe.Pointer(&b[0])) - m.Level = iana.ProtocolIP - m.Type = sysIP_RECVTTL - m.SetLen(syscall.CmsgLen(1)) - return b[syscall.CmsgSpace(1):] + m := socket.ControlMessage(b) + m.MarshalHeader(iana.ProtocolIP, sysIP_RECVTTL, 1) + return m.Next(1) } func parseTTL(cm *ControlMessage, b []byte) { diff --git a/ipv4/control_windows.go b/ipv4/control_windows.go index 5560fcf2..ce55c664 100644 --- a/ipv4/control_windows.go +++ b/ipv4/control_windows.go @@ -14,18 +14,3 @@ func setControlMessage(c *socket.Conn, opt *rawOpt, cf ControlFlags, on bool) er // TODO(mikio): implement this return syscall.EWINDOWS } - -func newControlMessage(opt *rawOpt) []byte { - // TODO(mikio): implement this - return nil -} - -func parseControlMessage(b []byte) (*ControlMessage, error) { - // TODO(mikio): implement this - return nil, syscall.EWINDOWS -} - -func marshalControlMessage(cm *ControlMessage) []byte { - // TODO(mikio): implement this - return nil -} diff --git a/ipv4/endpoint.go b/ipv4/endpoint.go index f173ed40..2ab87736 100644 --- a/ipv4/endpoint.go +++ b/ipv4/endpoint.go @@ -105,12 +105,7 @@ func NewPacketConn(c net.PacketConn) *PacketConn { p := &PacketConn{ genericOpt: genericOpt{Conn: cc}, dgramOpt: dgramOpt{Conn: cc}, - payloadHandler: payloadHandler{PacketConn: c}, - } - if _, ok := c.(*net.IPConn); ok { - if so, ok := sockOpts[ssoStripHeader]; ok { - so.SetInt(p.dgramOpt.Conn, boolint(true)) - } + payloadHandler: payloadHandler{PacketConn: c, Conn: cc}, } return p } @@ -140,7 +135,7 @@ func (c *RawConn) SetDeadline(t time.Time) error { if !c.packetHandler.ok() { return syscall.EINVAL } - return c.packetHandler.c.SetDeadline(t) + return c.packetHandler.IPConn.SetDeadline(t) } // SetReadDeadline sets the read deadline associated with the @@ -149,7 +144,7 @@ func (c *RawConn) SetReadDeadline(t time.Time) error { if !c.packetHandler.ok() { return syscall.EINVAL } - return c.packetHandler.c.SetReadDeadline(t) + return c.packetHandler.IPConn.SetReadDeadline(t) } // SetWriteDeadline sets the write deadline associated with the @@ -158,7 +153,7 @@ func (c *RawConn) SetWriteDeadline(t time.Time) error { if !c.packetHandler.ok() { return syscall.EINVAL } - return c.packetHandler.c.SetWriteDeadline(t) + return c.packetHandler.IPConn.SetWriteDeadline(t) } // Close closes the endpoint. @@ -166,7 +161,7 @@ func (c *RawConn) Close() error { if !c.packetHandler.ok() { return syscall.EINVAL } - return c.packetHandler.c.Close() + return c.packetHandler.IPConn.Close() } // NewRawConn returns a new RawConn using c as its underlying @@ -179,7 +174,7 @@ func NewRawConn(c net.PacketConn) (*RawConn, error) { r := &RawConn{ genericOpt: genericOpt{Conn: cc}, dgramOpt: dgramOpt{Conn: cc}, - packetHandler: packetHandler{c: c.(*net.IPConn)}, + packetHandler: packetHandler{IPConn: c.(*net.IPConn), Conn: cc}, } so, ok := sockOpts[ssoHeaderPrepend] if !ok { diff --git a/ipv4/header.go b/ipv4/header.go index 6480597f..8bb0f0f4 100644 --- a/ipv4/header.go +++ b/ipv4/header.go @@ -51,7 +51,7 @@ func (h *Header) String() string { return fmt.Sprintf("ver=%d hdrlen=%d tos=%#x totallen=%d id=%#x flags=%#x fragoff=%#x ttl=%d proto=%d cksum=%#x src=%v dst=%v", h.Version, h.Len, h.TOS, h.TotalLen, h.ID, h.Flags, h.FragOff, h.TTL, h.Protocol, h.Checksum, h.Src, h.Dst) } -// Marshal returns the binary encoding of the IPv4 header h. +// Marshal returns the binary encoding of h. func (h *Header) Marshal() ([]byte, error) { if h == nil { return nil, syscall.EINVAL @@ -98,26 +98,24 @@ func (h *Header) Marshal() ([]byte, error) { return b, nil } -// ParseHeader parses b as an IPv4 header. -func ParseHeader(b []byte) (*Header, error) { - if len(b) < HeaderLen { - return nil, errHeaderTooShort +// Parse parses b as an IPv4 header and sotres the result in h. +func (h *Header) Parse(b []byte) error { + if h == nil || len(b) < HeaderLen { + return errHeaderTooShort } hdrlen := int(b[0]&0x0f) << 2 if hdrlen > len(b) { - return nil, errBufferTooShort - } - h := &Header{ - Version: int(b[0] >> 4), - Len: hdrlen, - TOS: int(b[1]), - ID: int(binary.BigEndian.Uint16(b[4:6])), - TTL: int(b[8]), - Protocol: int(b[9]), - Checksum: int(binary.BigEndian.Uint16(b[10:12])), - Src: net.IPv4(b[12], b[13], b[14], b[15]), - Dst: net.IPv4(b[16], b[17], b[18], b[19]), + return errBufferTooShort } + h.Version = int(b[0] >> 4) + h.Len = hdrlen + h.TOS = int(b[1]) + h.ID = int(binary.BigEndian.Uint16(b[4:6])) + h.TTL = int(b[8]) + h.Protocol = int(b[9]) + h.Checksum = int(binary.BigEndian.Uint16(b[10:12])) + h.Src = net.IPv4(b[12], b[13], b[14], b[15]) + h.Dst = net.IPv4(b[16], b[17], b[18], b[19]) switch runtime.GOOS { case "darwin", "dragonfly", "netbsd": h.TotalLen = int(socket.NativeEndian.Uint16(b[2:4])) + hdrlen @@ -139,9 +137,23 @@ func ParseHeader(b []byte) (*Header, error) { } h.Flags = HeaderFlags(h.FragOff&0xe000) >> 13 h.FragOff = h.FragOff & 0x1fff - if hdrlen-HeaderLen > 0 { - h.Options = make([]byte, hdrlen-HeaderLen) - copy(h.Options, b[HeaderLen:]) + optlen := hdrlen - HeaderLen + if optlen > 0 && len(b) >= hdrlen { + if cap(h.Options) < optlen { + h.Options = make([]byte, optlen) + } else { + h.Options = h.Options[:optlen] + } + copy(h.Options, b[HeaderLen:hdrlen]) + } + return nil +} + +// ParseHeader parses b as an IPv4 header. +func ParseHeader(b []byte) (*Header, error) { + h := new(Header) + if err := h.Parse(b); err != nil { + return nil, err } return h, nil } diff --git a/ipv4/header_test.go b/ipv4/header_test.go index 8dd6fc60..a246aeea 100644 --- a/ipv4/header_test.go +++ b/ipv4/header_test.go @@ -17,138 +17,212 @@ import ( ) type headerTest struct { - wireHeaderFromKernel [HeaderLen]byte - wireHeaderToKernel [HeaderLen]byte - wireHeaderFromTradBSDKernel [HeaderLen]byte - wireHeaderToTradBSDKernel [HeaderLen]byte - wireHeaderFromFreeBSD10Kernel [HeaderLen]byte - wireHeaderToFreeBSD10Kernel [HeaderLen]byte + wireHeaderFromKernel []byte + wireHeaderToKernel []byte + wireHeaderFromTradBSDKernel []byte + wireHeaderToTradBSDKernel []byte + wireHeaderFromFreeBSD10Kernel []byte + wireHeaderToFreeBSD10Kernel []byte *Header } -var headerLittleEndianTest = headerTest{ +var headerLittleEndianTests = []headerTest{ // TODO(mikio): Add platform dependent wire header formats when // we support new platforms. - wireHeaderFromKernel: [HeaderLen]byte{ - 0x45, 0x01, 0xbe, 0xef, - 0xca, 0xfe, 0x45, 0xdc, - 0xff, 0x01, 0xde, 0xad, - 172, 16, 254, 254, - 192, 168, 0, 1, + { + wireHeaderFromKernel: []byte{ + 0x45, 0x01, 0xbe, 0xef, + 0xca, 0xfe, 0x45, 0xdc, + 0xff, 0x01, 0xde, 0xad, + 172, 16, 254, 254, + 192, 168, 0, 1, + }, + wireHeaderToKernel: []byte{ + 0x45, 0x01, 0xbe, 0xef, + 0xca, 0xfe, 0x45, 0xdc, + 0xff, 0x01, 0xde, 0xad, + 172, 16, 254, 254, + 192, 168, 0, 1, + }, + wireHeaderFromTradBSDKernel: []byte{ + 0x45, 0x01, 0xdb, 0xbe, + 0xca, 0xfe, 0xdc, 0x45, + 0xff, 0x01, 0xde, 0xad, + 172, 16, 254, 254, + 192, 168, 0, 1, + }, + wireHeaderToTradBSDKernel: []byte{ + 0x45, 0x01, 0xef, 0xbe, + 0xca, 0xfe, 0xdc, 0x45, + 0xff, 0x01, 0xde, 0xad, + 172, 16, 254, 254, + 192, 168, 0, 1, + }, + wireHeaderFromFreeBSD10Kernel: []byte{ + 0x45, 0x01, 0xef, 0xbe, + 0xca, 0xfe, 0xdc, 0x45, + 0xff, 0x01, 0xde, 0xad, + 172, 16, 254, 254, + 192, 168, 0, 1, + }, + wireHeaderToFreeBSD10Kernel: []byte{ + 0x45, 0x01, 0xef, 0xbe, + 0xca, 0xfe, 0xdc, 0x45, + 0xff, 0x01, 0xde, 0xad, + 172, 16, 254, 254, + 192, 168, 0, 1, + }, + Header: &Header{ + Version: Version, + Len: HeaderLen, + TOS: 1, + TotalLen: 0xbeef, + ID: 0xcafe, + Flags: DontFragment, + FragOff: 1500, + TTL: 255, + Protocol: 1, + Checksum: 0xdead, + Src: net.IPv4(172, 16, 254, 254), + Dst: net.IPv4(192, 168, 0, 1), + }, }, - wireHeaderToKernel: [HeaderLen]byte{ - 0x45, 0x01, 0xbe, 0xef, - 0xca, 0xfe, 0x45, 0xdc, - 0xff, 0x01, 0xde, 0xad, - 172, 16, 254, 254, - 192, 168, 0, 1, - }, - wireHeaderFromTradBSDKernel: [HeaderLen]byte{ - 0x45, 0x01, 0xdb, 0xbe, - 0xca, 0xfe, 0xdc, 0x45, - 0xff, 0x01, 0xde, 0xad, - 172, 16, 254, 254, - 192, 168, 0, 1, - }, - wireHeaderToTradBSDKernel: [HeaderLen]byte{ - 0x45, 0x01, 0xef, 0xbe, - 0xca, 0xfe, 0xdc, 0x45, - 0xff, 0x01, 0xde, 0xad, - 172, 16, 254, 254, - 192, 168, 0, 1, - }, - wireHeaderFromFreeBSD10Kernel: [HeaderLen]byte{ - 0x45, 0x01, 0xef, 0xbe, - 0xca, 0xfe, 0xdc, 0x45, - 0xff, 0x01, 0xde, 0xad, - 172, 16, 254, 254, - 192, 168, 0, 1, - }, - wireHeaderToFreeBSD10Kernel: [HeaderLen]byte{ - 0x45, 0x01, 0xef, 0xbe, - 0xca, 0xfe, 0xdc, 0x45, - 0xff, 0x01, 0xde, 0xad, - 172, 16, 254, 254, - 192, 168, 0, 1, - }, - Header: &Header{ - Version: Version, - Len: HeaderLen, - TOS: 1, - TotalLen: 0xbeef, - ID: 0xcafe, - Flags: DontFragment, - FragOff: 1500, - TTL: 255, - Protocol: 1, - Checksum: 0xdead, - Src: net.IPv4(172, 16, 254, 254), - Dst: net.IPv4(192, 168, 0, 1), + + // with option headers + { + wireHeaderFromKernel: []byte{ + 0x46, 0x01, 0xbe, 0xf3, + 0xca, 0xfe, 0x45, 0xdc, + 0xff, 0x01, 0xde, 0xad, + 172, 16, 254, 254, + 192, 168, 0, 1, + 0xff, 0xfe, 0xfe, 0xff, + }, + wireHeaderToKernel: []byte{ + 0x46, 0x01, 0xbe, 0xf3, + 0xca, 0xfe, 0x45, 0xdc, + 0xff, 0x01, 0xde, 0xad, + 172, 16, 254, 254, + 192, 168, 0, 1, + 0xff, 0xfe, 0xfe, 0xff, + }, + wireHeaderFromTradBSDKernel: []byte{ + 0x46, 0x01, 0xdb, 0xbe, + 0xca, 0xfe, 0xdc, 0x45, + 0xff, 0x01, 0xde, 0xad, + 172, 16, 254, 254, + 192, 168, 0, 1, + 0xff, 0xfe, 0xfe, 0xff, + }, + wireHeaderToTradBSDKernel: []byte{ + 0x46, 0x01, 0xf3, 0xbe, + 0xca, 0xfe, 0xdc, 0x45, + 0xff, 0x01, 0xde, 0xad, + 172, 16, 254, 254, + 192, 168, 0, 1, + 0xff, 0xfe, 0xfe, 0xff, + }, + wireHeaderFromFreeBSD10Kernel: []byte{ + 0x46, 0x01, 0xf3, 0xbe, + 0xca, 0xfe, 0xdc, 0x45, + 0xff, 0x01, 0xde, 0xad, + 172, 16, 254, 254, + 192, 168, 0, 1, + 0xff, 0xfe, 0xfe, 0xff, + }, + wireHeaderToFreeBSD10Kernel: []byte{ + 0x46, 0x01, 0xf3, 0xbe, + 0xca, 0xfe, 0xdc, 0x45, + 0xff, 0x01, 0xde, 0xad, + 172, 16, 254, 254, + 192, 168, 0, 1, + 0xff, 0xfe, 0xfe, 0xff, + }, + Header: &Header{ + Version: Version, + Len: HeaderLen + 4, + TOS: 1, + TotalLen: 0xbef3, + ID: 0xcafe, + Flags: DontFragment, + FragOff: 1500, + TTL: 255, + Protocol: 1, + Checksum: 0xdead, + Src: net.IPv4(172, 16, 254, 254), + Dst: net.IPv4(192, 168, 0, 1), + Options: []byte{0xff, 0xfe, 0xfe, 0xff}, + }, }, } func TestMarshalHeader(t *testing.T) { - tt := &headerLittleEndianTest if socket.NativeEndian != binary.LittleEndian { t.Skip("no test for non-little endian machine yet") } - b, err := tt.Header.Marshal() - if err != nil { - t.Fatal(err) - } - var wh []byte - switch runtime.GOOS { - case "darwin", "dragonfly", "netbsd": - wh = tt.wireHeaderToTradBSDKernel[:] - case "freebsd": - switch { - case freebsdVersion < 1000000: - wh = tt.wireHeaderToTradBSDKernel[:] - case 1000000 <= freebsdVersion && freebsdVersion < 1100000: - wh = tt.wireHeaderToFreeBSD10Kernel[:] - default: - wh = tt.wireHeaderToKernel[:] + for _, tt := range headerLittleEndianTests { + b, err := tt.Header.Marshal() + if err != nil { + t.Fatal(err) + } + var wh []byte + switch runtime.GOOS { + case "darwin", "dragonfly", "netbsd": + wh = tt.wireHeaderToTradBSDKernel + case "freebsd": + switch { + case freebsdVersion < 1000000: + wh = tt.wireHeaderToTradBSDKernel + case 1000000 <= freebsdVersion && freebsdVersion < 1100000: + wh = tt.wireHeaderToFreeBSD10Kernel + default: + wh = tt.wireHeaderToKernel + } + default: + wh = tt.wireHeaderToKernel + } + if !bytes.Equal(b, wh) { + t.Fatalf("got %#v; want %#v", b, wh) } - default: - wh = tt.wireHeaderToKernel[:] - } - if !bytes.Equal(b, wh) { - t.Fatalf("got %#v; want %#v", b, wh) } } func TestParseHeader(t *testing.T) { - tt := &headerLittleEndianTest if socket.NativeEndian != binary.LittleEndian { t.Skip("no test for big endian machine yet") } - var wh []byte - switch runtime.GOOS { - case "darwin", "dragonfly", "netbsd": - wh = tt.wireHeaderFromTradBSDKernel[:] - case "freebsd": - switch { - case freebsdVersion < 1000000: - wh = tt.wireHeaderFromTradBSDKernel[:] - case 1000000 <= freebsdVersion && freebsdVersion < 1100000: - wh = tt.wireHeaderFromFreeBSD10Kernel[:] + for _, tt := range headerLittleEndianTests { + var wh []byte + switch runtime.GOOS { + case "darwin", "dragonfly", "netbsd": + wh = tt.wireHeaderFromTradBSDKernel + case "freebsd": + switch { + case freebsdVersion < 1000000: + wh = tt.wireHeaderFromTradBSDKernel + case 1000000 <= freebsdVersion && freebsdVersion < 1100000: + wh = tt.wireHeaderFromFreeBSD10Kernel + default: + wh = tt.wireHeaderFromKernel + } default: - wh = tt.wireHeaderFromKernel[:] + wh = tt.wireHeaderFromKernel + } + h, err := ParseHeader(wh) + if err != nil { + t.Fatal(err) + } + if err := h.Parse(wh); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(h, tt.Header) { + t.Fatalf("got %#v; want %#v", h, tt.Header) + } + s := h.String() + if strings.Contains(s, ",") { + t.Fatalf("should be space-separated values: %s", s) } - default: - wh = tt.wireHeaderFromKernel[:] - } - h, err := ParseHeader(wh) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(h, tt.Header) { - t.Fatalf("got %#v; want %#v", h, tt.Header) - } - s := h.String() - if strings.Contains(s, ",") { - t.Fatalf("should be space-separated values: %s", s) } } diff --git a/ipv4/packet.go b/ipv4/packet.go index d43723ca..f00f5b05 100644 --- a/ipv4/packet.go +++ b/ipv4/packet.go @@ -7,6 +7,8 @@ package ipv4 import ( "net" "syscall" + + "golang.org/x/net/internal/socket" ) // BUG(mikio): On Windows, the ReadFrom and WriteTo methods of RawConn @@ -14,11 +16,12 @@ import ( // A packetHandler represents the IPv4 datagram handler. type packetHandler struct { - c *net.IPConn + *net.IPConn + *socket.Conn rawOpt } -func (c *packetHandler) ok() bool { return c != nil && c.c != nil } +func (c *packetHandler) ok() bool { return c != nil && c.IPConn != nil && c.Conn != nil } // ReadFrom reads an IPv4 datagram from the endpoint c, copying the // datagram into b. It returns the received datagram as the IPv4 @@ -27,25 +30,7 @@ func (c *packetHandler) ReadFrom(b []byte) (h *Header, p []byte, cm *ControlMess if !c.ok() { return nil, nil, nil, syscall.EINVAL } - oob := newControlMessage(&c.rawOpt) - n, oobn, _, src, err := c.c.ReadMsgIP(b, oob) - if err != nil { - return nil, nil, nil, err - } - var hs []byte - if hs, p, err = slicePacket(b[:n]); err != nil { - return nil, nil, nil, err - } - if h, err = ParseHeader(hs); err != nil { - return nil, nil, nil, err - } - if cm, err = parseControlMessage(oob[:oobn]); err != nil { - return nil, nil, nil, err - } - if src != nil && cm != nil { - cm.Src = src.IP - } - return + return c.readFrom(b) } func slicePacket(b []byte) (h, p []byte, err error) { @@ -80,21 +65,5 @@ func (c *packetHandler) WriteTo(h *Header, p []byte, cm *ControlMessage) error { if !c.ok() { return syscall.EINVAL } - oob := marshalControlMessage(cm) - wh, err := h.Marshal() - if err != nil { - return err - } - dst := &net.IPAddr{} - if cm != nil { - if ip := cm.Dst.To4(); ip != nil { - dst.IP = ip - } - } - if dst.IP == nil { - dst.IP = h.Dst - } - wh = append(wh, p...) - _, _, err = c.c.WriteMsgIP(wh, oob, dst) - return err + return c.writeTo(h, p, cm) } diff --git a/ipv4/packet_go1_8.go b/ipv4/packet_go1_8.go new file mode 100644 index 00000000..b47d1868 --- /dev/null +++ b/ipv4/packet_go1_8.go @@ -0,0 +1,56 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.9 + +package ipv4 + +import "net" + +func (c *packetHandler) readFrom(b []byte) (h *Header, p []byte, cm *ControlMessage, err error) { + c.rawOpt.RLock() + oob := NewControlMessage(c.rawOpt.cflags) + c.rawOpt.RUnlock() + n, nn, _, src, err := c.ReadMsgIP(b, oob) + if err != nil { + return nil, nil, nil, err + } + var hs []byte + if hs, p, err = slicePacket(b[:n]); err != nil { + return nil, nil, nil, err + } + if h, err = ParseHeader(hs); err != nil { + return nil, nil, nil, err + } + if nn > 0 { + cm = new(ControlMessage) + if err := cm.Parse(oob[:nn]); err != nil { + return nil, nil, nil, err + } + } + if src != nil && cm != nil { + cm.Src = src.IP + } + return +} + +func (c *packetHandler) writeTo(h *Header, p []byte, cm *ControlMessage) error { + oob := cm.Marshal() + wh, err := h.Marshal() + if err != nil { + return err + } + dst := new(net.IPAddr) + if cm != nil { + if ip := cm.Dst.To4(); ip != nil { + dst.IP = ip + } + } + if dst.IP == nil { + dst.IP = h.Dst + } + wh = append(wh, p...) + _, _, err = c.WriteMsgIP(wh, oob, dst) + return err +} diff --git a/ipv4/packet_go1_9.go b/ipv4/packet_go1_9.go new file mode 100644 index 00000000..285fdb0e --- /dev/null +++ b/ipv4/packet_go1_9.go @@ -0,0 +1,67 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.9 + +package ipv4 + +import ( + "net" + + "golang.org/x/net/internal/socket" +) + +func (c *packetHandler) readFrom(b []byte) (h *Header, p []byte, cm *ControlMessage, err error) { + c.rawOpt.RLock() + m := socket.Message{ + Buffers: [][]byte{b}, + OOB: NewControlMessage(c.rawOpt.cflags), + } + c.rawOpt.RUnlock() + if err := c.RecvMsg(&m, 0); err != nil { + return nil, nil, nil, &net.OpError{Op: "read", Net: c.IPConn.LocalAddr().Network(), Source: c.IPConn.LocalAddr(), Err: err} + } + var hs []byte + if hs, p, err = slicePacket(b[:m.N]); err != nil { + return nil, nil, nil, &net.OpError{Op: "read", Net: c.IPConn.LocalAddr().Network(), Source: c.IPConn.LocalAddr(), Err: err} + } + if h, err = ParseHeader(hs); err != nil { + return nil, nil, nil, &net.OpError{Op: "read", Net: c.IPConn.LocalAddr().Network(), Source: c.IPConn.LocalAddr(), Err: err} + } + if m.NN > 0 { + cm = new(ControlMessage) + if err := cm.Parse(m.OOB[:m.NN]); err != nil { + return nil, nil, nil, &net.OpError{Op: "read", Net: c.IPConn.LocalAddr().Network(), Source: c.IPConn.LocalAddr(), Err: err} + } + } + if src, ok := m.Addr.(*net.IPAddr); ok && cm != nil { + cm.Src = src.IP + } + return +} + +func (c *packetHandler) writeTo(h *Header, p []byte, cm *ControlMessage) error { + m := socket.Message{ + OOB: cm.Marshal(), + } + wh, err := h.Marshal() + if err != nil { + return err + } + m.Buffers = [][]byte{wh, p} + dst := new(net.IPAddr) + if cm != nil { + if ip := cm.Dst.To4(); ip != nil { + dst.IP = ip + } + } + if dst.IP == nil { + dst.IP = h.Dst + } + m.Addr = dst + if err := c.SendMsg(&m, 0); err != nil { + return &net.OpError{Op: "write", Net: c.IPConn.LocalAddr().Network(), Source: c.IPConn.LocalAddr(), Err: err} + } + return nil +} diff --git a/ipv4/payload.go b/ipv4/payload.go index be130e42..f95f811a 100644 --- a/ipv4/payload.go +++ b/ipv4/payload.go @@ -4,7 +4,11 @@ package ipv4 -import "net" +import ( + "net" + + "golang.org/x/net/internal/socket" +) // BUG(mikio): On Windows, the ControlMessage for ReadFrom and WriteTo // methods of PacketConn is not implemented. @@ -12,7 +16,8 @@ import "net" // A payloadHandler represents the IPv4 datagram payload handler. type payloadHandler struct { net.PacketConn + *socket.Conn rawOpt } -func (c *payloadHandler) ok() bool { return c != nil && c.PacketConn != nil } +func (c *payloadHandler) ok() bool { return c != nil && c.PacketConn != nil && c.Conn != nil } diff --git a/ipv4/payload_cmsg.go b/ipv4/payload_cmsg.go index 9a155d25..3f06d760 100644 --- a/ipv4/payload_cmsg.go +++ b/ipv4/payload_cmsg.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build !plan9,!windows +// +build !nacl,!plan9,!windows package ipv4 @@ -19,37 +19,7 @@ func (c *payloadHandler) ReadFrom(b []byte) (n int, cm *ControlMessage, src net. if !c.ok() { return 0, nil, nil, syscall.EINVAL } - oob := newControlMessage(&c.rawOpt) - var oobn int - switch c := c.PacketConn.(type) { - case *net.UDPConn: - if n, oobn, _, src, err = c.ReadMsgUDP(b, oob); err != nil { - return 0, nil, nil, err - } - case *net.IPConn: - if _, ok := sockOpts[ssoStripHeader]; ok { - if n, oobn, _, src, err = c.ReadMsgIP(b, oob); err != nil { - return 0, nil, nil, err - } - } else { - nb := make([]byte, maxHeaderLen+len(b)) - if n, oobn, _, src, err = c.ReadMsgIP(nb, oob); err != nil { - return 0, nil, nil, err - } - hdrlen := int(nb[0]&0x0f) << 2 - copy(b, nb[hdrlen:]) - n -= hdrlen - } - default: - return 0, nil, nil, errInvalidConnType - } - if cm, err = parseControlMessage(oob[:oobn]); err != nil { - return 0, nil, nil, err - } - if cm != nil { - cm.Src = netAddrToIP4(src) - } - return + return c.readFrom(b) } // WriteTo writes a payload of the IPv4 datagram, to the destination @@ -62,20 +32,5 @@ func (c *payloadHandler) WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n if !c.ok() { return 0, syscall.EINVAL } - oob := marshalControlMessage(cm) - if dst == nil { - return 0, errMissingAddress - } - switch c := c.PacketConn.(type) { - case *net.UDPConn: - n, _, err = c.WriteMsgUDP(b, oob, dst.(*net.UDPAddr)) - case *net.IPConn: - n, _, err = c.WriteMsgIP(b, oob, dst.(*net.IPAddr)) - default: - return 0, errInvalidConnType - } - if err != nil { - return 0, err - } - return + return c.writeTo(b, cm, dst) } diff --git a/ipv4/payload_cmsg_go1_8.go b/ipv4/payload_cmsg_go1_8.go new file mode 100644 index 00000000..0a9c33a1 --- /dev/null +++ b/ipv4/payload_cmsg_go1_8.go @@ -0,0 +1,59 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.9 +// +build !nacl,!plan9,!windows + +package ipv4 + +import "net" + +func (c *payloadHandler) readFrom(b []byte) (n int, cm *ControlMessage, src net.Addr, err error) { + c.rawOpt.RLock() + oob := NewControlMessage(c.rawOpt.cflags) + c.rawOpt.RUnlock() + var nn int + switch c := c.PacketConn.(type) { + case *net.UDPConn: + if n, nn, _, src, err = c.ReadMsgUDP(b, oob); err != nil { + return 0, nil, nil, err + } + case *net.IPConn: + nb := make([]byte, maxHeaderLen+len(b)) + if n, nn, _, src, err = c.ReadMsgIP(nb, oob); err != nil { + return 0, nil, nil, err + } + hdrlen := int(nb[0]&0x0f) << 2 + copy(b, nb[hdrlen:]) + n -= hdrlen + default: + return 0, nil, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Err: errInvalidConnType} + } + if nn > 0 { + cm = new(ControlMessage) + if err = cm.Parse(oob[:nn]); err != nil { + return 0, nil, nil, &net.OpError{Op: "read", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: err} + } + } + if cm != nil { + cm.Src = netAddrToIP4(src) + } + return +} + +func (c *payloadHandler) writeTo(b []byte, cm *ControlMessage, dst net.Addr) (n int, err error) { + oob := cm.Marshal() + if dst == nil { + return 0, &net.OpError{Op: "write", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: errMissingAddress} + } + switch c := c.PacketConn.(type) { + case *net.UDPConn: + n, _, err = c.WriteMsgUDP(b, oob, dst.(*net.UDPAddr)) + case *net.IPConn: + n, _, err = c.WriteMsgIP(b, oob, dst.(*net.IPAddr)) + default: + return 0, &net.OpError{Op: "write", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Err: errInvalidConnType} + } + return +} diff --git a/ipv4/payload_cmsg_go1_9.go b/ipv4/payload_cmsg_go1_9.go new file mode 100644 index 00000000..e697f35f --- /dev/null +++ b/ipv4/payload_cmsg_go1_9.go @@ -0,0 +1,67 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.9 +// +build !nacl,!plan9,!windows + +package ipv4 + +import ( + "net" + + "golang.org/x/net/internal/socket" +) + +func (c *payloadHandler) readFrom(b []byte) (int, *ControlMessage, net.Addr, error) { + c.rawOpt.RLock() + m := socket.Message{ + OOB: NewControlMessage(c.rawOpt.cflags), + } + c.rawOpt.RUnlock() + switch c.PacketConn.(type) { + case *net.UDPConn: + m.Buffers = [][]byte{b} + if err := c.RecvMsg(&m, 0); err != nil { + return 0, nil, nil, &net.OpError{Op: "read", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: err} + } + case *net.IPConn: + h := make([]byte, HeaderLen) + m.Buffers = [][]byte{h, b} + if err := c.RecvMsg(&m, 0); err != nil { + return 0, nil, nil, &net.OpError{Op: "read", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: err} + } + hdrlen := int(h[0]&0x0f) << 2 + if hdrlen > len(h) { + d := hdrlen - len(h) + copy(b, b[d:]) + m.N -= d + } else { + m.N -= hdrlen + } + default: + return 0, nil, nil, &net.OpError{Op: "read", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: errInvalidConnType} + } + var cm *ControlMessage + if m.NN > 0 { + cm = new(ControlMessage) + if err := cm.Parse(m.OOB[:m.NN]); err != nil { + return 0, nil, nil, &net.OpError{Op: "read", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: err} + } + cm.Src = netAddrToIP4(m.Addr) + } + return m.N, cm, m.Addr, nil +} + +func (c *payloadHandler) writeTo(b []byte, cm *ControlMessage, dst net.Addr) (int, error) { + m := socket.Message{ + Buffers: [][]byte{b}, + OOB: cm.Marshal(), + Addr: dst, + } + err := c.SendMsg(&m, 0) + if err != nil { + err = &net.OpError{Op: "write", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: err} + } + return m.N, err +} diff --git a/ipv4/payload_nocmsg.go b/ipv4/payload_nocmsg.go index 6f9d5b0e..3926de70 100644 --- a/ipv4/payload_nocmsg.go +++ b/ipv4/payload_nocmsg.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build plan9 windows +// +build nacl plan9 windows package ipv4 diff --git a/ipv4/readwrite_go1_8_test.go b/ipv4/readwrite_go1_8_test.go new file mode 100644 index 00000000..1cd926e7 --- /dev/null +++ b/ipv4/readwrite_go1_8_test.go @@ -0,0 +1,248 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.9 + +package ipv4_test + +import ( + "bytes" + "fmt" + "net" + "runtime" + "strings" + "sync" + "testing" + + "golang.org/x/net/internal/iana" + "golang.org/x/net/internal/nettest" + "golang.org/x/net/ipv4" +) + +func BenchmarkPacketConnReadWriteUnicast(b *testing.B) { + switch runtime.GOOS { + case "nacl", "plan9", "windows": + b.Skipf("not supported on %s", runtime.GOOS) + } + + payload := []byte("HELLO-R-U-THERE") + iph, err := (&ipv4.Header{ + Version: ipv4.Version, + Len: ipv4.HeaderLen, + TotalLen: ipv4.HeaderLen + len(payload), + TTL: 1, + Protocol: iana.ProtocolReserved, + Src: net.IPv4(192, 0, 2, 1), + Dst: net.IPv4(192, 0, 2, 254), + }).Marshal() + if err != nil { + b.Fatal(err) + } + greh := []byte{0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00} + datagram := append(greh, append(iph, payload...)...) + bb := make([]byte, 128) + cm := ipv4.ControlMessage{ + Src: net.IPv4(127, 0, 0, 1), + } + if ifi := nettest.RoutedInterface("ip4", net.FlagUp|net.FlagLoopback); ifi != nil { + cm.IfIndex = ifi.Index + } + + b.Run("UDP", func(b *testing.B) { + c, err := nettest.NewLocalPacketListener("udp4") + if err != nil { + b.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) + } + defer c.Close() + p := ipv4.NewPacketConn(c) + dst := c.LocalAddr() + cf := ipv4.FlagTTL | ipv4.FlagInterface + if err := p.SetControlMessage(cf, true); err != nil { + b.Fatal(err) + } + b.Run("Net", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := c.WriteTo(payload, dst); err != nil { + b.Fatal(err) + } + if _, _, err := c.ReadFrom(bb); err != nil { + b.Fatal(err) + } + } + }) + b.Run("ToFrom", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := p.WriteTo(payload, &cm, dst); err != nil { + b.Fatal(err) + } + if _, _, _, err := p.ReadFrom(bb); err != nil { + b.Fatal(err) + } + } + }) + }) + b.Run("IP", func(b *testing.B) { + switch runtime.GOOS { + case "netbsd": + b.Skip("need to configure gre on netbsd") + case "openbsd": + b.Skip("net.inet.gre.allow=0 by default on openbsd") + } + + c, err := net.ListenPacket(fmt.Sprintf("ip4:%d", iana.ProtocolGRE), "127.0.0.1") + if err != nil { + b.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) + } + defer c.Close() + p := ipv4.NewPacketConn(c) + dst := c.LocalAddr() + cf := ipv4.FlagTTL | ipv4.FlagInterface + if err := p.SetControlMessage(cf, true); err != nil { + b.Fatal(err) + } + b.Run("Net", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := c.WriteTo(datagram, dst); err != nil { + b.Fatal(err) + } + if _, _, err := c.ReadFrom(bb); err != nil { + b.Fatal(err) + } + } + }) + b.Run("ToFrom", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := p.WriteTo(datagram, &cm, dst); err != nil { + b.Fatal(err) + } + if _, _, _, err := p.ReadFrom(bb); err != nil { + b.Fatal(err) + } + } + }) + }) +} + +func TestPacketConnConcurrentReadWriteUnicast(t *testing.T) { + switch runtime.GOOS { + case "nacl", "plan9", "windows": + t.Skipf("not supported on %s", runtime.GOOS) + } + + payload := []byte("HELLO-R-U-THERE") + iph, err := (&ipv4.Header{ + Version: ipv4.Version, + Len: ipv4.HeaderLen, + TotalLen: ipv4.HeaderLen + len(payload), + TTL: 1, + Protocol: iana.ProtocolReserved, + Src: net.IPv4(192, 0, 2, 1), + Dst: net.IPv4(192, 0, 2, 254), + }).Marshal() + if err != nil { + t.Fatal(err) + } + greh := []byte{0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00} + datagram := append(greh, append(iph, payload...)...) + + t.Run("UDP", func(t *testing.T) { + c, err := nettest.NewLocalPacketListener("udp4") + if err != nil { + t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) + } + defer c.Close() + p := ipv4.NewPacketConn(c) + t.Run("ToFrom", func(t *testing.T) { + testPacketConnConcurrentReadWriteUnicast(t, p, payload, c.LocalAddr()) + }) + }) + t.Run("IP", func(t *testing.T) { + switch runtime.GOOS { + case "netbsd": + t.Skip("need to configure gre on netbsd") + case "openbsd": + t.Skip("net.inet.gre.allow=0 by default on openbsd") + } + + c, err := net.ListenPacket(fmt.Sprintf("ip4:%d", iana.ProtocolGRE), "127.0.0.1") + if err != nil { + t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) + } + defer c.Close() + p := ipv4.NewPacketConn(c) + t.Run("ToFrom", func(t *testing.T) { + testPacketConnConcurrentReadWriteUnicast(t, p, datagram, c.LocalAddr()) + }) + }) +} + +func testPacketConnConcurrentReadWriteUnicast(t *testing.T, p *ipv4.PacketConn, data []byte, dst net.Addr) { + ifi := nettest.RoutedInterface("ip4", net.FlagUp|net.FlagLoopback) + cf := ipv4.FlagTTL | ipv4.FlagSrc | ipv4.FlagDst | ipv4.FlagInterface + + if err := p.SetControlMessage(cf, true); err != nil { // probe before test + if nettest.ProtocolNotSupported(err) { + t.Skipf("not supported on %s", runtime.GOOS) + } + t.Fatal(err) + } + + var wg sync.WaitGroup + reader := func() { + defer wg.Done() + b := make([]byte, 128) + n, cm, _, err := p.ReadFrom(b) + if err != nil { + t.Error(err) + return + } + if !bytes.Equal(b[:n], data) { + t.Errorf("got %#v; want %#v", b[:n], data) + return + } + s := cm.String() + if strings.Contains(s, ",") { + t.Errorf("should be space-separated values: %s", s) + return + } + } + writer := func(toggle bool) { + defer wg.Done() + cm := ipv4.ControlMessage{ + Src: net.IPv4(127, 0, 0, 1), + } + if ifi != nil { + cm.IfIndex = ifi.Index + } + if err := p.SetControlMessage(cf, toggle); err != nil { + t.Error(err) + return + } + n, err := p.WriteTo(data, &cm, dst) + if err != nil { + t.Error(err) + return + } + if n != len(data) { + t.Errorf("got %d; want %d", n, len(data)) + return + } + } + + const N = 10 + wg.Add(N) + for i := 0; i < N; i++ { + go reader() + } + wg.Add(2 * N) + for i := 0; i < 2*N; i++ { + go writer(i%2 != 0) + + } + wg.Add(N) + for i := 0; i < N; i++ { + go reader() + } + wg.Wait() +} diff --git a/ipv4/readwrite_go1_9_test.go b/ipv4/readwrite_go1_9_test.go new file mode 100644 index 00000000..365de022 --- /dev/null +++ b/ipv4/readwrite_go1_9_test.go @@ -0,0 +1,388 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.9 + +package ipv4_test + +import ( + "bytes" + "fmt" + "net" + "runtime" + "strings" + "sync" + "testing" + + "golang.org/x/net/internal/iana" + "golang.org/x/net/internal/nettest" + "golang.org/x/net/ipv4" +) + +func BenchmarkPacketConnReadWriteUnicast(b *testing.B) { + switch runtime.GOOS { + case "nacl", "plan9", "windows": + b.Skipf("not supported on %s", runtime.GOOS) + } + + payload := []byte("HELLO-R-U-THERE") + iph, err := (&ipv4.Header{ + Version: ipv4.Version, + Len: ipv4.HeaderLen, + TotalLen: ipv4.HeaderLen + len(payload), + TTL: 1, + Protocol: iana.ProtocolReserved, + Src: net.IPv4(192, 0, 2, 1), + Dst: net.IPv4(192, 0, 2, 254), + }).Marshal() + if err != nil { + b.Fatal(err) + } + greh := []byte{0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00} + datagram := append(greh, append(iph, payload...)...) + bb := make([]byte, 128) + cm := ipv4.ControlMessage{ + Src: net.IPv4(127, 0, 0, 1), + } + if ifi := nettest.RoutedInterface("ip4", net.FlagUp|net.FlagLoopback); ifi != nil { + cm.IfIndex = ifi.Index + } + + b.Run("UDP", func(b *testing.B) { + c, err := nettest.NewLocalPacketListener("udp4") + if err != nil { + b.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) + } + defer c.Close() + p := ipv4.NewPacketConn(c) + dst := c.LocalAddr() + cf := ipv4.FlagTTL | ipv4.FlagInterface + if err := p.SetControlMessage(cf, true); err != nil { + b.Fatal(err) + } + wms := []ipv4.Message{ + { + Buffers: [][]byte{payload}, + Addr: dst, + OOB: cm.Marshal(), + }, + } + rms := []ipv4.Message{ + { + Buffers: [][]byte{bb}, + OOB: ipv4.NewControlMessage(cf), + }, + } + b.Run("Net", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := c.WriteTo(payload, dst); err != nil { + b.Fatal(err) + } + if _, _, err := c.ReadFrom(bb); err != nil { + b.Fatal(err) + } + } + }) + b.Run("ToFrom", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := p.WriteTo(payload, &cm, dst); err != nil { + b.Fatal(err) + } + if _, _, _, err := p.ReadFrom(bb); err != nil { + b.Fatal(err) + } + } + }) + b.Run("Batch", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := p.WriteBatch(wms, 0); err != nil { + b.Fatal(err) + } + if _, err := p.ReadBatch(rms, 0); err != nil { + b.Fatal(err) + } + } + }) + }) + b.Run("IP", func(b *testing.B) { + switch runtime.GOOS { + case "netbsd": + b.Skip("need to configure gre on netbsd") + case "openbsd": + b.Skip("net.inet.gre.allow=0 by default on openbsd") + } + + c, err := net.ListenPacket(fmt.Sprintf("ip4:%d", iana.ProtocolGRE), "127.0.0.1") + if err != nil { + b.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) + } + defer c.Close() + p := ipv4.NewPacketConn(c) + dst := c.LocalAddr() + cf := ipv4.FlagTTL | ipv4.FlagInterface + if err := p.SetControlMessage(cf, true); err != nil { + b.Fatal(err) + } + wms := []ipv4.Message{ + { + Buffers: [][]byte{datagram}, + Addr: dst, + OOB: cm.Marshal(), + }, + } + rms := []ipv4.Message{ + { + Buffers: [][]byte{bb}, + OOB: ipv4.NewControlMessage(cf), + }, + } + b.Run("Net", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := c.WriteTo(datagram, dst); err != nil { + b.Fatal(err) + } + if _, _, err := c.ReadFrom(bb); err != nil { + b.Fatal(err) + } + } + }) + b.Run("ToFrom", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := p.WriteTo(datagram, &cm, dst); err != nil { + b.Fatal(err) + } + if _, _, _, err := p.ReadFrom(bb); err != nil { + b.Fatal(err) + } + } + }) + b.Run("Batch", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := p.WriteBatch(wms, 0); err != nil { + b.Fatal(err) + } + if _, err := p.ReadBatch(rms, 0); err != nil { + b.Fatal(err) + } + } + }) + }) +} + +func TestPacketConnConcurrentReadWriteUnicast(t *testing.T) { + switch runtime.GOOS { + case "nacl", "plan9", "windows": + t.Skipf("not supported on %s", runtime.GOOS) + } + + payload := []byte("HELLO-R-U-THERE") + iph, err := (&ipv4.Header{ + Version: ipv4.Version, + Len: ipv4.HeaderLen, + TotalLen: ipv4.HeaderLen + len(payload), + TTL: 1, + Protocol: iana.ProtocolReserved, + Src: net.IPv4(192, 0, 2, 1), + Dst: net.IPv4(192, 0, 2, 254), + }).Marshal() + if err != nil { + t.Fatal(err) + } + greh := []byte{0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00} + datagram := append(greh, append(iph, payload...)...) + + t.Run("UDP", func(t *testing.T) { + c, err := nettest.NewLocalPacketListener("udp4") + if err != nil { + t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) + } + defer c.Close() + p := ipv4.NewPacketConn(c) + t.Run("ToFrom", func(t *testing.T) { + testPacketConnConcurrentReadWriteUnicast(t, p, payload, c.LocalAddr(), false) + }) + t.Run("Batch", func(t *testing.T) { + testPacketConnConcurrentReadWriteUnicast(t, p, payload, c.LocalAddr(), true) + }) + }) + t.Run("IP", func(t *testing.T) { + switch runtime.GOOS { + case "netbsd": + t.Skip("need to configure gre on netbsd") + case "openbsd": + t.Skip("net.inet.gre.allow=0 by default on openbsd") + } + + c, err := net.ListenPacket(fmt.Sprintf("ip4:%d", iana.ProtocolGRE), "127.0.0.1") + if err != nil { + t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) + } + defer c.Close() + p := ipv4.NewPacketConn(c) + t.Run("ToFrom", func(t *testing.T) { + testPacketConnConcurrentReadWriteUnicast(t, p, datagram, c.LocalAddr(), false) + }) + t.Run("Batch", func(t *testing.T) { + testPacketConnConcurrentReadWriteUnicast(t, p, datagram, c.LocalAddr(), true) + }) + }) +} + +func testPacketConnConcurrentReadWriteUnicast(t *testing.T, p *ipv4.PacketConn, data []byte, dst net.Addr, batch bool) { + ifi := nettest.RoutedInterface("ip4", net.FlagUp|net.FlagLoopback) + cf := ipv4.FlagTTL | ipv4.FlagSrc | ipv4.FlagDst | ipv4.FlagInterface + + if err := p.SetControlMessage(cf, true); err != nil { // probe before test + if nettest.ProtocolNotSupported(err) { + t.Skipf("not supported on %s", runtime.GOOS) + } + t.Fatal(err) + } + + var wg sync.WaitGroup + reader := func() { + defer wg.Done() + b := make([]byte, 128) + n, cm, _, err := p.ReadFrom(b) + if err != nil { + t.Error(err) + return + } + if !bytes.Equal(b[:n], data) { + t.Errorf("got %#v; want %#v", b[:n], data) + return + } + s := cm.String() + if strings.Contains(s, ",") { + t.Errorf("should be space-separated values: %s", s) + return + } + } + batchReader := func() { + defer wg.Done() + ms := []ipv4.Message{ + { + Buffers: [][]byte{make([]byte, 128)}, + OOB: ipv4.NewControlMessage(cf), + }, + } + n, err := p.ReadBatch(ms, 0) + if err != nil { + t.Error(err) + return + } + if n != len(ms) { + t.Errorf("got %d; want %d", n, len(ms)) + return + } + var cm ipv4.ControlMessage + if err := cm.Parse(ms[0].OOB[:ms[0].NN]); err != nil { + t.Error(err) + return + } + var b []byte + if _, ok := dst.(*net.IPAddr); ok { + var h ipv4.Header + if err := h.Parse(ms[0].Buffers[0][:ms[0].N]); err != nil { + t.Error(err) + return + } + b = ms[0].Buffers[0][h.Len:ms[0].N] + } else { + b = ms[0].Buffers[0][:ms[0].N] + } + if !bytes.Equal(b, data) { + t.Errorf("got %#v; want %#v", b, data) + return + } + s := cm.String() + if strings.Contains(s, ",") { + t.Errorf("should be space-separated values: %s", s) + return + } + } + writer := func(toggle bool) { + defer wg.Done() + cm := ipv4.ControlMessage{ + Src: net.IPv4(127, 0, 0, 1), + } + if ifi != nil { + cm.IfIndex = ifi.Index + } + if err := p.SetControlMessage(cf, toggle); err != nil { + t.Error(err) + return + } + n, err := p.WriteTo(data, &cm, dst) + if err != nil { + t.Error(err) + return + } + if n != len(data) { + t.Errorf("got %d; want %d", n, len(data)) + return + } + } + batchWriter := func(toggle bool) { + defer wg.Done() + cm := ipv4.ControlMessage{ + Src: net.IPv4(127, 0, 0, 1), + } + if ifi != nil { + cm.IfIndex = ifi.Index + } + if err := p.SetControlMessage(cf, toggle); err != nil { + t.Error(err) + return + } + ms := []ipv4.Message{ + { + Buffers: [][]byte{data}, + OOB: cm.Marshal(), + Addr: dst, + }, + } + n, err := p.WriteBatch(ms, 0) + if err != nil { + t.Error(err) + return + } + if n != len(ms) { + t.Errorf("got %d; want %d", n, len(ms)) + return + } + if ms[0].N != len(data) { + t.Errorf("got %d; want %d", ms[0].N, len(data)) + return + } + } + + const N = 10 + wg.Add(N) + for i := 0; i < N; i++ { + if batch { + go batchReader() + } else { + go reader() + } + } + wg.Add(2 * N) + for i := 0; i < 2*N; i++ { + if batch { + go batchWriter(i%2 != 0) + } else { + go writer(i%2 != 0) + } + + } + wg.Add(N) + for i := 0; i < N; i++ { + if batch { + go batchReader() + } else { + go reader() + } + } + wg.Wait() +}