一、什么是sync.WaitGroup
官方文档对其的描述是:WaitGroup等待一组goroutine的任务完成。主goroutine调用添加以设置要等待的goroutine的数量。然后,每个goroutine都会运行并在完成后调用Done。同时,可以使用Wait来阻塞,直到所有goroutine完成。我们来看官网给的一个例子:
package main
import (
"sync"
)
type httpPkg struct{}
func (httpPkg) Get(url string) {}
var http httpPkg
func main() {
var wg sync.WaitGroup
var urls = []string{
"http://www.golang.org/",
"http://www.google.com/",
"http://www.somestupidname.com/",
}
for _, url := range urls {
// 增加waitGroup计数
wg.Add(1)
// 启动goroutine获取url
go func(url string) {
//等获取url的goroutine完成,将waitGroup计数减1
defer wg.Done()
// 获取url
http.Get(url)
}(url)
}
// 等待所有goroutine完成
wg.Wait()
}
二、源码分析
WaitGroup结构体onCopy机制
Go中没有原生的禁止拷贝的方式,所以如果有的结构体,你希望使用者无法拷贝,只能指针传递保证全局唯一的话,可以这么干,定义一个结构体叫noCopy,要实现sync.Locker 这个接口。
type noCopy struct{}
// nocopy 只有在使用 go vet 检查时才能显示错误,编译正常
func (*noCopy) Lock() {}
func (*noCopy) UnLock() {}
state1字段处理
总共分配了12个字节,在这里被设计成三种状态。其中对齐的8个字节作为状态位(state),高32位为记录计数器的数量,低32位为等待goroutine的数量值。其余的4个字节作为信号量存储(sema)。由于操作系统分为32位和64位,64位的原子操作需要64位对齐,但是32位编译器保证不了,于是这里就采用了动态识别当前我们操作的64位数到底是不是在8字节对齐的位置上面。具体见源码state方法
// 得到state1的状态位和信号量
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
// 如果地址是64bit对齐的,数组前两个元素做state,后一个元素做信号量
return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
} else {
// 如果地址是32bit对齐的,数组后两个元素用来做state,它可以用来做64bit的原子操作,第一个元素32bit用来做信号量
return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
}
}
Add方法实现
主要操作的state1字段中计数值部分,计数器部分的逻辑主要是通过state(),在上面有提及。每次调用Add方法就会增加相应数量的计数器。如果计数器为零,则释放等待时阻塞的所有goroutine。如果计数器变为负数,请添加恐慌。如果计数器值大于0,说明此时还有任务没有完成,那么调用者就变成等待者,需要加入wait队列,并且阻塞自己。参数可正可负数。
func (wg *WaitGroup) Add(delta int) {
//获取state1中的状态位和信号量位
statep, semap := wg.state()
//用来goroutine的竞争检测,可忽略。
if race.Enabled {
_ = *statep
if delta < 0 {
race.ReleaseMerge(unsafe.Pointer(wg))
}
race.Disable()
defer race.Enable()
}
// uint64(delta)<<32 将delta左移32
// 因为高32位表示计数器,所以delta左移32位,
// 增加到计数位。
state := atomic.AddUint64(statep, uint64(delta)<<32)
// 当前计数器的值
v := int32(state >> 32)
// 阻塞的wait goroutine数量
w := uint32(state)
if race.Enabled && delta > 0 && v == int32(delta) {
race.Read(unsafe.Pointer(semap))
}
// 计数器的值<0,panic
if v < 0 {
panic("sync: negative WaitGroup counter")
}
// 当wait goroutine数量不为0时,累加后的counter值和delta相等,
// 说明Add()和Wait()同时调用了,所以发生panic,
// 因为正确的做法是先Add()后Wait(),
// 也就是已经调用了wait()就不允许再添加任务了
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// add调用结束
if v > 0 || w == 0 {
return
}
// 能走到这里说明当前Goroutine Counter计数器为0,
// Waiter Counter计数器大于0,
// 到这里数据也就是允许发生变动了,如果发生变动了,则出发panic
if *statep != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// 所有的状态位清0
*statep = 0
for ; w != 0; w-- {
// 首先让信号量加一,然后检查是否有正在等待的Goroutine,如果没有,直接返回;
// 如果有,调用goready函数唤醒一个Goroutine。
runtime_Semrelease(semap, false, 0)
}
}
Done方法实现
内部调用了Add(-1)的操作,具体看Add方法实现部分
//Done方法其实就是Add(-1)
func (wg *WaitGroup) Done() {
wg.Add(-1)
}
Wait方法实现
阻塞主goroutine直到WaitGroup计数器变为0。
// 等待并阻塞,直到WaitGroup计数器为0
func (wg *WaitGroup) Wait() {
// 获取waitgroup状态位和信号量
statep, semap := wg.state()
if race.Enabled {
_ = *statep
race.Disable()
}
for {
// 使用原子操作读取state,是为了保证Add中的写入操作已经完成
state := atomic.LoadUint64(statep)
v := int32(state >> 32) //获取计数器(高32位)
w := uint32(state) //获取wait goroutine数量(低32位)
if v == 0 { // 计数器为0,跳出死循环,不用阻塞
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
// 使用CAS操作对`waiter Counter`计数器进行+1操作,
// 外面有for循环保证这里可以进行重试操作
if atomic.CompareAndSwapUint64(statep, state, state+1) {
if race.Enabled && w == 0 {
race.Write(unsafe.Pointer(semap))
}
// 在这里获取信号量,使线程进入睡眠状态,
// 与Add方法中runtime_Semrelease增加信号量相对应,
// 也就是当最后一个任务调用Done方法
// 后会调用Add方法对goroutine counter的值减到0,
// 就会走到最后的增加信号量
runtime_Semacquire(semap)
// 在Add方法中增加信号量时已经将statep的值设为0了,
// 如果这里不是0,说明在wait之后又调用了Add方法,
// 使用时机不对,触发panic
if *statep != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
}
}
三、推荐阅读