Go中看似簡單的WaitGroup源碼設(shè)計,竟然暗含這么多知識?
Go語言提供的協(xié)程goroutine可以讓我們很容易地寫出多線程程序,但是,如何讓這些并發(fā)執(zhí)行的goroutine得到有效地控制,這是我們需要探討的問題。正如小菜刀在《Golang并發(fā)控制簡述》中所述,Go標準庫為我們提供的同步原語中,鎖與原子操作注重控制goroutine之間的數(shù)據(jù)安全,WaitGroup、channel與Context控制的是它們的并發(fā)行為。關(guān)于鎖、原子操作、channel 的實現(xiàn)原理小菜刀均有詳細地解析過。因此本文,我們將重點放在WaitGroup上。
初識WaitGroup
WaitGroup是sync包下的內(nèi)容,用于控制協(xié)程間的同步。WaitGroup使用場景同名字的含義一樣,當我們需要等待一組協(xié)程都執(zhí)行完成以后,才能做后續(xù)的處理時,就可以考慮使用。
1func main() {
2 var wg sync.WaitGroup
3
4 wg.Add(2) //worker number 2
5
6 go func() {
7 // worker 1 do something
8 fmt.Println("goroutine 1 done!")
9 wg.Done()
10 }()
11
12 go func() {
13 // worker 2 do something
14 fmt.Println("goroutine 2 done!")
15 wg.Done()
16 }()
17
18 wg.Wait() // wait all waiter done
19 fmt.Println("all work done!")
20}
21
22// output
23goroutine 2 done!
24goroutine 1 done!
25all work done!
可以看到WaitGroup的使用非常簡單,它提供了三個方法。雖然goroutine之間并不存在類似于父子關(guān)系,但是為了方便理解,本文會將調(diào)用Wait函數(shù)的goroutine稱為主goroutine,調(diào)用Done函數(shù)的goroutine稱呼為子goroutine。
1func (wg *WaitGroup) Add(delta int) // 增加WaitGroup中的子goroutine計數(shù)值
2func (wg *WaitGroup) Done() // 當子goroutine任務(wù)完成,將計數(shù)值減1
3func (wg *WaitGroup) Wait() // 阻塞調(diào)用此方法的goroutine,直到計數(shù)值為0
4
那么它是如何實現(xiàn)的呢?在源碼src/sync/waitgroup.go中,我們可以看到它的核心源碼只有100行不到,十分地精練,非常值得學(xué)習(xí)。
前置知識
代碼少,不代表就實現(xiàn)簡單,易于理解。相反,如果讀者沒有下述中的前置知識,想要真正理解WaitGroup的實現(xiàn)是會比較費力的。在解析源碼之前,我們先過一遍這些知識(如果你都已經(jīng)掌握,那就可以直接跳到后文的源碼解析部分)。
信號量
在學(xué)習(xí)操作系統(tǒng)時,我們知道信號量是一種保護共享資源的機制,用于解決多線程同步問題。信號量s是具有非負整數(shù)值的全局變量,只能由兩種特殊的操作來處理,這兩種操作稱為P和V。
P(s):如果s是非零的,那么P將s減1,并且立即返回。如果s為零,那么就掛起這個線程,直到s變?yōu)榉橇?,等到另一個執(zhí)行V(s)操作的線程喚醒該線程。在喚醒之后,P操作將s減1,并將控制返回給調(diào)用者。V(s):V操作將s加1。如果有任何線程阻塞在P操作等待s變?yōu)榉橇?,那?/span>V操作會喚醒這些線程中的一個,然后該線程將s減1,完成它的P操作。
在Go的底層信號量函數(shù)中
runtime_Semacquire(s *uint32)函數(shù)會阻塞goroutine直到信號量s的值大于0,然后原子性地減這個值,即P操作。runtime_Semrelease(s *uint32, lifo bool, skipframes int)函數(shù)原子性增加信號量的值,然后通知被runtime_Semacquire阻塞的goroutine,即V操作。
這兩個信號量函數(shù)不止在WaitGroup中會用上,在《Go精妙的互斥鎖設(shè)計》一文中,我們發(fā)現(xiàn)Go在設(shè)計互斥鎖的時候也少不了信號量的參與。
內(nèi)存對齊
對于以下的結(jié)構(gòu)體,你能回答出它占用的內(nèi)存是多少嗎
1type Ins struct {
2 x bool // 1個字節(jié)
3 y int32 // 4個字節(jié)
4 z byte // 1個字節(jié)
5}
6
7func main() {
8 ins := Ins{}
9 fmt.Printf("ins size: %d, align: %d\n", unsafe.Sizeof(ins), unsafe.Alignof(ins))
10}
11
12//output
13ins size: 12, align: 4
按照結(jié)構(gòu)體中字段的大小而言,ins對象占用內(nèi)存應(yīng)該是 1+4+1=6 個字節(jié),但是實際上確實12個字節(jié),這就是內(nèi)存對齊所致。從《CPU緩存體系對Go程序的影響》一文中,我們知道CPU的內(nèi)存讀取并不是一個字節(jié)一個字節(jié)地讀取的,而是一塊一塊的。因此,在類型的值在內(nèi)存中對齊的情況下,計算機的加載或者寫入會很高效。
在聚合類型(結(jié)構(gòu)體或數(shù)組)的內(nèi)存所占長度或許會比它元素所占內(nèi)存之和更大。編譯器會添加未使用的內(nèi)存地址用于填充內(nèi)存空隙,以確保連續(xù)的成員或元素相當于結(jié)構(gòu)體或數(shù)組的起始地址是對齊的。

