泛型是一些语言的标配,可以极大的便利开发者,或许C++转golang最想要的语法就是泛型功能,这样可以使用类似模板的功能,减少很多类型判断的代码;
泛型已经出来一年多了,但有一些开发者还未用到这一语法特性,因此本文主要介绍一下泛型在golang的使用和样例,需要1.18+的版本支持,所以读者可以使用在线环境:https://go2goplay.golang.org/。

如何写一个泛型

package main

import (
    "fmt"
)

type Ifvar interface {
    type int, int8, int16, int32, int64,
        uint, uint8, uint16, uint32, uint64, uintptr,
        float32, float64, complex64, complex128,
        string
}

func If[T Ifvar](b bool, a1, a2 T) T {
    if b {
        return a1
    }
    
    return a2
}


func main() {
    fmt.Println(If(1 < 2, 1, 2))
    fmt.Println(If("foo" > "bar", "foo", "bar"))
    fmt.Println(If(If("aoo" > "bar1111111111", "aoo", "bar1111111111") > "bb", "hello1", "hello2"))
}

以前三元表达式需要interface{},然后再强制转换类型,现在泛型就能搞定。

定义约束

(1)类型约束

package main

import (
    "fmt"
)

type Addable interface {
    type int, int8, int16, int32, int64,
        uint, uint8, uint16, uint32, uint64, uintptr,
        float32, float64, complex64, complex128,
        string
}

func add[T Addable](a, b T) T {
    return a + b
}

func main() {
    fmt.Println(add(1,2)) // 输出:3
    fmt.Println(add("foo","bar")) // 输出:foobar
}

通过type可以约束add函数支持的参数类型,否则编译不通过;

(2)方法约束

package main

import (
    "fmt"
)

type TA1 struct {}

func (t TA1) String() {
    fmt.Println("TA1 String()")
}

type TA2 struct {}

func (t TA2) String() {
    fmt.Println("TA2 String()")
}

type TAable interface {
    type TA1, TA2
    String()
}

func print[T TAable](a, b T) {
    a.String()
    b.String()
}

func main() {
    var a11, a12 TA1
    var a21, a22 TA2
    print(a11, a12)
    print(a21, a22)
}

通过interface中定义的方法可以约束支持类型,对于无实现对应函数的struct是编译报错;

(3)多参数约束

package main

import (
    "fmt"
)

func print[T1, T2 any](a T1, b T2) {
    fmt.Println("a: ", a, ", b: ", b)
}

func main() {
    print([]string{"xxxx", "nnnn"}, []int{1, 2})
}

泛型可以支持多参数类型,用any可以传入任意类型,但如果使用某些类型方法或者变量,需要实现才行;

泛型函数

package main

import (
    "fmt"
)

type comparable interface {
    type int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr, float32, float64
}

func max[T comparable](a []T) T {
    m := a[0]
    for _, v := range a {
        if m < v {
            m = v
        }
    }
    return m
}

func min[T comparable](a []T) T {
    m := a[0]
    for _, v := range a {
        if m > v {
            m = v
        }
    }
    return m
}

func main() {
    vi := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
    result := max(vi)
    fmt.Println(result)

    vi1 := []float64{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0}
    result1 := min(vi1)
    fmt.Println(result1)
}

以上是一个泛型函数的示例,实现c++的max_element和min_element功能,可以传入声明的类型;

泛型结构体

package main

import (
    "fmt"
)

type List[T any] struct {
    next *List[T]
    val  T
}

func (l *List[T]) Add(x T) {
    l.val = x 
    // pass
}

func (l *List[T]) Print() {
    fmt.Println("l->val: ", l.val)
}

func main() {
    var l1 List[int]
    l1.Add(1)
    l1.Print()

    var l2 List[string]
    l2.Add("XXXXXXXXXXXXXX")
    l2.Print()
}

以上是一个泛型结构体的示例,存储任何类型的List;

泛型指针

package main

import (
    "fmt"
    "strconv"
)

type Setter interface {
    Set(string)
}

func FromStrings[T Setter](s []string) []T {
    result := make([]T, len(s))
    for i, v := range s {
        result[i].Set(v)
    }
    return result
}

type Settable int

func (p *Settable) Set(s string) {
    i, _ := strconv.Atoi(s) // real code should not ignore the error
    *p = Settable(i)
}

func main() {
    // 编译报错:Settable does not satisfy Setter: wrong method signature
    nums := FromStrings[Settable]([]string{"1", "2"})
    fmt.Println(nums)

    // 编译正常
    nums := FromStrings[*Settable]([]string{"1", "2"})
    fmt.Println(nums)
}
Settable*SettableSettable*Settableresult := make([]T, len(s))
package main

