泛型是一些语言的标配,可以极大的便利开发者,或许C++转golang最想要的语法就是泛型功能,这样可以使用类似模板的功能,减少很多类型判断的代码;
泛型已经出来一年多了,但有一些开发者还未用到这一语法特性,因此本文主要介绍一下泛型在golang的使用和样例,需要1.18+的版本支持,所以读者可以使用在线环境:https://go2goplay.golang.org/。
如何写一个泛型
package main
import (
"fmt"
)
type Ifvar interface {
type int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64, uintptr,
float32, float64, complex64, complex128,
string
}
func If[T Ifvar](b bool, a1, a2 T) T {
if b {
return a1
}
return a2
}
func main() {
fmt.Println(If(1 < 2, 1, 2))
fmt.Println(If("foo" > "bar", "foo", "bar"))
fmt.Println(If(If("aoo" > "bar1111111111", "aoo", "bar1111111111") > "bb", "hello1", "hello2"))
}
以前三元表达式需要interface{},然后再强制转换类型,现在泛型就能搞定。
定义约束
(1)类型约束
package main
import (
"fmt"
)
type Addable interface {
type int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64, uintptr,
float32, float64, complex64, complex128,
string
}
func add[T Addable](a, b T) T {
return a + b
}
func main() {
fmt.Println(add(1,2)) // 输出:3
fmt.Println(add("foo","bar")) // 输出:foobar
}
通过type可以约束add函数支持的参数类型,否则编译不通过;
(2)方法约束
package main
import (
"fmt"
)
type TA1 struct {}
func (t TA1) String() {
fmt.Println("TA1 String()")
}
type TA2 struct {}
func (t TA2) String() {
fmt.Println("TA2 String()")
}
type TAable interface {
type TA1, TA2
String()
}
func print[T TAable](a, b T) {
a.String()
b.String()
}
func main() {
var a11, a12 TA1
var a21, a22 TA2
print(a11, a12)
print(a21, a22)
}
通过interface中定义的方法可以约束支持类型,对于无实现对应函数的struct是编译报错;
(3)多参数约束
package main
import (
"fmt"
)
func print[T1, T2 any](a T1, b T2) {
fmt.Println("a: ", a, ", b: ", b)
}
func main() {
print([]string{"xxxx", "nnnn"}, []int{1, 2})
}
泛型可以支持多参数类型,用any可以传入任意类型,但如果使用某些类型方法或者变量,需要实现才行;
泛型函数
package main
import (
"fmt"
)
type comparable interface {
type int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr, float32, float64
}
func max[T comparable](a []T) T {
m := a[0]
for _, v := range a {
if m < v {
m = v
}
}
return m
}
func min[T comparable](a []T) T {
m := a[0]
for _, v := range a {
if m > v {
m = v
}
}
return m
}
func main() {
vi := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
result := max(vi)
fmt.Println(result)
vi1 := []float64{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0}
result1 := min(vi1)
fmt.Println(result1)
}
以上是一个泛型函数的示例,实现c++的max_element和min_element功能,可以传入声明的类型;
泛型结构体
package main
import (
"fmt"
)
type List[T any] struct {
next *List[T]
val T
}
func (l *List[T]) Add(x T) {
l.val = x
// pass
}
func (l *List[T]) Print() {
fmt.Println("l->val: ", l.val)
}
func main() {
var l1 List[int]
l1.Add(1)
l1.Print()
var l2 List[string]
l2.Add("XXXXXXXXXXXXXX")
l2.Print()
}
以上是一个泛型结构体的示例,存储任何类型的List;
泛型指针
package main
import (
"fmt"
"strconv"
)
type Setter interface {
Set(string)
}
func FromStrings[T Setter](s []string) []T {
result := make([]T, len(s))
for i, v := range s {
result[i].Set(v)
}
return result
}
type Settable int
func (p *Settable) Set(s string) {
i, _ := strconv.Atoi(s) // real code should not ignore the error
*p = Settable(i)
}
func main() {
// 编译报错:Settable does not satisfy Setter: wrong method signature
nums := FromStrings[Settable]([]string{"1", "2"})
fmt.Println(nums)
// 编译正常
nums := FromStrings[*Settable]([]string{"1", "2"})
fmt.Println(nums)
}
Settable*SettableSettable*Settableresult := make([]T, len(s))
package main
import (
"fmt"
"strconv"
)
type Setter[B any] interface {
Set(string)
type *B
}
func FromStrings[T any, T1 Setter[T]](s []string) []T {
result := make([]T, len(s))
for i, v := range s {
p := T1(&result[i])
p.Set(v)
}
return result
}
type Settable int
func (p *Settable) Set(s string) {
i, _ := strconv.Atoi(s) // real code should not ignore the error
*p = Settable(i)
}
func main() {
nums := FromStrings[Settable, *Settable]([]string{"1", "2"})
fmt.Println(nums)
}
Channel泛型
package main
import (
"fmt"
"time"
)
func reduceFunc[T any](c <-chan T) {
for range c {
fmt.Println("v: ", <-c)
}
}
func main() {
c1 := make(chan int, 1)
go reduceFunc(c1)
c1 <- 1
c2 := make(chan string, 1)
go reduceFunc(c2)
c2 <- "XXXXXXXXXXXXXx"
time.Sleep(1000)
}
泛型推导
(1)函数推导
package main
import (
"fmt"
"reflect"
)
type Addable interface {
type int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64, uintptr,
float32, float64, complex64, complex128
}
func generator[T Addable](a T, v T) func() T {
return func() T {
fmt.Println("function T: ", reflect.TypeOf(a))
r := a
a = a + v
return r
}
}
func main() {
f1 := generator(0, 1)
fmt.Println(f1())
f2 := generator(2.0, 1.0)
fmt.Println(f2())
}
(2)slice推导
package main
import (
"fmt"
)
func filterFunc[T any](a []T, f func(T) bool) []T {
var n []T
for _, e := range a {
if f(e) {
n = append(n, e)
}
}
return n
}
func main() {
vi := []int{1,2,3,4,5,6}
vi = filterFunc(vi, func(v int) bool {
return v < 4
})
fmt.Println(vi)
}
(3)多参数类型推导
package main
import (
"fmt"
)
func mapFunc[T any, M any](a []T, f func(T) M) []M {
n := make([]M, len(a), cap(a))
for i, e := range a {
n[i] = f(e)
}
return n
}
func main() {
vi := []int{1,2,3,4,5,6}
vs := mapFunc(vi, func(v int) string {
return "<" + fmt.Sprint(v) + ">"
})
fmt.Println(vs)
}
(4)递归推导
package main
import (
"fmt"
)
func permInternal[T any](a []T, f func([]T), i int) {
if i > len(a) {
f(a)
return
}
permInternal(a, f, i+1)
for j := i + 1; j < len(a); j++ {
a[i], a[j] = a[j], a[i]
permInternal(a, f, i+1)
a[i], a[j] = a[j], a[i]
}
}
func perm[T any](a []T, f func([]T)) {
permInternal(a, f, 0)
}
func main() {
vi := []int{1, 2, 3}
perm(vi, func(a []int) {
fmt.Println(a)
})
}
类型转换
package main
import (
"fmt"
"reflect"
)
type integer interface {
type int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64, uintptr, float64
}
func Convert[To, From integer](from From) To {
to := To(from)
if From(to) != from {
panic("conversion out of range")
}
return to
}
func main() {
r := Convert[int, float64](1.0)
fmt.Println("1.0 -> r: ", r, ", type: ", reflect.TypeOf(r))
}
注意
(1) 不明确的转换非法
package main
import (
"fmt"
)
// 编译报错:cannot convert x (variable of type parameter type T2) to T1
func Copy1[T1, T2 any](dst []T1, src []T2) int {
for i, x := range src {
if i > len(dst) {
return i
}
dst[i] = T1(x) // INVALID
}
return len(src)
}
// 编译正常
func Copy2[T1, T2 any](dst []T1, src []T2) int {
for i, x := range src {
if i > len(dst) {
return i
}
dst[i] = (interface{})(x).(T1)
}
return len(src)
}
func main() {
fmt.Println("Copy1: ", Copy1([]int{1,2}, []float64{2.0, 3.0}))
fmt.Println("Copy2: ", Copy2([]int{1,2}, []float64{2.0, 3.0}))
}
(2) 重命名
type Vector[T any] {
a T
}
type VectorInt = Vector[int]
type any = interface{}