因此,在我們設(shè)計結(jié)構(gòu)體時,當結(jié)構(gòu)體成員的類型不同時,將相同類型的成員定義在相鄰位置可以更節(jié)省內(nèi)存空間。
原子操作CAS
CAS是原子操作的一種,可用于在多線程編程中實現(xiàn)不被打斷的數(shù)據(jù)交換操作,從而避免多線程同時改寫某一數(shù)據(jù)時由于執(zhí)行順序不確定性以及中斷的不可預(yù)知性產(chǎn)生的數(shù)據(jù)不一致問題。該操作通過將內(nèi)存中的值與指定數(shù)據(jù)進行比較,當數(shù)值一樣時將內(nèi)存中的數(shù)據(jù)替換為新的值。關(guān)于Go中原子操作的底層實現(xiàn),小菜刀在《同步原語的基石》一文中有詳細介紹。
移位運算 >> 與 <<
在之前關(guān)于鎖的文章《Go精妙的互斥鎖設(shè)計》與《Go更細粒度的讀寫鎖設(shè)計中》,我們能看到大量的位運算操作。靈活的位運算,能讓一個普通的數(shù)字變化出豐富的含義,這里僅介紹下文中會用到的移位運算。
對于左移位運算 <<,按二進制形式將所有的數(shù)字向左移動對應(yīng)的位數(shù),高位舍棄,低位的空位補零。在數(shù)字沒有溢出的前提下,左移一位相當于乘以2的1次方,左移n位就相當于乘以2的n次方。
對于右移位運算 >>,按二進制形式把所有的數(shù)字向右移動對應(yīng)位數(shù),低位移出,高位的空位補符號位。右移一位相當于除2,右移n位相當于除以2的n次方。這里是取商,余數(shù)就不要了。
移位運算也可以有很巧妙的操作,后文中我們會看到移位運算的高級運用。
unsafa.Pointer指針與uintptr
Go中的指針可以分為三類:1.普通類型指針*T,例如*int;2. unsafe.Pointer指針;3. uintptr。
*T:普通的指針類型,用于傳遞對象地址,不能進行指針計算。
unsafe.Pointer指針:通用型指針,任何一個普通類型的指針*T都可以轉(zhuǎn)換為unsafe.Pointer指針,而且unsafe.Pointer類型的指針還可以轉(zhuǎn)換回普通指針,并且它可以不用和原來的指針類型*T相同。但是它不能進行指針計算,不能讀取內(nèi)存中的值(必須通過轉(zhuǎn)換為某一具體類型的普通指針才行)。
uintptr:準確來講,uintptr并不是指針,它是一個大小并不明確的無符號整型。unsafe.Pointer類型可以與uinptr相互轉(zhuǎn)換,由于uinptr類型保存了指針所指向地址的數(shù)值,因此可以通過該數(shù)值進行指針運算。GC時,不會將uintptr當做指針,uintptr類型目標會被回收。

