OOM?教你如何在PyTorch更高效地利用顯存

極市導(dǎo)讀
本文介紹了如何在不減少輸入數(shù)據(jù)尺寸以及BatchSize的情況下,進一步榨干GPU的顯存。 >>加入極市CV技術(shù)交流群,走在計算機視覺的最前沿
引言
Out Of Memory, 一個煉丹師們熟悉得不能再熟悉的異常,其解決方法也很簡單,減少輸入圖像的尺寸或者Batch Size就好了。但是,且不說輸入尺寸對模型精度的影響,當(dāng)BatchSize過小的時候網(wǎng)絡(luò)甚至無法收斂的。
下圖來源知乎,深度學(xué)習(xí)中的batch的大小對學(xué)習(xí)效果有何影響?[1]

作者使用LeNet在MNIST數(shù)據(jù)集上進行測試,驗證不同大小的BatchSize對訓(xùn)練結(jié)果的影響。我們可以看到,雖然說BatchSize并不是越大越好,但是過小的BatchSize的結(jié)果往往更差甚至無法收斂。因此本文將會介紹如何在不減少輸入數(shù)據(jù)尺寸以及BatchSize的情況下,進一步榨干GPU的顯存。
什么在占用顯存
顯存主要是被以下三部分內(nèi)容占用:1、網(wǎng)絡(luò)模型,2、模型計算的過程中的中間變量,3、框架自身的顯存開銷。
網(wǎng)絡(luò)模型的占用的顯存主要是來自于所有有參數(shù)的層,包括:卷積、全連接、BN等;而不占用顯存的有:激活函數(shù)、池化層以及Dropout等。 計算過程中產(chǎn)生的顯存主要有:優(yōu)化器、中間過程的特征圖、backward過程產(chǎn)生的參數(shù) 而框架自身的顯存開銷一般不大,并且我們也不好優(yōu)化,所以我們只能考慮從前面兩點對顯存進行優(yōu)化,針對這兩個部分,本文接下來將會介紹常用的顯存占用優(yōu)化策略。
模型顯存優(yōu)化
盡量使用Inplace
在PyTorch中,inplce操作指的是改變一個tensor值的時候,不經(jīng)過復(fù)制操作而是直接在原來的內(nèi)存上修改它的值,也就是原地操作?;旧希刑峁﹊nplace參數(shù)的操作都可以使用inplace,并且官方文檔也說了,如果你使用了inplace operation而沒有報錯的話,那么你可以確定你的梯度計算是正確的。