import (
    "fmt"
    "strconv"
)

type Setter[B any] interface {
    Set(string)
    type *B
}

func FromStrings[T any, T1 Setter[T]](s []string) []T {
    result := make([]T, len(s))
    for i, v := range s {
        p := T1(&result[i])
        p.Set(v)
    }
    return result
}

type Settable int

func (p *Settable) Set(s string) {
    i, _ := strconv.Atoi(s) // real code should not ignore the error
    *p = Settable(i)
}

func main() {
    nums := FromStrings[Settable, *Settable]([]string{"1", "2"})
    fmt.Println(nums)
}

Channel泛型

package main

import (
    "fmt"
    "time"
)

func reduceFunc[T any](c <-chan T) {
    for range c {
        fmt.Println("v: ", <-c)
    }
}

func main() {
    c1 := make(chan int, 1)
    go reduceFunc(c1)
    c1 <- 1

    c2 := make(chan string, 1)
    go reduceFunc(c2)
    c2 <- "XXXXXXXXXXXXXx"

    time.Sleep(1000)
}

泛型推导

(1)函数推导

package main

import (
    "fmt"
    "reflect"
)

type Addable interface {
    type int, int8, int16, int32, int64,
        uint, uint8, uint16, uint32, uint64, uintptr,
        float32, float64, complex64, complex128
}

func generator[T Addable](a T, v T) func() T {
    return func() T {
        fmt.Println("function T: ", reflect.TypeOf(a))
        r := a
        a = a + v
        return r
    }
}

func main() {
    f1 := generator(0, 1)
    fmt.Println(f1())
    f2 := generator(2.0, 1.0)
    fmt.Println(f2())
}

(2)slice推导

package main

import (
    "fmt"
)

func filterFunc[T any](a []T, f func(T) bool) []T {
    var n []T
    for _, e := range a {
        if f(e) {
            n = append(n, e)
        }
    }
    return n
}

func main() {
    vi := []int{1,2,3,4,5,6}
    vi = filterFunc(vi, func(v int) bool {
        return v < 4
    })
    fmt.Println(vi)
}

(3)多参数类型推导

package main

import (
    "fmt"
)

func mapFunc[T any, M any](a []T, f func(T) M) []M {
    n := make([]M, len(a), cap(a))
    for i, e := range a {
        n[i] = f(e)
    }
    return n
}

func main() {
    vi := []int{1,2,3,4,5,6}
    vs := mapFunc(vi, func(v int) string {
        return "<" + fmt.Sprint(v) + ">"
    })
    fmt.Println(vs)
}

(4)递归推导

package main

import (
    "fmt"
)

func permInternal[T any](a []T, f func([]T), i int) {
    if i > len(a) {
        f(a)
        return
    }
    permInternal(a, f, i+1)
    for j := i + 1; j < len(a); j++ {
        a[i], a[j] = a[j], a[i]
        permInternal(a, f, i+1)
        a[i], a[j] = a[j], a[i]
    }
}

func perm[T any](a []T, f func([]T)) {
    permInternal(a, f, 0)
}

func main() {
    vi := []int{1, 2, 3}
    perm(vi, func(a []int) {
        fmt.Println(a)
    })
}

类型转换

package main

import (
    "fmt"
    "reflect"
)

type integer interface {
    type int, int8, int16, int32, int64,
        uint, uint8, uint16, uint32, uint64, uintptr, float64
}

func Convert[To, From integer](from From) To {
    to := To(from)
    if From(to) != from {
        panic("conversion out of range")
    }
    return to
}

func main() {
    r := Convert[int, float64](1.0)
    fmt.Println("1.0 -> r: ", r, ", type: ", reflect.TypeOf(r))
}

注意

(1) 不明确的转换非法

package main

import (
    "fmt"
)

// 编译报错:cannot convert x (variable of type parameter type T2) to T1
func Copy1[T1, T2 any](dst []T1, src []T2) int {
    for i, x := range src {
        if i > len(dst) {
            return i
        }
        dst[i] = T1(x) // INVALID
    }
    return len(src)
}

// 编译正常
func Copy2[T1, T2 any](dst []T1, src []T2) int {
    for i, x := range src {
        if i > len(dst) {
            return i
        }
        dst[i] = (interface{})(x).(T1)
    }
    return len(src)
}

func main() {
    fmt.Println("Copy1: ", Copy1([]int{1,2}, []float64{2.0, 3.0}))
    fmt.Println("Copy2: ", Copy2([]int{1,2}, []float64{2.0, 3.0}))
}

(2) 重命名

type Vector[T any] {
    a T
}

type VectorInt = Vector[int]
type any = interface{}