unsafe.Pointer 是橋梁,可以讓任意類型的普通指針實現(xiàn)相互轉(zhuǎn)換,也可以將任意類型的指針轉(zhuǎn)換為 uintptr 進行指針運算。但是,unsafe.Pointer和任意類型指針的轉(zhuǎn)換可以讓我們將任意值寫入內(nèi)存中,這會破壞Go原有的類型系統(tǒng),同時由于不是所有的數(shù)值都是合法的內(nèi)存地址,從uintptr到unsafe.Pointer的轉(zhuǎn)換同樣會破壞類型系統(tǒng)。因此,既然Go將該包定義為unsafe,那就不應(yīng)該隨意使用。
源碼解析
本文基于Go源碼1.15.7版本
結(jié)構(gòu)體
sync.WaitGroup的結(jié)構(gòu)體定義如下,它包括了一個 noCopy 的輔助字段,和一個具有復(fù)合意義的state1字段。
1type WaitGroup struct {
2 noCopy noCopy
3
4 // 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
5 // 64-bit atomic operations require 64-bit alignment, but 32-bit
6 // compilers do not ensure it. So we allocate 12 bytes and then use
7 // the aligned 8 bytes in them as state, and the other 4 as storage
8 // for the sema.
9 state1 [3]uint32
10}
11
12// state returns pointers to the state and sema fields stored within wg.state1.
13func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
14 // 64位編譯器地址能被8整除,由此可判斷是否為64位對齊
15 if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
16 return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
17 } else {
18 return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
19 }
20}
其中,noCopy字段是空結(jié)構(gòu)體,它并不會占用內(nèi)存,編譯器也不會對其進行字節(jié)填充。它主要是為了通過go vet工具來做靜態(tài)編譯檢查,防止開發(fā)者在使用WaitGroup過程中對其進行了復(fù)制,從而導(dǎo)致的安全隱患。關(guān)于這部分內(nèi)容,可以閱讀《no copy機制》詳細了解。
state1字段是一個長度為3的uint32數(shù)組。它用于表示三部分內(nèi)容:1. 通過Add()設(shè)置的子goroutine的計數(shù)值counter;2. 通過Wait()陷入阻塞的waiter數(shù);3. 信號量semap。
由于后續(xù)是對 uint64 類型的statep進行操作,而64位整數(shù)的原子操作需要64位對齊,32位的編譯器并不能保證這一點。因此,在64位與32位的環(huán)境下,state1字段的組成含義是不相同的。

需要注意的是,當我們初始化一個WaitGroup對象時,其counter值、waiter值、semap值均為0。
Add函數(shù)
Add()函數(shù)的入?yún)⑹且粋€整型,它可正可負,是對counter數(shù)值的更改。如果counter數(shù)值變?yōu)?,那么所有阻塞在Wait()函數(shù)的waiter將會被喚醒;如果counter數(shù)值為負值,將引起panic。
我們將競態(tài)檢測部分的代碼去掉,Add()函數(shù)的實現(xiàn)源碼如下
1func (wg *WaitGroup) Add(delta int) {
2 // 獲取包含counter與waiter的復(fù)合狀態(tài)statep,表示信號量值的semap
3 statep, semap := wg.state()
4 state := atomic.AddUint64(statep, uint64(delta)<<32)
5 v := int32(state >> 32)
6 w := uint32(state)
7
8 if v < 0 {
9 panic("sync: negative WaitGroup counter")
10 }
11
12 if w != 0 && delta > 0 && v == int32(delta) {
13 panic("sync: WaitGroup misuse: Add called concurrently with Wait")
14 }
15
16 if v > 0 || w == 0 {
17 return
18 }
19
20 if *statep != state {
21 panic("sync: WaitGroup misuse: Add called concurrently with Wait")
22 }
23
24 // 如果執(zhí)行到這,一定是 counter=0,waiter>0
25 // 能執(zhí)行到這,一定是執(zhí)行了Add(-x)的goroutine
26 // 它的執(zhí)行,代表所有子goroutine已經(jīng)完成了任務(wù)
27 // 因此,我們需要將復(fù)合狀態(tài)全部歸0,并釋放掉waiter個數(shù)的信號量
28 *statep = 0
29 for ; w != 0; w-- {
30 // 釋放信號量,執(zhí)行一次就將喚醒一個阻塞的waiter
31 runtime_Semrelease(semap, false, 0)
32 }
33}
代碼非常精簡,我們接下來對關(guān)鍵部分進行剖析。
1 state := atomic.AddUint64(statep, uint64(delta)<<32) // 新增counter數(shù)值delta
2 v := int32(state >> 32) // 獲取counter值
3 w := uint32(state) // 獲取waiter值
此時的statep是一個uint64數(shù)值,如果此時statep中包含的counter數(shù)為2,waiter為1,輸入delta為1,那么這三行代碼的邏輯過程如下圖所示。

