10條PyTorch避坑指南
點擊上方“小白學視覺”,選擇加"星標"或“置頂”
重磅干貨,第一時間送達

本文轉(zhuǎn)載自:機器之心 | 作者:Eugene Khvedchenya
高性能 PyTorch 的訓練管道是什么樣的?是產(chǎn)生最高準確率的模型?是最快的運行速度?是易于理解和擴展?還是容易并行化?答案是,包括以上提到的所有。

建議 0:了解你代碼中的瓶頸在哪里
建議 1:如果可能的話,將數(shù)據(jù)的全部或部分移至 RAM。
class RAMDataset(Dataset):def __init__(image_fnames, targets):self.targets = targetsself.images = []for fname in tqdm(image_fnames, desc="Loading files in RAM"):with open(fname, "rb") as f:self.images.append(f.read())def __len__(self):return len(self.targets)def __getitem__(self, index):target = self.targets[index]image, retval = cv2.imdecode(self.images[index], cv2.IMREAD_COLOR)return image, target
建議 2:解析、度量、比較。每次你在管道中提出任何改變,要深入地評估它全面的影響。
# Profile CPU bottleneckspython -m cProfile training_script.py --profiling# Profile GPU bottlenecksnvprof --print-gpu-trace python train_mnist.py# Profile system calls bottlenecksstrace -fcT python training_script.py -e trace=open,close,readAdvice 3: *Preprocess everything offline*
建議 3:離線預處理所有內(nèi)容
建議 4:調(diào)整 DataLoader 的工作程序
假設我們?yōu)?Cityscapes 訓練圖像分割模型,其批處理大小為 32,RGB 圖像大小是 512x512x3(高、寬、通道)。我們在 CPU 端進行圖像標準化(稍后我將會解釋為什么這一點比較重要)。在這種情況下,我們最終的圖像 tensor 將會是 512 * 512 * 3 * sizeof(float32) = 3,145,728 字節(jié)。與批處理大小相乘,結(jié)果是 100,663,296 字節(jié),大約 100Mb;
除了圖像之外,我們還需要提供 ground-truth 掩膜。它們各自的大小為(默認情況下,掩膜的類型是 long,8 個字節(jié))——512 * 512 * 1 * 8 * 32 = 67,108,864 或者大約 67Mb;
因此一批數(shù)據(jù)所需要的總內(nèi)存是 167Mb。假設有 8 個工作程序,內(nèi)存的總需求量將是 167 Mb * 8 = 1,336 Mb。
將 RGB 圖像保持在每個通道深度 8 位。可以輕松地在 GPU 上將圖像轉(zhuǎn)換為浮點形式或者標準化。
在數(shù)據(jù)集中用 uint8 或 uint16 數(shù)據(jù)類型代替 long。
class MySegmentationDataset(Dataset):...def __getitem__(self, index):image = cv2.imread(self.images[index])target = cv2.imread(self.masks[index])# No data normalization and type casting herereturn torch.from_numpy(image).permute(2,0,1).contiguous(),torch.from_numpy(target).permute(2,0,1).contiguous()class Normalize(nn.Module):# https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/modules/normalize.pydef __init__(self, mean, std):super().__init__()self.register_buffer("mean", torch.tensor(mean).float().reshape(1, len(mean), 1, 1).contiguous())self.register_buffer("std", torch.tensor(std).float().reshape(1, len(std), 1, 1).reciprocal().contiguous())def forward(self, input: torch.Tensor) -> torch.Tensor:return (input.to(self.mean.type) - self.mean) * self.stdclass MySegmentationModel(nn.Module):def __init__(self):self.normalize = Normalize([0.221 * 255], [0.242 * 255])self.loss = nn.CrossEntropyLoss()def forward(self, image, target):image = self.normalize(image)output = self.backbone(image)if target is not None:loss = self.loss(output, target.long())return lossreturn output

model = nn.DataParallel(model) # Runs model on all available GPUsGPU 負載不平衡;
在主 GPU 上聚合需要額外的視頻內(nèi)存
在訓練期間繼續(xù)在前向推導內(nèi)使用 nn.DataParallel 計算損耗。在這種情況下。za 不會將密集的預測掩碼返回給主 GPU,而只會返回單個標量損失;
使用分布式訓練,也稱為 nn.DistributedDataParallel。借助分布式訓練的另一個好處是可以看到 GPU 實現(xiàn) 100% 負載。
https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255
https://medium.com/@theaccelerators/learn-pytorch-multi-gpu-properly-3eb976c030ee
https://towardsdatascience.com/how-to-scale-training-on-multiple-gpus-dae1041f49d2
建議 5: 如果你擁有兩個及以上的 GPU
def test_loss_profiling():loss = nn.BCEWithLogitsLoss()with torch.autograd.profiler.profile(use_cuda=True) as prof:input = torch.randn((8, 1, 128, 128)).cuda()input.requires_grad = Truetarget = torch.randint(1, (8, 1, 128, 128)).cuda().float()for i in range(10):l = loss(input, target)l.backward()print(prof.key_averages().table(sort_by="self_cpu_time_total"))
建議 9: 如果設計自定義模塊和損失——配置并測試他們
通過硬件升級可以更輕松地解決某些瓶頸。
下載1:OpenCV-Contrib擴展模塊中文版教程
在「小白學視覺」公眾號后臺回復:擴展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴展模塊教程中文版,涵蓋擴展模塊安裝、SFM算法、立體視覺、目標跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。
下載2:Python視覺實戰(zhàn)項目52講 在「小白學視覺」公眾號后臺回復:Python視覺實戰(zhàn)項目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個視覺實戰(zhàn)項目,助力快速學校計算機視覺。
下載3:OpenCV實戰(zhàn)項目20講 在「小白學視覺」公眾號后臺回復:OpenCV實戰(zhàn)項目20講,即可下載含有20個基于OpenCV實現(xiàn)20個實戰(zhàn)項目,實現(xiàn)OpenCV學習進階。
交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計算攝影、檢測、分割、識別、醫(yī)學影像、GAN、算法競賽等微信群(以后會逐漸細分),請掃描下面微信號加群,備注:”昵稱+學校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進入相關微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~
