Pytorch mixed precision 概述(混合精度)
點(diǎn)擊上方“程序員大白”,選擇“星標(biāo)”公眾號(hào)
重磅干貨,第一時(shí)間送達(dá)

01
import torchvisionimport torchimport torch.cuda.ampimport gcimport time# Timing utilitiesstart_time = Nonedef start_timer():global start_timegc.collect()torch.cuda.empty_cache()torch.cuda.reset_max_memory_allocated()torch.cuda.synchronize() # 同步后得出的時(shí)間才是實(shí)際運(yùn)行的時(shí)間start_time = time.time()def end_timer_and_print(local_msg):torch.cuda.synchronize()end_time = time.time()print("\n" + local_msg)print("Total execution time = {:.3f} sec".format(end_time - start_time))print("Max memory used by tensors = {} bytes".format(torch.cuda.max_memory_allocated()))num_batches = 50batch_size = 70epochs = 3# 隨機(jī)創(chuàng)建訓(xùn)練數(shù)據(jù)data = [torch.randn(batch_size, 3, 224, 224, device="cuda") for _ in range(num_batches)]targets = [torch.randint(0, 1000, size=(batch_size, ), device='cuda') for _ in range(num_batches)]# 創(chuàng)建一個(gè)模型net = torchvision.models.resnext50_32x4d().cuda()# 定義損失函數(shù)loss_fn = torch.nn.CrossEntropyLoss().cuda()# 定義優(yōu)化器opt = torch.optim.SGD(net.parameters(), lr=0.001)# 是否使用混合精度訓(xùn)練use_amp = True# Constructs scaler once, at the beginning of the convergence run, using default args.# If your network fails to converge with default GradScaler args, please file an issue.# The same GradScaler instance should be used for the entire convergence run.# If you perform multiple convergence runs in the same script, each run should use# a dedicated fresh GradScaler instance. GradScaler instances are lightweight.scaler = torch.cuda.amp.GradScaler(enabled=use_amp)start_timer()for epoch in range(epochs):for input, target in zip(data, targets):with torch.cuda.amp.autocast(enabled=use_amp):output = net(input)loss = loss_fn(output, target)# 放大loss Calls backward() on scaled loss to create scaled gradients.scaler.scale(loss).backward()# scaler.step() first unscales the gradients of the optimizer's assigned params.# If these gradients do not contain infs or NaNs, optimizer.step() is then called,# otherwise, optimizer.step() is skipped.scaler.step(opt)# Updates the scale for next iteration.scaler.update()opt.zero_grad(set_to_none=True) # set_to_none=True here can modestly improve performanceend_timer_and_print("Mixed precision:")02
混合精度測(cè)試






推薦閱讀
國(guó)產(chǎn)小眾瀏覽器因屏蔽視頻廣告,被索賠100萬(后續(xù))
年輕人“不講武德”:因看黃片上癮,把網(wǎng)站和786名女主播起訴了
關(guān)于程序員大白
程序員大白是一群哈工大,東北大學(xué),西湖大學(xué)和上海交通大學(xué)的碩士博士運(yùn)營(yíng)維護(hù)的號(hào),大家樂于分享高質(zhì)量文章,喜歡總結(jié)知識(shí),歡迎關(guān)注[程序員大白],大家一起學(xué)習(xí)進(jìn)步!
評(píng)論
圖片
表情

