mirror of
https://github.com/golang/net.git
synced 2026-04-01 02:47:08 +09:00
internal/socket: tell race detector about syscall reads and writes
The syscalls that send and receive messages write to buffers provided by the user. The race detector can't see those reads and writes by default (they are done by the kernel), so we need to tell the race detector explicitly about them. Fixes golang/go#35329 Change-Id: Ibf4ef1b937535c4834aa9eeb744722d91f669a27 Reviewed-on: https://go-review.googlesource.com/c/net/+/205461 Run-TryBot: Keith Randall <khr@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com>
This commit is contained in:
committed by
Keith Randall
parent
daa7c04131
commit
2180aed223
12
internal/socket/norace.go
Normal file
12
internal/socket/norace.go
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !race
|
||||
|
||||
package socket
|
||||
|
||||
func (m *Message) raceRead() {
|
||||
}
|
||||
func (m *Message) raceWrite() {
|
||||
}
|
||||
37
internal/socket/race.go
Normal file
37
internal/socket/race.go
Normal file
@@ -0,0 +1,37 @@
|
||||
// Copyright 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build race
|
||||
|
||||
package socket
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// This package reads and writes the Message buffers using a
|
||||
// direct system call, which the race detector can't see.
|
||||
// These functions tell the race detector what is going on during the syscall.
|
||||
|
||||
func (m *Message) raceRead() {
|
||||
for _, b := range m.Buffers {
|
||||
if len(b) > 0 {
|
||||
runtime.RaceReadRange(unsafe.Pointer(&b[0]), len(b))
|
||||
}
|
||||
}
|
||||
if b := m.OOB; len(b) > 0 {
|
||||
runtime.RaceReadRange(unsafe.Pointer(&b[0]), len(b))
|
||||
}
|
||||
}
|
||||
func (m *Message) raceWrite() {
|
||||
for _, b := range m.Buffers {
|
||||
if len(b) > 0 {
|
||||
runtime.RaceWriteRange(unsafe.Pointer(&b[0]), len(b))
|
||||
}
|
||||
}
|
||||
if b := m.OOB; len(b) > 0 {
|
||||
runtime.RaceWriteRange(unsafe.Pointer(&b[0]), len(b))
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,9 @@ import (
|
||||
)
|
||||
|
||||
func (c *Conn) recvMsgs(ms []Message, flags int) (int, error) {
|
||||
for i := range ms {
|
||||
ms[i].raceWrite()
|
||||
}
|
||||
hs := make(mmsghdrs, len(ms))
|
||||
var parseFn func([]byte, string) (net.Addr, error)
|
||||
if c.network != "tcp" {
|
||||
@@ -43,6 +46,9 @@ func (c *Conn) recvMsgs(ms []Message, flags int) (int, error) {
|
||||
}
|
||||
|
||||
func (c *Conn) sendMsgs(ms []Message, flags int) (int, error) {
|
||||
for i := range ms {
|
||||
ms[i].raceRead()
|
||||
}
|
||||
hs := make(mmsghdrs, len(ms))
|
||||
var marshalFn func(net.Addr) []byte
|
||||
if c.network != "tcp" {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func (c *Conn) recvMsg(m *Message, flags int) error {
|
||||
m.raceWrite()
|
||||
var h msghdr
|
||||
vs := make([]iovec, len(m.Buffers))
|
||||
var sa []byte
|
||||
@@ -48,6 +49,7 @@ func (c *Conn) recvMsg(m *Message, flags int) error {
|
||||
}
|
||||
|
||||
func (c *Conn) sendMsg(m *Message, flags int) error {
|
||||
m.raceRead()
|
||||
var h msghdr
|
||||
vs := make([]iovec, len(m.Buffers))
|
||||
var sa []byte
|
||||
|
||||
@@ -9,8 +9,13 @@ package socket_test
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
@@ -296,3 +301,67 @@ func BenchmarkUDP(b *testing.B) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRace(t *testing.T) {
|
||||
tests := []string{
|
||||
`
|
||||
package main
|
||||
import "net"
|
||||
import "golang.org/x/net/ipv4"
|
||||
var g byte
|
||||
func main() {
|
||||
c, _ := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
cc := ipv4.NewPacketConn(c)
|
||||
sync := make(chan bool)
|
||||
src := make([]byte, 1)
|
||||
dst := make([]byte, 1)
|
||||
go func() { cc.WriteTo(src, nil, c.LocalAddr()) }()
|
||||
go func() { cc.ReadFrom(dst); sync <- true }()
|
||||
g = dst[0]
|
||||
<- sync
|
||||
}
|
||||
`,
|
||||
`
|
||||
package main
|
||||
import "net"
|
||||
import "golang.org/x/net/ipv4"
|
||||
func main() {
|
||||
c, _ := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
cc := ipv4.NewPacketConn(c)
|
||||
sync := make(chan bool)
|
||||
src := make([]byte, 1)
|
||||
dst := make([]byte, 1)
|
||||
go func() { cc.WriteTo(src, nil, c.LocalAddr()); sync <- true }()
|
||||
src[0] = 0
|
||||
go func() { cc.ReadFrom(dst) }()
|
||||
<- sync
|
||||
}
|
||||
`,
|
||||
}
|
||||
platforms := map[string]bool{
|
||||
"linux/amd64": true,
|
||||
"linux/ppc64le": true,
|
||||
"linux/arm64": true,
|
||||
}
|
||||
if !platforms[runtime.GOOS+"/"+runtime.GOARCH] {
|
||||
t.Skip("skipping test on non-race-enabled host.")
|
||||
}
|
||||
dir, err := ioutil.TempDir("", "testrace")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
goBinary := filepath.Join(runtime.GOROOT(), "bin", "go")
|
||||
for i, test := range tests {
|
||||
t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
|
||||
src := filepath.Join(dir, fmt.Sprintf("test%d.go", i))
|
||||
if err := ioutil.WriteFile(src, []byte(test), 0644); err != nil {
|
||||
t.Fatalf("failed to write file: %v", err)
|
||||
}
|
||||
got, err := exec.Command(goBinary, "run", "-race", src).CombinedOutput()
|
||||
if !strings.Contains(string(got), "WARNING: DATA RACE") {
|
||||
t.Errorf("race not detected for test %d: err:%v out:%s", i, err, string(got))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user