【小白學(xué)習(xí)PyTorch教程】十七、 PyTorch 中 數(shù)據(jù)集torchvision和torcht...
「@Author:Runsen」
對于PyTorch加載和處理不同類型數(shù)據(jù),官方提供了torchvision和torchtext。
之前使用 torchDataLoader類直接加載圖像并將其轉(zhuǎn)換為張量?,F(xiàn)在結(jié)合torchvision和torchtext介紹torch中的內(nèi)置數(shù)據(jù)集
Torchvision 中的數(shù)據(jù)集
MNIST
MNIST 是一個由標(biāo)準(zhǔn)化和中心裁剪的手寫圖像組成的數(shù)據(jù)集。它有超過 60,000 張訓(xùn)練圖像和 10,000 張測試圖像。這是用于學(xué)習(xí)和實驗?zāi)康淖畛S玫臄?shù)據(jù)集之一。要加載和使用數(shù)據(jù)集,使用以下語法導(dǎo)入:torchvision.datasets.MNIST()。
Fashion MNIST
Fashion MNIST數(shù)據(jù)集類似于MNIST,但該數(shù)據(jù)集包含T恤、褲子、包包等服裝項目,而不是手寫數(shù)字,訓(xùn)練和測試樣本數(shù)分別為60,000和10,000。要加載和使用數(shù)據(jù)集,使用以下語法導(dǎo)入:torchvision.datasets.FashionMNIST()
CIFAR
CIFAR數(shù)據(jù)集有兩個版本,CIFAR10和CIFAR100。CIFAR10 由 10 個不同標(biāo)簽的圖像組成,而 CIFAR100 有 100 個不同的類。這些包括常見的圖像,如卡車、青蛙、船、汽車、鹿等。
torchvision.datasets.CIFAR10()
torchvision.datasets.CIFAR100()
COCO
COCO數(shù)據(jù)集包含超過 100,000 個日常對象,如人、瓶子、文具、書籍等。這個圖像數(shù)據(jù)集廣泛用于對象檢測和圖像字幕應(yīng)用。下面是可以加載 COCO 的位置:torchvision.datasets.CocoCaptions()
EMNIST
EMNIST數(shù)據(jù)集是 MNIST 數(shù)據(jù)集的高級版本。它由包括數(shù)字和字母的圖像組成。如果您正在處理基于從圖像中識別文本的問題,EMNIST是一個不錯的選擇。下面是可以加載 EMNIST的位置::torchvision.datasets.EMNIST()
IMAGE-NET
ImageNet 是用于訓(xùn)練高端神經(jīng)網(wǎng)絡(luò)的旗艦數(shù)據(jù)集之一。它由分布在 10,000 個類別中的超過 120 萬張圖像組成。通常,這個數(shù)據(jù)集加載在高端硬件系統(tǒng)上,因為單獨的 CPU 無法處理這么大的數(shù)據(jù)集。下面是加載 ImageNet 數(shù)據(jù)集的類:torchvision.datasets.ImageNet()
Torchtext 中的數(shù)據(jù)集
IMDB
IMDB是一個用于情感分類的數(shù)據(jù)集,其中包含一組 25,000 條高度極端的電影評論用于訓(xùn)練,另外 25,000 條用于測試。使用以下類加載這些數(shù)據(jù)torchtext:torchtext.datasets.IMDB()
WikiText2
WikiText2語言建模數(shù)據(jù)集是一個超過 1 億個標(biāo)記的集合。它是從維基百科中提取的,并保留了標(biāo)點符號和實際的字母大小寫。它廣泛用于涉及長期依賴的應(yīng)用程序??梢詮膖orchtext以下位置加載此數(shù)據(jù):torchtext.datasets.WikiText2()
除了上述兩個流行的數(shù)據(jù)集,torchtext庫中還有更多可用的數(shù)據(jù)集,例如 SST、TREC、SNLI、MultiNLI、WikiText-2、WikiText103、PennTreebank、Multi30k 等。
深入查看 MNIST 數(shù)據(jù)集
MNIST 是最受歡迎的數(shù)據(jù)集之一。現(xiàn)在我們將看到 PyTorch 如何從 pytorch/vision 存儲庫加載 MNIST 數(shù)據(jù)集。讓我們首先下載數(shù)據(jù)集并將其加載到名為 的變量中data_train
from?torchvision.datasets?import?MNIST
#?Download?MNIST?
data_train?=?MNIST('~/mnist_data',?train=True,?download=True)
import?matplotlib.pyplot?as?plt
random_image?=?data_train[0][0]
random_image_label?=?data_train[0][1]
#?Print?the?Image?using?Matplotlib
plt.imshow(random_image)
print("The?label?of?the?image?is:",?random_image_label)
DataLoader加載MNIST
下面我們使用DataLoader該類加載數(shù)據(jù)集,如下所示。
import?torch
from?torchvision?import?transforms
data_train?=?torch.utils.data.DataLoader(
????MNIST(
??????????'~/mnist_data',?train=True,?download=True,?
??????????transform?=?transforms.Compose([
??????????????transforms.ToTensor()
??????????])),
??????????batch_size=64,
??????????shuffle=True
??????????)
for?batch_idx,?samples?in?enumerate(data_train):
??????print(batch_idx,?samples)
CUDA加載
我們可以啟用 GPU 來更快地訓(xùn)練我們的模型?,F(xiàn)在讓我們使用CUDA加載數(shù)據(jù)時可以使用的(GPU 支持 PyTorch)的配置。
device?=?"cuda"?if?torch.cuda.is_available()?else?"cpu"
kwargs?=?{'num_workers':?1,?'pin_memory':?True}?if?device=='cuda'?else?{}
train_loader?=?torch.utils.data.DataLoader(
??torchvision.datasets.MNIST('/files/',?train=True,?download=True),
??batch_size=batch_size_train,?**kwargs)
test_loader?=?torch.utils.data.DataLoader(
??torchvision.datasets.MNIST('files/',?train=False,?download=True),
??batch_size=batch_size,?**kwargs)
ImageFolder
ImageFolder是一個通用數(shù)據(jù)加載器類torchvision,可幫助加載自己的圖像數(shù)據(jù)集。處理一個分類問題并構(gòu)建一個神經(jīng)網(wǎng)絡(luò)來識別給定的圖像是apple還是orange。要在 PyTorch 中執(zhí)行此操作,第一步是在默認(rèn)文件夾結(jié)構(gòu)中排列圖像,如下所示:
root
├──?orange
│???├──?orange_image1.png
│???└──?orange_image1.png
├──?apple
│???└──?apple_image1.png
│???└──?apple_image2.png
│???└──?apple_image3.png
可以使用ImageLoader該類加載所有這些圖像。
torchvision.datasets.ImageFolder(root,?transform)
transforms
PyTorch 轉(zhuǎn)換定義了簡單的圖像轉(zhuǎn)換技術(shù),可將整個數(shù)據(jù)集轉(zhuǎn)換為獨特的格式。
如果是一個包含不同分辨率的不同汽車圖片的數(shù)據(jù)集,在訓(xùn)練時,我們訓(xùn)練數(shù)據(jù)集中的所有圖像都應(yīng)該具有相同的分辨率大小。如果我們手動將所有圖像轉(zhuǎn)換為所需的輸入大小,則很耗時,因此我們可以使用transforms;使用幾行 PyTorch 代碼,我們數(shù)據(jù)集中的所有圖像都可以轉(zhuǎn)換為所需的輸入大小和分辨率。
現(xiàn)在讓我們加載 CIFAR10torchvision.datasets并應(yīng)用以下轉(zhuǎn)換:
- 將所有圖像調(diào)整為 32×32
- 對圖像應(yīng)用中心裁剪變換
- 將裁剪后的圖像轉(zhuǎn)換為張量
- 標(biāo)準(zhǔn)化圖像
import?torch
import?torchvision
import?torchvision.transforms?as?transforms
import?matplotlib.pyplot?as?plt
import?numpy?as?np
transform?=?transforms.Compose([
????#?resize?32×32
????transforms.Resize(32),
????#?center-crop裁剪變換
????transforms.CenterCrop(32),
????#?to-tensor
????transforms.ToTensor(),
????#?normalize?標(biāo)準(zhǔn)化
????transforms.Normalize([0.5,?0.5,?0.5],?[0.5,?0.5,?0.5])
])
trainset?=?torchvision.datasets.CIFAR10(root='./data',?train=True,
????????????????????????????????????????download=True,?transform=transform)
trainloader?=?torch.utils.data.DataLoader(trainset,?batch_size=4,
??????????????????????????????????????????shuffle=False)
在 PyTorch 中創(chuàng)建自定義數(shù)據(jù)集
下面將創(chuàng)建一個由數(shù)字和文本組成的簡單自定義數(shù)據(jù)集。需要封裝Dataset 類中的__getitem__()和__len__()方法。
- _
_getitem__()方法通過索引返回數(shù)據(jù)集中的選定樣本。 __len__()方法返回數(shù)據(jù)集的總大小。
下面是曾經(jīng)封裝FruitImagesDataset數(shù)據(jù)集的代碼,基本是比較好的 PyTorch 中創(chuàng)建自定義數(shù)據(jù)集的模板。
import?os
import?numpy?as?np
import?cv2
import?torch
import?matplotlib.patches?as?patches
import?albumentations?as?A
from?albumentations.pytorch.transforms?import?ToTensorV2
from?matplotlib?import?pyplot?as?plt
from?torch.utils.data?import?Dataset
from?xml.etree?import?ElementTree?as?et
from?torchvision?import?transforms?as?torchtrans
class?FruitImagesDataset(torch.utils.data.Dataset):
????def?__init__(self,?files_dir,?width,?height,?transforms=None):
????????self.transforms?=?transforms
????????self.files_dir?=?files_dir
????????self.height?=?height
????????self.width?=?width
????????self.imgs?=?[image?for?image?in?sorted(os.listdir(files_dir))
?????????????????????if?image[-4:]?==?'.jpg']
????????self.classes?=?['_','apple',?'banana',?'orange']
????def?__getitem__(self,?idx):
????????img_name?=?self.imgs[idx]
????????image_path?=?os.path.join(self.files_dir,?img_name)
????????#?reading?the?images?and?converting?them?to?correct?size?and?color
????????img?=?cv2.imread(image_path)
????????img_rgb?=?cv2.cvtColor(img,?cv2.COLOR_BGR2RGB).astype(np.float32)
????????img_res?=?cv2.resize(img_rgb,?(self.width,?self.height),?cv2.INTER_AREA)
????????#?diving?by?255
????????img_res?/=?255.0
????????#?annotation?file
????????annot_filename?=?img_name[:-4]?+?'.xml'
????????annot_file_path?=?os.path.join(self.files_dir,?annot_filename)
????????boxes?=?[]
????????labels?=?[]
????????tree?=?et.parse(annot_file_path)
????????root?=?tree.getroot()
????????#?cv2?image?gives?size?as?height?x?width
????????wt?=?img.shape[1]
????????ht?=?img.shape[0]
????????#?box?coordinates?for?xml?files?are?extracted?and?corrected?for?image?size?given
????????for?member?in?root.findall('object'):
????????????labels.append(self.classes.index(member.find('name').text))
????????????#?bounding?box
????????????xmin?=?int(member.find('bndbox').find('xmin').text)
????????????xmax?=?int(member.find('bndbox').find('xmax').text)
????????????ymin?=?int(member.find('bndbox').find('ymin').text)
????????????ymax?=?int(member.find('bndbox').find('ymax').text)
????????????xmin_corr?=?(xmin?/?wt)?*?self.width
????????????xmax_corr?=?(xmax?/?wt)?*?self.width
????????????ymin_corr?=?(ymin?/?ht)?*?self.height
????????????ymax_corr?=?(ymax?/?ht)?*?self.height
????????????boxes.append([xmin_corr,?ymin_corr,?xmax_corr,?ymax_corr])
????????#?convert?boxes?into?a?torch.Tensor
????????boxes?=?torch.as_tensor(boxes,?dtype=torch.float32)
????????#?getting?the?areas?of?the?boxes
????????area?=?(boxes[:,?3]?-?boxes[:,?1])?*?(boxes[:,?2]?-?boxes[:,?0])
????????#?suppose?all?instances?are?not?crowd
????????iscrowd?=?torch.zeros((boxes.shape[0],),?dtype=torch.int64)
????????labels?=?torch.as_tensor(labels,?dtype=torch.int64)
????????target?=?{}
????????target["boxes"]?=?boxes
????????target["labels"]?=?labels
????????target["area"]?=?area
????????target["iscrowd"]?=?iscrowd
????????#?image_id
????????image_id?=?torch.tensor([idx])
????????target["image_id"]?=?image_id
????????if?self.transforms:
????????????sample?=?self.transforms(image=img_res,
?????????????????????????????????????bboxes=target['boxes'],
?????????????????????????????????????labels=labels)
????????????img_res?=?sample['image']
????????????target['boxes']?=?torch.Tensor(sample['bboxes'])
????????return?img_res,?target
????def?__len__(self):
????????return?len(self.imgs)
def?get_transform(train):
????if?train:
????????return?A.Compose([
????????????A.HorizontalFlip(0.5),
????????????ToTensorV2(p=1.0)
????????],?bbox_params={'format':?'pascal_voc',?'label_fields':?['labels']})
????else:
????????return?A.Compose([
????????????ToTensorV2(p=1.0)
????????],?bbox_params={'format':?'pascal_voc',?'label_fields':?['labels']})
files_dir?=?'../input/fruit-images-for-object-detection/train_zip/train'
test_dir?=?'../input/fruit-images-for-object-detection/test_zip/test'
dataset?=?FruitImagesDataset(train_dir,?480,?480)