在得到當前counter數(shù)v與waiter數(shù)w后,會對它們的值進行判斷,分幾種情況。
1 // 情況1:這是很低級的錯誤,counter值不能為負
2 if v < 0 {
3 panic("sync: negative WaitGroup counter")
4 }
5
6 // 情況2:misuse引起panic
7 // 因為wg其實是可以用復(fù)用的,但是下一次復(fù)用的基礎(chǔ)是需要將所有的狀態(tài)重置為0才可以
8 if w != 0 && delta > 0 && v == int32(delta) {
9 panic("sync: WaitGroup misuse: Add called concurrently with Wait")
10 }
11
12 // 情況3:本次Add操作只負責增加counter值,直接返回即可。
13 // 如果此時counter值大于0,喚醒的操作留給之后的Add調(diào)用者(執(zhí)行Add(negative int))
14 // 如果waiter值為0,代表此時還沒有阻塞的waiter
15 if v > 0 || w == 0 {
16 return
17 }
18
19 // 情況4: misuse引起的panic
20 if *statep != state {
21 panic("sync: WaitGroup misuse: Add called concurrently with Wait")
22 }
關(guān)于 misuse 和 reused 引發(fā) panic 的情況,如果沒有示例錯誤代碼,其實是比較難解釋的。值得高興的是,在Go源碼中給出了錯誤使用示范,這些例子位于src/sync/waitgroup_test.go文件下,想深入了解的讀者可以去看以下三個測試函數(shù)中的示例。
1func TestWaitGroupMisuse(t *testing.T)
2func TestWaitGroupMisuse2(t *testing.T)
3func TestWaitGroupMisuse3(t *testing.T)
4
Done函數(shù)
Done()函數(shù)比較簡單,就是調(diào)用Add(-1)。在實際使用時,當子goroutine任務(wù)完成之后,就應(yīng)該調(diào)用Done()函數(shù)。
1func (wg *WaitGroup) Done() {
2 wg.Add(-1)
3}
Wait函數(shù)
如果WaitGroup中的counter值大于0,那么執(zhí)行Wait()函數(shù)的主goroutine會將waiter值加1,并阻塞等待該值為0,才能繼續(xù)執(zhí)行后續(xù)代碼。
我們將競態(tài)檢測部分的代碼去掉,Wait()函數(shù)的實現(xiàn)源碼如下
1func (wg *WaitGroup) Wait() {
2 statep, semap := wg.state()
3 for {
4 state := atomic.LoadUint64(statep) // 原子讀取復(fù)合狀態(tài)statep
5 v := int32(state >> 32) // 獲取counter值
6 w := uint32(state) // 獲取waiter值
7 // 如果此時v==0,證明已經(jīng)沒有待執(zhí)行任務(wù)的子goroutine,直接退出即可。
8 if v == 0 {
9 return
10 }
11 // 如果在執(zhí)行CAS原子操作和讀取復(fù)合狀態(tài)之間,沒有其他goroutine更改了復(fù)合狀態(tài)
12 // 那么就將waiter值+1,否則:進入下一輪循環(huán),重新讀取復(fù)合狀態(tài)
13 if atomic.CompareAndSwapUint64(statep, state, state+1) {
14 // 對waiter值累加成功后
15 // 等待Add函數(shù)中調(diào)用 runtime_Semrelease 喚醒自己
16 runtime_Semacquire(semap)
17 // reused 引發(fā)panic
18 // 在當前goroutine被喚醒時,由于喚醒自己的goroutine通過調(diào)用Add方法時
19 // 已經(jīng)通過 *statep = 0 語句做了重置操作
20 // 此時的復(fù)合狀態(tài)位不為0,就是因為還未等Waiter執(zhí)行完Wait,WaitGroup就已經(jīng)發(fā)生了復(fù)用
21 if *statep != 0 {
22 panic("sync: WaitGroup is reused before previous Wait has returned")
23 }
24 return
25 }
26 }
27}
總結(jié)
要看懂WaitGroup的源碼實現(xiàn),我們需要有一些前置知識,例如信號量、內(nèi)存對齊、原子操作、移位運算和指針轉(zhuǎn)換等。
但其實WaitGroup的實現(xiàn)思路還是蠻簡單的,通過結(jié)構(gòu)體字段state1維護了兩個計數(shù)器和一個信號量,計數(shù)器分別是通過Add()添加的子goroutine的計數(shù)值counter,通過Wait()陷入阻塞的waiter數(shù),信號量用于阻塞與喚醒Waiter。當執(zhí)行Add(positive n)時,counter +=n,表明新增n個子goroutine執(zhí)行任務(wù)。每個子goroutine完成任務(wù)之后,需要調(diào)用Done()函數(shù)將counter值減1,當最后一個子goroutine完成時,counter值會是0,此時就需要喚醒阻塞在Wait()調(diào)用中的Waiter。
但是,在使用WaitGroup時,有幾點需要注意
通過
Add()函數(shù)添加的counter數(shù)一定要與后續(xù)通過Done()減去的數(shù)值一致。如果前者大,那么阻塞在Wait()調(diào)用處的goroutine將永遠得不到喚醒;如果后者大,將會引發(fā)panic。Add()的增量函數(shù)應(yīng)該最先得到執(zhí)行。不要對WaitGroup對象進行復(fù)制使用。
如果要復(fù)用WaitGroup,則必須在所有先前的
Wait()調(diào)用返回之后再進行新的Add()調(diào)用。
推薦閱讀