pytorch中所有inplace操作一般都是以_為后綴,如tensor.add_()、tensor.scatter_()等。除了自帶的一些函數(shù)提供inplace操作外,一些運算法也存在inplace操作。如上圖展示的是兩個向量相加操作使用inplace與否的區(qū)別,但是要注意寫法:
x = x+y屬于case 1x += y屬于case2同理, *=也是inplace操作
盡量少產(chǎn)生中間結(jié)果
下面兩份代碼,效果是一樣的,但是占用顯存卻是不一樣的。
不推薦寫法
def forward(self, x):out_1 = self.conv_1(x)out_2 = self.conv_2(out_1)out_3 = self.conv_3(out_2)return out_3
推薦寫法
def forward(self, x):x = self.conv_1(x)x = self.conv_2(x)x = self.conv_3(x)return x
不需要的中間變量盡可能的都是用一個變量來代替,因為這些變量都是會占用顯存的。因此,網(wǎng)絡(luò)中如果存在一些較長的連接(比如第10層的網(wǎng)絡(luò)需要使用來自網(wǎng)絡(luò)第一層的輸出結(jié)果),這部分的特征圖就會一直占用顯存。
不使用過大全連接
相比于卷積的參數(shù),全連接的參數(shù)量可就大多了。因為卷積只是一個局部的連接,而全連接則是一個全局的連接。舉個栗子:卷積的參數(shù)只與輸出的通道數(shù)、卷積核大小相關(guān)。在不考慮偏置的情況下,卷積核大小為3,輸入通道為32,輸出通道數(shù)為64的時候,參數(shù)量大小為
而使用全連接的參數(shù)與輸入的通道數(shù)以及特征圖的尺寸是相關(guān)的。其計算方式如下:因為使用全連接前我們需要將特征圖flatten成一維向量,假如輸入特征圖的大小為512,通道數(shù)為32,在輸出尺寸不變、沒有偏置的前提下,第一層全連接參數(shù)量為:
所以一般來說,特征圖比較大的時候,直接用全連接顯卡會直接冒煙。因此往往只能在深層或者特征進行壓縮之后才能夠使用全連接。比如像SENet中,就是先將特征圖使用GAP(Global Average Pooling)之后,才使用全連接,并且在全連接的中間層還是用了一定的壓縮倍率。亦或者可以像ECA-Net那般,不使用全連接,采用鄰域連接的方式來減少計算量。
計算過程優(yōu)化
使用checkpoint
PyTorch在0.4版本后推出了一個新功能,可以將一個模型的計算過程分為兩半。也就是說,如果一個模型訓(xùn)練需要占用的顯存太大,可以先計算網(wǎng)絡(luò)的一半,保存后半部分所需要的中獎結(jié)果,再計算后半部分。當(dāng)然,這樣的操作顯然是一個犧牲時間換空間的方法,其使用方式如下:
# 常規(guī)寫法def forward(self, x):x = self.conv_1(x)x = self.conv_2(x)x = self.conv_3(x)return x# 引入checkpointfrom torch.utils.checkpoint import checkpointdef forward(self, x):x = checkpoint(self.conv_1(x), x)x = checkpoint(self.conv_2(x), x)x = checkpoint(self.conv_3(x),x)return x
梯度累加
大多數(shù)情況下,其實我們降低顯存就是為了獲得更大的Batchsize,因此使用gradient accumulation(梯度累加)也可以達到類似的效果。一般來說,我們使用pytorch寫網(wǎng)絡(luò)的訓(xùn)練過程主要是下面這個流程:
for i in range(epochs):optimizer.zero_grad() # 梯度清零outputs = network(input) # 正向傳播loss = criterion(output, label) # 計算損失loss.backward() # 反向傳播,計算梯度optimizer.step() # 更新參數(shù)
而梯度累加的代碼則只需要多一步:
for i in range(epochs):optimizer.zero_grad() # 梯度清零outputs = network(input) # 正向傳播loss = criterion(output, label)/accumulation_stepsif (i+1) % accumulation_steps == 0:optimizer.step() # 更新參數(shù)optimizer.zero_grad() # 梯度清零
通過這種方法能夠比較簡單的在有限的內(nèi)存下模擬更大batchsize 的效果,并且效果也比較接近。
降低計算精度
PyTorch中,所有Tensor默認(rèn)的精度都是FP32,也就是說每一個浮點型參數(shù)都需要占用32bit的顯存。因此,如果直接把精度降低到FP16,那理論上直接就能減少一半的顯存占用。那么,古爾丹,代價是什么呢?代價就是,在反向傳播的過程中,大多數(shù)更新值都非常小但不為零。反向傳播的舍入誤差可以把這些數(shù)字變成0或者nans,使得梯度更新不準(zhǔn)確,影響網(wǎng)絡(luò)的收斂。ICLR2018論文中Mixed Precision Training[2]發(fā)現(xiàn),使用FP16進行訓(xùn)練的網(wǎng)絡(luò)約有5%的梯度都會被“吞掉”。

在PyTorch1.6之前,降低訓(xùn)練進度普遍使用的都是NVIDIA提供的apex庫。而在1.6版本之后,PyTorch推出了AMP(Automatic mixed precision),自動混合精度訓(xùn)練。這套技術(shù)并不是簡單的將所有的參數(shù)降低精度,而是根據(jù)不同向量的不同操作對于誤差的敏感程度來決定其使用的是FP16還是FP32。其使用起來也十分簡單,下面是一個簡單的例子,代碼參考知乎[3]:
from torch.cuda.amp import autocast,GradScaler# 創(chuàng)建model,默認(rèn)是torch.FloatTensormodel = Net().cuda()optimizer = optim.SGD(model.parameters(), ...)# 在訓(xùn)練最開始之前實例化一個GradScaler對象scaler = GradScaler()for epoch in epochs:for input, target in data:optimizer.zero_grad()# 前向過程(model + loss)開啟 autocastwith autocast():output = model(input)loss = loss_fn(output, target)# Scales loss. 為了梯度放大.scaler.scale(loss).backward()# scaler.step() 首先把梯度的值unscale回來.# 如果梯度的值不是 infs 或者 NaNs, 那么調(diào)用optimizer.step()來更新權(quán)重,# 否則,忽略step調(diào)用,從而保證權(quán)重不更新(不被破壞)scaler.step(optimizer)# 準(zhǔn)備著,看是否要增大scalerscaler.update()
終極解決辦法
加錢
References
[1]https://www.zhihu.com/question/32673260/answer/71137399
[2]https://arxiv.org/abs/1710.03740
[3]https://zhuanlan.zhihu.com/p/165152789
本文亮點總結(jié)
如果覺得有用,就請分享到朋友圈吧!
公眾號后臺回復(fù)“目標(biāo)跟蹤”獲取目標(biāo)跟蹤綜述~

# CV技術(shù)社群邀請函 #
備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測-深圳)
即可申請加入極市目標(biāo)檢測/圖像分割/工業(yè)檢測/人臉/醫(yī)學(xué)影像/3D/SLAM/自動駕駛/超分辨率/姿態(tài)估計/ReID/GAN/圖像增強/OCR/視頻理解等技術(shù)交流群
每月大咖直播分享、真實項目需求對接、求職內(nèi)推、算法競賽、干貨資訊匯總、與 10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動交流~

