1
头图

原文链接:源码剖析sync.WaitGroup(文末思考题你能解释一下吗?)

前言

哈喽,大家好,我是asong,这是我并发编程系列的第三篇文章,上一篇我们一起分析了sync.once的使用与实现,今天我们一起来看一看sync.WaitGroup的使用与实现.

快过年了,这是年前最后一篇推文了,待我积累一下,年后加大力度写干货,在这里先预祝大家新春快乐,身体健康,万事如意!

什么是sync.WaitGroup

官方文档对sync.WatiGroup的描述是:一个waitGroup对象可以等待一组协程结束,也就等待一组goroutine返回。有了sync.Waitgroup我们可以将原本顺序执行的代码在多个Goroutine中并发执行,加快程序处理的速度。其实他与java中的CountdownLatch,阻塞等待所有任务完成之后再继续执行。我们来看官网给的一个例子,这个例子使用waitGroup阻塞主进程,并发获取多个URL,直到完成所有获取:

package main

import (
    "sync"
)

type httpPkg struct{}

func (httpPkg) Get(url string) {}

var http httpPkg

func main() {
    var wg sync.WaitGroup
    var urls = []string{
        "http://www.golang.org/",
        "http://www.google.com/",
        "http://www.somestupidname.com/",
    }
    for _, url := range urls {
        // Increment the WaitGroup counter.
        wg.Add(1)
        // Launch a goroutine to fetch the URL.
        go func(url string) {
            // Decrement the counter when the goroutine completes.
            defer wg.Done()
            // Fetch the URL.
            http.Get(url)
        }(url)
    }
    // Wait for all HTTP fetches to complete.
    wg.Wait()
}

首先我们需要声明一个sync.WaitGroup对象,在主gorourine调用Add()方法设置要等待的goroutine数量,每一个Goroutine在运行结束时要调用Done()方法,同时使用Wait()方法进行阻塞直到所有的goroutine完成。

为什么要用sync.waitGroup

我们在日常开发中为了提高接口响应时间,有一些场景需要在多个goroutine中做一些互不影响的业务,这样可以节省不少时间,但是需要协调多个goroutine,没有sync.WaitGroup的时候,我们可以使用通道来解决这个问题,我们把主Goroutine当成铜锣扛把子a song,把每一个Goroutine当成一个马仔,asong管理这些马仔,让这些马仔去收保护费,我今天派10个马仔去收保护费,每一个马仔收好了保护费就在账本上打一个✅,当所有马仔都收好了保护费,账本上就被打满了✅,活全被干完了,很出色,然后酒吧走起,浪一浪,全场的消费松公子买单,写成代码可以这样表示:


func exampleImplWaitGroup()  {
    done := make(chan struct{}) // 收10份保护费
    count := 10 // 10个马仔
    for i:=0;i < count;i++{
        go func(i int) {
            defer func() {
                done <- struct {}{}
            }()
            fmt.Printf("马仔%d号收保护费\n",i)
        }(i)
    }
    for i:=0;i< count;i++{
        <- done
        fmt.Printf("马仔%d号已经收完保护费\n",i)
    }
    fmt.Println("所有马仔已经干完活了,开始酒吧消费~")
}

虽然这样可以实现,但是我们每次使用都要保证主Goroutine最后从通道接收的次数需要与之前其他的Goroutine发送元素的次数相同,实现起来不够优雅,在这种场景下我们就可以选用sync.WaitGroup来帮助我们实现同步。

源码剖析

前面我们已经知道sync.waitGroup的基本使用了,接下来我们就一起看看他是怎样实现的~,只有知其所以然,才能写出更健壮的代码。

Go version: 1.15.3

首先我们看一下sync.WaitGroup的结构:

// A WaitGroup must not be copied after first use.
type WaitGroup struct {
    noCopy noCopy

    // 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
    // 64-bit atomic operations require 64-bit alignment, but 32-bit
    // compilers do not ensure it. So we allocate 12 bytes and then use
    // the aligned 8 bytes in them as state, and the other 4 as storage
    // for the sema.
    state1 [3]uint32
}

