diff --git a/src/cmd/compile/internal/noder/expr.go b/src/cmd/compile/internal/noder/expr.go index 989ebf236e..1ca5552879 100644 --- a/src/cmd/compile/internal/noder/expr.go +++ b/src/cmd/compile/internal/noder/expr.go @@ -135,7 +135,7 @@ func (g *irgen) expr0(typ types2.Type, expr syntax.Expr) ir.Node { index := g.expr(expr.Index) if index.Op() != ir.OTYPE { // This is just a normal index expression - return Index(pos, g.expr(expr.X), index) + return Index(pos, g.typ(typ), g.expr(expr.X), index) } // This is generic function instantiation with a single type targs = []ir.Node{index} diff --git a/src/cmd/compile/internal/noder/helpers.go b/src/cmd/compile/internal/noder/helpers.go index cf7a3e22b3..1210d4b58c 100644 --- a/src/cmd/compile/internal/noder/helpers.go +++ b/src/cmd/compile/internal/noder/helpers.go @@ -83,11 +83,12 @@ func Binary(pos src.XPos, op ir.Op, x, y ir.Node) ir.Node { } func Call(pos src.XPos, typ *types.Type, fun ir.Node, args []ir.Node, dots bool) ir.Node { - // TODO(mdempsky): This should not be so difficult. + n := ir.NewCallExpr(pos, ir.OCALL, fun, args) + n.IsDDD = dots + if fun.Op() == ir.OTYPE { // Actually a type conversion, not a function call. - n := ir.NewCallExpr(pos, ir.OCALL, fun, args) - if fun.Type().Kind() == types.TTYPEPARAM { + if fun.Type().HasTParam() || args[0].Type().HasTParam() { // For type params, don't typecheck until we actually know // the type. return typed(typ, n) @@ -96,9 +97,34 @@ func Call(pos src.XPos, typ *types.Type, fun ir.Node, args []ir.Node, dots bool) } if fun, ok := fun.(*ir.Name); ok && fun.BuiltinOp != 0 { - // Call to a builtin function. - n := ir.NewCallExpr(pos, ir.OCALL, fun, args) - n.IsDDD = dots + // For Builtin ops, we currently stay with using the old + // typechecker to transform the call to a more specific expression + // and possibly use more specific ops. However, for a bunch of the + // ops, we delay doing the old typechecker if any of the args have + // type params, for a variety of reasons: + // + // OMAKE: hard to choose specific ops OMAKESLICE, etc. until arg type is known + // OREAL/OIMAG: can't determine type float32/float64 until arg type know + // OLEN/OCAP: old typechecker will complain if arg is not obviously a slice/array. + // OAPPEND: old typechecker will complain if arg is not obviously slice, etc. + // + // We will eventually break out the transforming functionality + // needed for builtin's, and call it here or during stenciling, as + // appropriate. + switch fun.BuiltinOp { + case ir.OMAKE, ir.OREAL, ir.OIMAG, ir.OLEN, ir.OCAP, ir.OAPPEND: + hasTParam := false + for _, arg := range args { + if arg.Type().HasTParam() { + hasTParam = true + break + } + } + if hasTParam { + return typed(typ, n) + } + } + switch fun.BuiltinOp { case ir.OCLOSE, ir.ODELETE, ir.OPANIC, ir.OPRINT, ir.OPRINTN: return typecheck.Stmt(n) @@ -124,9 +150,6 @@ func Call(pos src.XPos, typ *types.Type, fun ir.Node, args []ir.Node, dots bool) } } - n := ir.NewCallExpr(pos, ir.OCALL, fun, args) - n.IsDDD = dots - if fun.Op() == ir.OXDOT { if !fun.(*ir.SelectorExpr).X.Type().HasTParam() { base.FatalfAt(pos, "Expecting type param receiver in %v", fun) @@ -230,9 +253,18 @@ func method(typ *types.Type, index int) *types.Field { return types.ReceiverBaseType(typ).Methods().Index(index) } -func Index(pos src.XPos, x, index ir.Node) ir.Node { - // TODO(mdempsky): Avoid typecheck.Expr (which will call tcIndex) - return typecheck.Expr(ir.NewIndexExpr(pos, x, index)) +func Index(pos src.XPos, typ *types.Type, x, index ir.Node) ir.Node { + n := ir.NewIndexExpr(pos, x, index) + // TODO(danscales): Temporary fix. Need to separate out the + // transformations done by the old typechecker (in tcIndex()), to be + // called here or after stenciling. + if x.Type().HasTParam() && x.Type().Kind() != types.TMAP && + x.Type().Kind() != types.TSLICE && x.Type().Kind() != types.TARRAY { + // Old typechecker will complain if arg is not obviously a slice/array/map. + typed(typ, n) + return n + } + return typecheck.Expr(n) } func Slice(pos src.XPos, x, low, high, max ir.Node) ir.Node { diff --git a/src/cmd/compile/internal/noder/stencil.go b/src/cmd/compile/internal/noder/stencil.go index 64b3a942e2..55aee9b6ff 100644 --- a/src/cmd/compile/internal/noder/stencil.go +++ b/src/cmd/compile/internal/noder/stencil.go @@ -330,8 +330,13 @@ func (subst *subster) node(n ir.Node) ir.Node { m.SetIsClosureVar(true) } t := x.Type() - newt := subst.typ(t) - m.SetType(newt) + if t == nil { + assert(name.BuiltinOp != 0) + } else { + newt := subst.typ(t) + m.SetType(newt) + } + m.BuiltinOp = name.BuiltinOp m.Curfn = subst.newf m.Class = name.Class m.Func = name.Func @@ -396,11 +401,23 @@ func (subst *subster) node(n ir.Node) ir.Node { // that the OXDOT was resolved. call.SetTypecheck(0) typecheck.Call(call) + } else if name := call.X.Name(); name != nil { + switch name.BuiltinOp { + case ir.OMAKE, ir.OREAL, ir.OIMAG, ir.OLEN, ir.OCAP, ir.OAPPEND: + // Call old typechecker (to do any + // transformations) now that we know the + // type of the args. + m.SetTypecheck(0) + m = typecheck.Expr(m) + default: + base.FatalfAt(call.Pos(), "Unexpected builtin op") + } + } else if call.X.Op() != ir.OFUNCINST { // A call with an OFUNCINST will get typechecked // in stencil() once we have created & attached the // instantiation to be called. - base.FatalfAt(call.Pos(), "Expecting OCALLPART or OTYPE or OFUNCINST with CALL") + base.FatalfAt(call.Pos(), "Expecting OCALLPART or OTYPE or OFUNCINST or builtin with CALL") } } diff --git a/src/cmd/compile/internal/noder/stmt.go b/src/cmd/compile/internal/noder/stmt.go index 1775116f41..31c6bfe5c8 100644 --- a/src/cmd/compile/internal/noder/stmt.go +++ b/src/cmd/compile/internal/noder/stmt.go @@ -28,7 +28,11 @@ func (g *irgen) stmts(stmts []syntax.Stmt) []ir.Node { func (g *irgen) stmt(stmt syntax.Stmt) ir.Node { // TODO(mdempsky): Remove dependency on typecheck. - return typecheck.Stmt(g.stmt0(stmt)) + n := g.stmt0(stmt) + if n != nil { + n.SetTypecheck(1) + } + return n } func (g *irgen) stmt0(stmt syntax.Stmt) ir.Node { @@ -46,17 +50,20 @@ func (g *irgen) stmt0(stmt syntax.Stmt) ir.Node { } return x case *syntax.SendStmt: - return ir.NewSendStmt(g.pos(stmt), g.expr(stmt.Chan), g.expr(stmt.Value)) + n := ir.NewSendStmt(g.pos(stmt), g.expr(stmt.Chan), g.expr(stmt.Value)) + // Need to do the AssignConv() in tcSend(). + return typecheck.Stmt(n) case *syntax.DeclStmt: return ir.NewBlockStmt(g.pos(stmt), g.decls(stmt.DeclList)) case *syntax.AssignStmt: if stmt.Op != 0 && stmt.Op != syntax.Def { op := g.op(stmt.Op, binOps[:]) + // May need to insert ConvExpr nodes on the args in tcArith if stmt.Rhs == nil { - return IncDec(g.pos(stmt), op, g.expr(stmt.Lhs)) + return typecheck.Stmt(IncDec(g.pos(stmt), op, g.expr(stmt.Lhs))) } - return ir.NewAssignOpStmt(g.pos(stmt), op, g.expr(stmt.Lhs), g.expr(stmt.Rhs)) + return typecheck.Stmt(ir.NewAssignOpStmt(g.pos(stmt), op, g.expr(stmt.Lhs), g.expr(stmt.Rhs))) } names, lhs := g.assignList(stmt.Lhs, stmt.Op == syntax.Def) @@ -65,25 +72,31 @@ func (g *irgen) stmt0(stmt syntax.Stmt) ir.Node { if len(lhs) == 1 && len(rhs) == 1 { n := ir.NewAssignStmt(g.pos(stmt), lhs[0], rhs[0]) n.Def = initDefn(n, names) - return n + // Need to set Assigned in checkassign for maps + return typecheck.Stmt(n) } n := ir.NewAssignListStmt(g.pos(stmt), ir.OAS2, lhs, rhs) n.Def = initDefn(n, names) - return n + // Need to do tcAssignList(). + return typecheck.Stmt(n) case *syntax.BranchStmt: return ir.NewBranchStmt(g.pos(stmt), g.tokOp(int(stmt.Tok), branchOps[:]), g.name(stmt.Label)) case *syntax.CallStmt: return ir.NewGoDeferStmt(g.pos(stmt), g.tokOp(int(stmt.Tok), callOps[:]), g.expr(stmt.Call)) case *syntax.ReturnStmt: - return ir.NewReturnStmt(g.pos(stmt), g.exprList(stmt.Results)) + n := ir.NewReturnStmt(g.pos(stmt), g.exprList(stmt.Results)) + // Need to do typecheckaste() for multiple return values + return typecheck.Stmt(n) case *syntax.IfStmt: return g.ifStmt(stmt) case *syntax.ForStmt: return g.forStmt(stmt) case *syntax.SelectStmt: - return g.selectStmt(stmt) + n := g.selectStmt(stmt) + // Need to convert assignments to OSELRECV2 in tcSelect() + return typecheck.Stmt(n) case *syntax.SwitchStmt: return g.switchStmt(stmt) diff --git a/test/typeparam/absdiff.go b/test/typeparam/absdiff.go index 5dd58f14f7..1381d7c92c 100644 --- a/test/typeparam/absdiff.go +++ b/test/typeparam/absdiff.go @@ -8,7 +8,7 @@ package main import ( "fmt" - //"math" + "math" ) type Numeric interface { @@ -57,14 +57,14 @@ func (a orderedAbs[T]) Abs() orderedAbs[T] { // complexAbs is a helper type that defines an Abs method for // complex types. -// type complexAbs[T Complex] T +type complexAbs[T Complex] T -// func (a complexAbs[T]) Abs() complexAbs[T] { -// r := float64(real(a)) -// i := float64(imag(a)) -// d := math.Sqrt(r * r + i * i) -// return complexAbs[T](complex(d, 0)) -// } +func (a complexAbs[T]) Abs() complexAbs[T] { + r := float64(real(a)) + i := float64(imag(a)) + d := math.Sqrt(r * r + i * i) + return complexAbs[T](complex(d, 0)) +} // OrderedAbsDifference returns the absolute value of the difference // between a and b, where a and b are of an ordered type. @@ -74,9 +74,9 @@ func orderedAbsDifference[T orderedNumeric](a, b T) T { // ComplexAbsDifference returns the absolute value of the difference // between a and b, where a and b are of a complex type. -// func complexAbsDifference[T Complex](a, b T) T { -// return T(absDifference(complexAbs[T](a), complexAbs[T](b))) -// } +func complexAbsDifference[T Complex](a, b T) T { + return T(absDifference(complexAbs[T](a), complexAbs[T](b))) +} func main() { if got, want := orderedAbsDifference(1.0, -2.0), 3.0; got != want { @@ -89,11 +89,10 @@ func main() { panic(fmt.Sprintf("got = %v, want = %v", got, want)) } - // Still have to handle built-ins real/abs to make this work - // if got, want := complexAbsDifference(5.0 + 2.0i, 2.0 - 2.0i), 5; got != want { - // panic(fmt.Sprintf("got = %v, want = %v", got, want) - // } - // if got, want := complexAbsDifference(2.0 - 2.0i, 5.0 + 2.0i), 5; got != want { - // panic(fmt.Sprintf("got = %v, want = %v", got, want) - // } + if got, want := complexAbsDifference(5.0 + 2.0i, 2.0 - 2.0i), 5+0i; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } + if got, want := complexAbsDifference(2.0 - 2.0i, 5.0 + 2.0i), 5+0i; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } } diff --git a/test/typeparam/append.go b/test/typeparam/append.go new file mode 100644 index 0000000000..8b9bc2039f --- /dev/null +++ b/test/typeparam/append.go @@ -0,0 +1,31 @@ +// run -gcflags=-G=3 + +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +type Recv <-chan int + +type sliceOf[E any] interface { + type []E +} + +func _Append[S sliceOf[T], T any](s S, t ...T) S { + return append(s, t...) +} + +func main() { + recv := make(Recv) + a := _Append([]Recv{recv}, recv) + if len(a) != 2 || a[0] != recv || a[1] != recv { + panic(a) + } + + recv2 := make(chan<- int) + a2 := _Append([]chan<- int{recv2}, recv2) + if len(a2) != 2 || a2[0] != recv2 || a2[1] != recv2 { + panic(a) + } +} diff --git a/test/typeparam/double.go b/test/typeparam/double.go new file mode 100644 index 0000000000..1f7a26c7f4 --- /dev/null +++ b/test/typeparam/double.go @@ -0,0 +1,50 @@ +// run -gcflags=-G=3 + +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "reflect" +) + +type Number interface { + type int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr, float32, float64 +} + +type MySlice []int + +type _SliceOf[E any] interface { + type []E +} + +func _DoubleElems[S _SliceOf[E], E Number](s S) S { + r := make(S, len(s)) + for i, v := range s { + r[i] = v + v + } + return r +} + +func main() { + arg := MySlice{1, 2, 3} + want := MySlice{2, 4, 6} + got := _DoubleElems[MySlice, int](arg) + if !reflect.DeepEqual(got, want) { + panic(fmt.Sprintf("got %s, want %s", got, want)) + } + + // constraint type inference + got = _DoubleElems[MySlice](arg) + if !reflect.DeepEqual(got, want) { + panic(fmt.Sprintf("got %s, want %s", got, want)) + } + + got = _DoubleElems(arg) + if !reflect.DeepEqual(got, want) { + panic(fmt.Sprintf("got %s, want %s", got, want)) + } +} diff --git a/test/typeparam/settable.go b/test/typeparam/settable.go index 29874fb189..588166da85 100644 --- a/test/typeparam/settable.go +++ b/test/typeparam/settable.go @@ -14,44 +14,58 @@ import ( // Various implementations of fromStrings(). type _Setter[B any] interface { - Set(string) + Set(string) type *B } // Takes two type parameters where PT = *T func fromStrings1[T any, PT _Setter[T]](s []string) []T { - result := make([]T, len(s)) - for i, v := range s { - // The type of &result[i] is *T which is in the type list - // of Setter, so we can convert it to PT. - p := PT(&result[i]) - // PT has a Set method. - p.Set(v) - } - return result + result := make([]T, len(s)) + for i, v := range s { + // The type of &result[i] is *T which is in the type list + // of Setter, so we can convert it to PT. + p := PT(&result[i]) + // PT has a Set method. + p.Set(v) + } + return result } +func fromStrings1a[T any, PT _Setter[T]](s []string) []PT { + result := make([]PT, len(s)) + for i, v := range s { + // The type new(T) is *T which is in the type list + // of Setter, so we can convert it to PT. + result[i] = PT(new(T)) + p := result[i] + // PT has a Set method. + p.Set(v) + } + return result +} + + // Takes one type parameter and a set function func fromStrings2[T any](s []string, set func(*T, string)) []T { - results := make([]T, len(s)) - for i, v := range s { - set(&results[i], v) - } - return results + results := make([]T, len(s)) + for i, v := range s { + set(&results[i], v) + } + return results } type _Setter2 interface { - Set(string) + Set(string) } // Takes only one type parameter, but causes a panic (see below) func fromStrings3[T _Setter2](s []string) []T { - results := make([]T, len(s)) - for i, v := range s { + results := make([]T, len(s)) + for i, v := range s { // Panics if T is a pointer type because receiver is T(nil). results[i].Set(v) - } - return results + } + return results } // Two concrete types with the appropriate Set method. @@ -59,11 +73,11 @@ func fromStrings3[T _Setter2](s []string) []T { type SettableInt int func (p *SettableInt) Set(s string) { - i, err := strconv.Atoi(s) - if err != nil { - panic(err) - } - *p = SettableInt(i) + i, err := strconv.Atoi(s) + if err != nil { + panic(err) + } + *p = SettableInt(i) } type SettableString struct { @@ -71,34 +85,40 @@ type SettableString struct { } func (x *SettableString) Set(s string) { - x.s = s + x.s = s } func main() { - s := fromStrings1[SettableInt, *SettableInt]([]string{"1"}) - if len(s) != 1 || s[0] != 1 { - panic(fmt.Sprintf("got %v, want %v", s, []int{1})) - } + s := fromStrings1[SettableInt, *SettableInt]([]string{"1"}) + if len(s) != 1 || s[0] != 1 { + panic(fmt.Sprintf("got %v, want %v", s, []int{1})) + } + + s2 := fromStrings1a[SettableInt, *SettableInt]([]string{"1"}) + if len(s2) != 1 || *s2[0] != 1 { + x := 1 + panic(fmt.Sprintf("got %v, want %v", s2, []*int{&x})) + } // Test out constraint type inference, which should determine that the second // type param is *SettableString. ps := fromStrings1[SettableString]([]string{"x", "y"}) - if len(ps) != 2 || ps[0] != (SettableString{"x"}) || ps[1] != (SettableString{"y"}) { - panic(s) - } + if len(ps) != 2 || ps[0] != (SettableString{"x"}) || ps[1] != (SettableString{"y"}) { + panic(s) + } - s = fromStrings2([]string{"1"}, func(p *SettableInt, s string) { p.Set(s) }) - if len(s) != 1 || s[0] != 1 { - panic(fmt.Sprintf("got %v, want %v", s, []int{1})) - } + s = fromStrings2([]string{"1"}, func(p *SettableInt, s string) { p.Set(s) }) + if len(s) != 1 || s[0] != 1 { + panic(fmt.Sprintf("got %v, want %v", s, []int{1})) + } - defer func() { - if recover() == nil { - panic("did not panic as expected") - } - }() - // This should type check but should panic at run time, - // because it will make a slice of *SettableInt and then call - // Set on a nil value. - fromStrings3[*SettableInt]([]string{"1"}) + defer func() { + if recover() == nil { + panic("did not panic as expected") + } + }() + // This should type check but should panic at run time, + // because it will make a slice of *SettableInt and then call + // Set on a nil value. + fromStrings3[*SettableInt]([]string{"1"}) }