PyTorch深度學(xué)習(xí)訓(xùn)練可視化工具tensorboardX
點(diǎn)擊上方“小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時間送達(dá)
之前筆者提到了PyTorch的專屬可視化工具visdom,參看PyTorch深度學(xué)習(xí)訓(xùn)練可視化工具visdom。但在此之前很多TensorFlow用戶更習(xí)慣于使用TensorBoard來進(jìn)行訓(xùn)練的可視化展示。為了能讓PyTorch用戶也能用上TensorBoard,有開發(fā)者提供了PyTorch版本的TensorBoard,也就是tensorboardX。

?
???? 熟悉TensorBoard的用戶可以無縫對接到tensorboardX,安裝方式為:
pip install tensorboardX???? 除了要安裝PyTorch之外,還需要安裝TensorFlow。跟TensorBoard一樣,tensorboardX也支持scalar, image, figure, histogram, audio, text, graph, onnx_graph, embedding, pr_curve,video等不同類型對象的可視化展示方式。
tensorboardX和TensorBoard的啟動方式一樣,直接在終端下運(yùn)行:
tensorboard --logdir runs???? 然后另起一個終端執(zhí)行Python文件即可:
python demo.py???? 打開localhost:6006即可看到tensorboardX可視化界面。
???? tensorboardX本地啟動非常容易,但一般情況下我們訓(xùn)練都是在服務(wù)器上完成的, 所以要在遠(yuǎn)程啟動tensorboardX需要進(jìn)行一些簡單的設(shè)置。以虛擬機(jī)工具xshell為例:依此設(shè)置文件->屬性->ssh->隧道->添加,類型local,源主機(jī)填127.0.0.1(本機(jī)),端口設(shè)置一個,比如12345,目標(biāo)主機(jī)為服務(wù)器地址,目標(biāo)端口一般是6006,如果6006被占了可以改為其他端口。

???? 分別執(zhí)行tensorboard和python腳本后,本地打開127.0.0.1:12345即可進(jìn)入遠(yuǎn)程TensorBoard界面。
?
???? 以scalar為例來看一下tensorboardX的使用方式:
import numpy as npfrom tensorboardX import SummaryWriterwriter = SummaryWriter()for i in range(100):writer.add_scalar('data/scalar1', np.random.rand(), i)writer.add_scalar('data/scalar2', {'xsinx': i*np.sin(i), 'xcosx': i*np.cos(i)}, i)writer.close()
???? scalar可視化如下圖所示。

???? 一個完整tensorboardX 使用demo如下:
import torchimport torchvision.utils as vutilsimport numpy as npimport torchvision.models as modelsfrom torchvision import datasetsfrom tensorboardX import SummaryWriterresnet18 = models.resnet18(False)writer = SummaryWriter()sample_rate = 44100freqs = [262, 294, 330, 349, 392, 440, 440, 440, 440, 440, 440]for n_iter in range(100):dummy_s1 = torch.rand(1)dummy_s2 = torch.rand(1)# data grouping by `slash`writer.add_scalar('data/scalar1', dummy_s1[0], n_iter)writer.add_scalar('data/scalar2', dummy_s2[0], n_iter)writer.add_scalars('data/scalar_group', {'xsinx': n_iter * np.sin(n_iter),'xcosx': n_iter * np.cos(n_iter),'arctanx': np.arctan(n_iter)}, n_iter)dummy_img = torch.rand(32, 3, 64, 64) # output from networkif n_iter % 10 == 0:x = vutils.make_grid(dummy_img, normalize=True, scale_each=True)writer.add_image('Image', x, n_iter)dummy_audio = torch.zeros(sample_rate * 2)for i in range(x.size(0)):# amplitude of sound should in [-1, 1]dummy_audio[i] = np.cos(freqs[n_iter // 10] * np.pi * float(i) / float(sample_rate))writer.add_audio('myAudio', dummy_audio, n_iter, sample_rate=sample_rate)writer.add_text('Text', 'text logged at step:' + str(n_iter), n_iter)for name, param in resnet18.named_parameters():writer.add_histogram(name, param.clone().cpu().data.numpy(), n_iter)# needs tensorboard 0.4RC or laterwriter.add_pr_curve('xoxo', np.random.randint(2, size=100), np.random.rand(100), n_iter)dataset = datasets.MNIST('mnist', train=False, download=True)images = dataset.test_data[:100].float()label = dataset.test_labels[:100]features = images.view(100, 784)writer.add_embedding(features, metadata=label, label_img=images.unsqueeze(1))# export scalar data to JSON for external processingwriter.export_scalars_to_json("./all_scalars.json")writer.close()
可視化效果如下所示:

? 參考資料:
https://github.com/lanpa/tensorboardX
https://www.tensorflow.org/tensorboard
交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計(jì)算攝影、檢測、分割、識別、醫(yī)學(xué)影像、GAN、算法競賽等微信群(以后會逐漸細(xì)分),請掃描下面微信號加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三?+?上海交大?+?視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進(jìn)入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~