总共就有两个字段,nocopy是为了保证该结构不会被进行拷贝,这是一种保护机制,会在后面进行介绍;state1主要是存储着状态和信号量,这里使用的8字节对齐处理的方式很有意思,我先来一起看看这种处理。

state1状态和信号量处理

state1这里总共被分配了12个字节,这里被设计了三种状态:

  • 其中对齐的8个字节作为状态,高32位为计数的数量,低32位为等待的goroutine数量
  • 其中的4个字节作为信号量存储

提供了(wg *WaitGroup) state() (statep *uint64, semap *uint32)帮助我们从state1字段中取出他的状态和信号量,为什么要这样设计呢?

我们在分析atomicGo看源码必会知识之unsafe包有说到过,64位原子操作需要64位对齐,但是32位编译器不能保证这一点,所以为了保证waitGroup32位平台上使用的话,就必须保证在任何时候,64位操作不会报错。所以也就不能分成两个字段来写,考虑到字段顺序不同、平台不同,内存对齐也就不同。因此这里采用动态识别当前我们操作的64位数到底是不是在8字节对齐的位置上面,我们来分析一下state方法:

// state returns pointers to the state and sema fields stored within wg.state1.
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
    if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
        return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
    } else {
        return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
    }
}

当数组的首地址是处于一个8字节对齐的位置上时,那么就将这个数组的前8个字节作为64位值使用表示状态,后4个字节作为32位值表示信号量(semaphore)。同理如果首地址没有处于8字节对齐的位置上时,那么就将前4个字节作为semaphore,后8个字节作为64位数值。画个图表示一下:

image

Add()Done()方法

sync.WaitGroup提供了Add()方法增加一个计数器,Done()方法减掉一个计数,Done()方法实现比较简单,内部调用的Add()方法实现的计数器减一操作,也就是增减逻辑都在Add()方法中,所以我们重点看一下Add()是如何实现的:

