WaitGroup源码剖析

  1. 结构定义
  2. 增加/减少计数
  3. Wait 等待计数归零
  4. 结尾

sync.waitgroup 是一个常用的用来协调等待协程结束的组件。

比如在下面的这段代码上,我们通过 waitgroup 可以让 main 等待协程退出后再退出:

func main() {

    wg := sync.WaitGroup{}

    wg.Add(1)
    go func() {
        time.Sleep(2 * time.Second)
        fmt.Println("goroutine exit.")
        wg.Done()
    }()

    wg.Wait()
    fmt.Println("main exit.")
}

那么,waitgroup 是怎么实现的呢,我们来详解一下。

结构定义

waitgroup 的结构定义非常简单,但是涉及到了几个重要的知识点。

type WaitGroup struct {
    noCopy noCopy
    state1 [3]uint32
}
  • noCopy 表示了可以做静态检查,不允许拷贝实例使用;
  • state1 这里面包含了 3种状态:
    • 64位值:高 32 位 计数的数量;低 32位 等待 goroutine 的数量
    • 32位值:还有 Semaphore;

根据 atomic 官方文档 最后一段中,对于 64位数在32位平台上的操作时,强制要求使用 8 字节对齐,否则就会出现问题。而如果保证 waitgroup 在32位平台上使用的话,就必须保证在任何时候,64位的操作不会出错。

所以,并不能直接在这里将变量申明成下面的样子,原因是因为我们并不能确定 counter 是不是在 8 字节对齐的位置上(即便互换了 sema 和 counter 也不行)。

type WaitGroup struct {
    noCopy  noCopy
    counter uint64
    sema    uint32
}

这就需要有一个办法,来动态的识别当前我们操作的64位数,到底是不是在 8 字节对齐的位置上面。WaitGroup 通过申明一个 12 字节的数组,并实现了一个内部方法 state() 来保证这一点:

func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
    // 当数组的首地址是在一个 8 字节对齐的位置上时
    // 那么就将数组中的前 8 个字节作为64位值使用
    // 后 4 个字节作为 semaphore
    if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
        return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
    } else {
        // 如果首地址没有在 8 字节对齐的位置上时
        // 那么,就将前 4 个字节作为 semaphore
        // 后 8 个字节作为 64位计数值
        return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
    }
}

waitgroup

增加/减少计数

WaitGroup 中,使用 Add() 增加一个计数,当需要减掉一个计数时,使用 Done() 。但实际上 Done() 调用的还是 Add(),只不过增加的是 -1:

func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

因为,增减的逻辑都放在了 Add() 中,而调用者可以随意传入正负数值到函数中,所以需要考虑两种异常情况:

  1. 计数为负数;
  2. 当有等待者等待的时候,并发调用 Add();

下面的代码详解,去掉了 race 检查的部分:

func (wg *WaitGroup) Add(delta int) {
    // 获取到计数和 semaphore
    statep, semap := wg.state()

    // 给高32位增加 delta 的计数
    state := atomic.AddUint64(statep, uint64(delta)<<32)

    // 获取到计数的值
    v := int32(state >> 32)
    
    // 获取 semaphore 的值
    w := uint32(state)

    // 计数小于 0 ,异常 panic
    if v < 0 {
        panic("sync: negative WaitGroup counter")
    }
    
    // 有等待者的时候,并发调用了 Add, 异常 panic
    if w != 0 && delta > 0 && v == int32(delta) {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    
    // 计数大于 0 正常返回
    // 没有等待者,也不需要后续操作
    if v > 0 || w == 0 {
        return
    }

    // ------- 最后的情况 计数 == 0  --------
    
    // 有等待者的时候,并发调用了 Add, 异常 panic
    if *statep != state {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    // Reset waiters count to 0.
    // 重置 waiter 等于 0
    *statep = 0
    for ; w != 0; w-- {
        // 按顺序通知等待的 goroutine
        runtime_Semrelease(semap, false, 0)
    }
}

Wait 等待计数归零

func (wg *WaitGroup) Wait() {
    // 获取状态
    statep, semap := wg.state()

    for {  // 因为有 CAS,所以要放到循环中,保证成功
        // load 状态值
        state := atomic.LoadUint64(statep)
        
        // 获取计数
        v := int32(state >> 32)
        
        // 获取等待者数量
        w := uint32(state)
        
        // 计数为0 直接返回
        if v == 0 {
            return
        }
        // 增加等待者的数量
        if atomic.CompareAndSwapUint64(statep, state, state+1) {
            // 增加成功,等待信号量
            runtime_Semacquire(semap)
            
            // 通知计数归零了,如果状态值不为零,那么认为是有问题的,详见 Add()
            if *statep != 0 {
                panic("sync: WaitGroup is reused before previous Wait has returned")
            }
            
            return
        }
    }
}

结尾

整个 WaitGroup 中,实现相对其他库比较简单。但是对于 8 字节对齐的处理很有意思,值得在开发中借鉴。


署名-非商业性使用-相同方式共享 4.0