internal/socket: tell race detector about syscall reads and writes

The syscalls that send and receive messages write to buffers provided
by the user. The race detector can't see those reads and writes by
default (they are done by the kernel), so we need to tell the race
detector explicitly about them.

Fixes golang/go#35329

Change-Id: Ibf4ef1b937535c4834aa9eeb744722d91f669a27
Reviewed-on: https://go-review.googlesource.com/c/net/+/205461
Run-TryBot: Keith Randall <khr@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com>
This commit is contained in:
Keith Randall
2019-11-05 12:36:25 -08:00
committed by Keith Randall
parent daa7c04131
commit 2180aed223
5 changed files with 126 additions and 0 deletions

12
internal/socket/norace.go Normal file
View File

@@ -0,0 +1,12 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !race
package socket
func (m *Message) raceRead() {
}
func (m *Message) raceWrite() {
}

37
internal/socket/race.go Normal file
View File

@@ -0,0 +1,37 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build race
package socket
import (
"runtime"
"unsafe"
)
// This package reads and writes the Message buffers using a
// direct system call, which the race detector can't see.
// These functions tell the race detector what is going on during the syscall.
func (m *Message) raceRead() {
for _, b := range m.Buffers {
if len(b) > 0 {
runtime.RaceReadRange(unsafe.Pointer(&b[0]), len(b))
}
}
if b := m.OOB; len(b) > 0 {
runtime.RaceReadRange(unsafe.Pointer(&b[0]), len(b))
}
}
func (m *Message) raceWrite() {
for _, b := range m.Buffers {
if len(b) > 0 {
runtime.RaceWriteRange(unsafe.Pointer(&b[0]), len(b))
}
}
if b := m.OOB; len(b) > 0 {
runtime.RaceWriteRange(unsafe.Pointer(&b[0]), len(b))
}
}

View File

@@ -13,6 +13,9 @@ import (
)
func (c *Conn) recvMsgs(ms []Message, flags int) (int, error) {
for i := range ms {
ms[i].raceWrite()
}
hs := make(mmsghdrs, len(ms))
var parseFn func([]byte, string) (net.Addr, error)
if c.network != "tcp" {
@@ -43,6 +46,9 @@ func (c *Conn) recvMsgs(ms []Message, flags int) (int, error) {
}
func (c *Conn) sendMsgs(ms []Message, flags int) (int, error) {
for i := range ms {
ms[i].raceRead()
}
hs := make(mmsghdrs, len(ms))
var marshalFn func(net.Addr) []byte
if c.network != "tcp" {

View File

@@ -12,6 +12,7 @@ import (
)
func (c *Conn) recvMsg(m *Message, flags int) error {
m.raceWrite()
var h msghdr
vs := make([]iovec, len(m.Buffers))
var sa []byte
@@ -48,6 +49,7 @@ func (c *Conn) recvMsg(m *Message, flags int) error {
}
func (c *Conn) sendMsg(m *Message, flags int) error {
m.raceRead()
var h msghdr
vs := make([]iovec, len(m.Buffers))
var sa []byte

View File

@@ -9,8 +9,13 @@ package socket_test
import (
"bytes"
"fmt"
"io/ioutil"
"net"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"syscall"
"testing"
@@ -296,3 +301,67 @@ func BenchmarkUDP(b *testing.B) {
}
}
}
func TestRace(t *testing.T) {
tests := []string{
`
package main
import "net"
import "golang.org/x/net/ipv4"
var g byte
func main() {
c, _ := net.ListenPacket("udp", "127.0.0.1:0")
cc := ipv4.NewPacketConn(c)
sync := make(chan bool)
src := make([]byte, 1)
dst := make([]byte, 1)
go func() { cc.WriteTo(src, nil, c.LocalAddr()) }()
go func() { cc.ReadFrom(dst); sync <- true }()
g = dst[0]
<- sync
}
`,
`
package main
import "net"
import "golang.org/x/net/ipv4"
func main() {
c, _ := net.ListenPacket("udp", "127.0.0.1:0")
cc := ipv4.NewPacketConn(c)
sync := make(chan bool)
src := make([]byte, 1)
dst := make([]byte, 1)
go func() { cc.WriteTo(src, nil, c.LocalAddr()); sync <- true }()
src[0] = 0
go func() { cc.ReadFrom(dst) }()
<- sync
}
`,
}
platforms := map[string]bool{
"linux/amd64": true,
"linux/ppc64le": true,
"linux/arm64": true,
}
if !platforms[runtime.GOOS+"/"+runtime.GOARCH] {
t.Skip("skipping test on non-race-enabled host.")
}
dir, err := ioutil.TempDir("", "testrace")
if err != nil {
t.Fatalf("failed to create temp directory: %v", err)
}
defer os.RemoveAll(dir)
goBinary := filepath.Join(runtime.GOROOT(), "bin", "go")
for i, test := range tests {
t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
src := filepath.Join(dir, fmt.Sprintf("test%d.go", i))
if err := ioutil.WriteFile(src, []byte(test), 0644); err != nil {
t.Fatalf("failed to write file: %v", err)
}
got, err := exec.Command(goBinary, "run", "-race", src).CombinedOutput()
if !strings.Contains(string(got), "WARNING: DATA RACE") {
t.Errorf("race not detected for test %d: err:%v out:%s", i, err, string(got))
}
})
}
}