func (wg *WaitGroup) Add(delta int) {
  // 获取状态(Goroutine Counter 和 Waiter Counter)和信号量
    statep, semap := wg.state()
    if race.Enabled {
        _ = *statep // trigger nil deref early
        if delta < 0 {
            // Synchronize decrements with Wait.
            race.ReleaseMerge(unsafe.Pointer(wg))
        }
        race.Disable()
        defer race.Enable()
    }
  // 原子操作,goroutine counter累加delta
    state := atomic.AddUint64(statep, uint64(delta)<<32)
  // 获取当前goroutine counter的值(高32位)
    v := int32(state >> 32)
  // 获取当前waiter counter的值(低32位)
    w := uint32(state)
    if race.Enabled && delta > 0 && v == int32(delta) {
        // The first increment must be synchronized with Wait.
        // Need to model this as a read, because there can be
        // several concurrent wg.counter transitions from 0.
        race.Read(unsafe.Pointer(semap))
    }
  // Goroutine counter是不允许为负数的,否则会发生panic
    if v < 0 {
        panic("sync: negative WaitGroup counter")
    }
  // 当wait的Goroutine不为0时,累加后的counter值和delta相等,说明Add()和Wait()同时调用了,所以发生panic,因为正确的做法是先Add()后Wait(),也就是已经调用了wait()就不允许再添加任务了
    if w != 0 && delta > 0 && v == int32(delta) {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
  // 正常`Add()`方法后,`goroutine Counter`计数器大于0或者`waiter Counter`计数器等于0时,不需要释放信号量
    if v > 0 || w == 0 {
        return
    }
    // 能走到这里说明当前Goroutine Counter计数器为0,Waiter Counter计数器大于0, 到这里数据也就是允许发生变动了,如果发生变动了,则出发panic
    if *statep != state {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    // 重置状态,并发出信号量告诉wait所有任务已经完成
    *statep = 0
    for ; w != 0; w-- {
        runtime_Semrelease(semap, false, 0)
    }
}

上面的代码有一部分是race静态检测,下面的分析会省略这一部分,因为它并不是本文的重点。

注释我都添加到对应的代码行上了,你是否都看懂了,没看懂不要紧,因为Add()是与Wait()方法一块使用的,所以有些逻辑与wait()里的逻辑是相互照应的,所以当我们看完wait()方法的实现在总结一下你们就明白了。

Wait()方法

sync.Wait()方法会阻塞主Goroutine直到WaitGroup计数器变为0。我们一起来看一下Wait()方法的源码:

// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
  // 获取状态(Goroutine Counter 和 Waiter Counter)和信号量
    statep, semap := wg.state()
    if race.Enabled {
        _ = *statep // trigger nil deref early
        race.Disable()
    }
    for {
    // 使用原子操作读取state,是为了保证Add中的写入操作已经完成
        state := atomic.LoadUint64(statep)
    // 获取当前goroutine counter的值(高32位)
        v := int32(state >> 32)
     // 获取当前waiter counter的值(低32位)
        w := uint32(state)
    // 如果没有任务,或者任务已经在调用`wait`方法前已经执行完成了,就不用阻塞了
        if v == 0 {
            // Counter is 0, no need to wait.
            if race.Enabled {
                race.Enable()
                race.Acquire(unsafe.Pointer(wg))
            }
            return
        }
        // 使用CAS操作对`waiter Counter`计数器进行+1操作,外面有for循环保证这里可以进行重试操作
        if atomic.CompareAndSwapUint64(statep, state, state+1) {
            if race.Enabled && w == 0 {
                // Wait must be synchronized with the first Add.
                // Need to model this is as a write to race with the read in Add.
                // As a consequence, can do the write only for the first waiter,
                // otherwise concurrent Waits will race with each other.
                race.Write(unsafe.Pointer(semap))
            }
      // 在这里获取信号量,使线程进入睡眠状态,与Add方法中最后的增加信号量相对应,也就是当最后一个任务调用Done方法
      // 后会调用Add方法对goroutine counter的值减到0,就会走到最后的增加信号量
            runtime_Semacquire(semap)
      // 在Add方法中增加信号量时已经将statep的值设为0了,如果这里不是0,说明在wait之后又调用了Add方法,使用时机不对,触发panic
            if *statep != 0 {
                panic("sync: WaitGroup is reused before previous Wait has returned")
            }
            if race.Enabled {
                race.Enable()
                race.Acquire(unsafe.Pointer(wg))
            }
            return
        }
    }
}

源码总结

分了源码,我们可以总结如下:

  • Add方法与wait方法不可以并发同时调用,Add方法要在wait方法之前调用.
  • Add()设置的值必须与实际等待的goroutine个数一致,否则会panic.
  • 调用了wait方法后,必须要在wait方法返回以后才能再次重新使用waitGroup,也就是Wait没有返回之前不要在调用Add方法,否则会发生Panic.
  • Done 只是对Add 方法的简单封装,我们可以向 Add方法传入任意负数(需要保证计数器非负)快速将计数器归零以唤醒等待的 Goroutine.
  • waitGroup对象只能有一份,不可以拷贝给其他变量,否则会造成意想不到的Bug.

no copy机制

在前文看waitGroup结构时,有一个nocopy字段,为什么要有nocopy呢?我们先看这样一个例子:

type User struct {
    Name string
    Info *Info
}

type Info struct {
    Age int
    Number int
}


func main()  {
    u := User{
        Name: "asong",
        Info: &Info{
            Age: 10,
            Number: 24,
        },
    }
    u1 := u
    u1.Name = "Golang梦工厂"
    u1.Info.Age = 30
    fmt.Println(u.Info.Age,u.Name)
    fmt.Println(u1.Info.Age,u1.Name)
}
// 运行结果
30 asong
30 Golang梦工厂

结构体User中有两个字段NameInfoNameString类型,Info是指向结构体Info的指针类型,我们首先声明了一个u变量,对他进行复制拷贝得到变量u1,在u1中对两个字段进行改变,可以看到Info字段发生了更改,而Name就没发生更改,这就引发了安全问题,如果结构体对象包含指针字段,当该对象被拷贝时,会使得两个对象中的指针字段变得不再安全。

Go语言中提供了两种copy检查,一种是在运行时进行检查,一种是通过静态检查。不过运行检查是比较影响程序的执行性能的,Go官方目前只提供了strings.Builder和sync.Cond的runtime拷贝检查机制,对于其他需要nocopy对象类型来说,使用go vet工具来做静态编译检查。运行检查的实现可以通过比较所属对象是否发生变更就可以判断,而静态检查是提供了一个nocopy对象,只要是该对象或对象中存在nocopy字段,他就实现了sync.Locker接口, 它拥有Lock()和Unlock()方法,之后,可以通过go vet功能,来检查代码中该对象是否有被copy。

踩坑事项

在文章的最后总结一下使用waitGroup易错的知识点,防止大家再次犯错。

  1. waitGroup中计数器的值是不能小于0的,源码中我们就可以看到,一旦小于0就会引发panic。
  2. 一定要住注意调用Add方法与Wait方法的顺序,不可并发同时调用这两个方法,否则就会引发panic,同时在调用了wait方法在其没有释放前不要再次调用Add方法,这样也会引发panicwaitGroup是可以复用的,但是需要保证其计数周期的完整性。
  3. WaitGroup对象不是一个引用类型,通过函数传值的时候需要使用地址,因为Go语言只有值传递,传递WaitGroup是值的话,就会导致会发生panic,看这样一个例子:
func main()  {
    wg := sync.WaitGroup{}
    wg.Add(1)
    doDeadLock(wg)
    wg.Wait()
}
func doDeadLock(wg sync.WaitGroup)  {
    defer wg.Done()
    fmt.Println("do something")
}
//运行结果:panic: sync: negative WaitGroup counter

发生这个问题的原因就是在doDeadLock()方法中wg是一个新对象,直接调用Done方法,计数器就会出现负数,所以引发panic,为了安全起见,对于这种传结构体的场景一般建议都传指针就好了,基本可以避免一些问题。

  1. Add()设置的值必须与实际等待的goroutine个数一致,否则会panic,很重要的一点,也是很容易出错的地方。

思考题

最后给大家出一个思考题,下面这段代码会不会发生panic

func main() {
    wg := sync.WaitGroup{}
    wg.Add(100)
    for i := 0; i < 100; i++ {
        go func() {
            defer wg.Done()
            fmt.Println(i)
        }()
    }
    wg.Wait()
}

结尾

在最后,祝大家新年快乐,心想事成,万事如意~~~

好啦,这篇文章就到这里啦,素质三连(分享、点赞、在看)都是笔者持续创作更多优质内容的动力!

创建了一个Golang学习交流群,欢迎各位大佬们踊跃入群,我们一起学习交流。入群方式:加我vx拉你入群,或者公众号获取入群二维码

结尾给大家发一个小福利吧,最近我在看[微服务架构设计模式]这一本书,讲的很好,自己也收集了一本PDF,有需要的小伙可以到自行下载。获取方式:关注公众号:[Golang梦工厂],后台回复:[微服务],即可获取。

我翻译了一份GIN中文文档,会定期进行维护,有需要的小伙伴后台回复[gin]即可下载。

翻译了一份Machinery中文文档,会定期进行维护,有需要的小伙伴们后台回复[machinery]即可获取。

我是asong,一名普普通通的程序猿,让我们一起慢慢变强吧。欢迎各位的关注,我们下期见~~~

推荐往期文章:


asong
602 声望904 粉丝