【深度學(xué)習(xí)】在PyTorch中使用Datasets和DataLoader來(lái)定制文本數(shù)據(jù)
作者 | Jake Wherlock
作者 | Jake Wherlock
編譯 | VK
來(lái)源 | Towards Data Science

創(chuàng)建一個(gè)PyTorch數(shù)據(jù)集并使用Dataloader對(duì)其進(jìn)行管理,并有助于簡(jiǎn)化機(jī)器學(xué)習(xí)流程。Dataset存儲(chǔ)所有數(shù)據(jù),而Dataloader可用于迭代數(shù)據(jù)、管理批處理、轉(zhuǎn)換數(shù)據(jù)等等。
導(dǎo)入庫(kù)
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
Pandas對(duì)于創(chuàng)建數(shù)據(jù)集對(duì)象不是必需的。不過(guò),它是管理數(shù)據(jù)的強(qiáng)大工具,所以我將使用它。
torch.utils.data導(dǎo)入創(chuàng)建和使用Dataset和DataLoader所需的函數(shù)。
創(chuàng)建自定義數(shù)據(jù)集類
class CustomTextDataset(Dataset):
def __init__(self, txt, labels):
self.labels = labels
self.text = text
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
label = self.labels[idx]
text = self.text[idx]
sample = {"Text": text, "Class": label}
return sample
class CustomTextDataset(Dataset):創(chuàng)建一個(gè)名為“CustomTextDataset”的類,可以任意調(diào)用。傳入類的是我們前面導(dǎo)入的數(shù)據(jù)集模塊。
def init(self, text, labels):初始化類時(shí)需要導(dǎo)入兩個(gè)變量。在這種情況下,變量被稱為“Text”和“Class”,以匹配將要添加的數(shù)據(jù)。
self.labels = labels & self.text = text:導(dǎo)入的變量現(xiàn)在可以使用self.text或self.labels在類內(nèi)的函數(shù)中使用。
def len(self):這個(gè)函數(shù)在調(diào)用時(shí)只返回標(biāo)簽的長(zhǎng)度。例如,如果你有一個(gè)帶有5個(gè)標(biāo)簽的數(shù)據(jù)集,那么將返回整數(shù)5。
def getitem(self, idx):這個(gè)函數(shù)被Pytorch的Dataset模塊用來(lái)獲取樣本并構(gòu)建數(shù)據(jù)集。初始化時(shí),它將通過(guò)此函數(shù)循環(huán),從數(shù)據(jù)集中的每個(gè)實(shí)例創(chuàng)建一個(gè)樣本。
傳遞給函數(shù)的“idx”是一個(gè)數(shù)字,這個(gè)數(shù)字是數(shù)據(jù)集將遍歷的數(shù)據(jù)實(shí)例。我們使用self.labels和self.text提到的文本變量與“idx”變量一起傳入,以獲得當(dāng)前的數(shù)據(jù)實(shí)例。這些當(dāng)前實(shí)例被保存在變量' label '和' data '中。
接下來(lái),聲明一個(gè)名為‘sample’的變量,其中包含一個(gè)存儲(chǔ)數(shù)據(jù)的字典。在用數(shù)據(jù)初始化這個(gè)類之后,它將包含許多標(biāo)記為“Text”和“Class”的數(shù)據(jù)實(shí)例。你可以命名“Text”和“Class”任何東西。
初始化CustomTextDataset類
# 定義數(shù)據(jù)和類標(biāo)簽
text = ['Happy', 'Amazing', 'Sad', 'Unhapy', 'Glum']
labels = ['Positive', 'Positive', 'Negative', 'Negative', 'Negative']
# 創(chuàng)建數(shù)據(jù)幀
text_labels_df = pd.DataFrame({'Text': text, 'Labels': labels})
# 定義數(shù)據(jù)集對(duì)象
TD = CustomTextDataset(text_labels_df['Text'], text_labels_df['Labels'])
首先,我們創(chuàng)建兩個(gè)名為“text”和“l(fā)abels”的列作為示例。
text_labels_df = pd.DataFrame({‘Text’: text, ‘Labels’: labels}):不是必需的,但是Pandas是數(shù)據(jù)管理和預(yù)處理的有用工具,可能會(huì)在PyTorch管道中使用。在本節(jié)中,包含數(shù)據(jù)的列表“Text”和“Labels”保存在數(shù)據(jù)框中。
TD = CustomTextDataset(text_labels_df[‘Text’], text_labels_df[‘Labels’]):這將初始化我們前面創(chuàng)建的類,并傳入'text'和'labels'數(shù)據(jù)。此數(shù)據(jù)將在類中變?yōu)椤皊elf.text”和“self.labels”。數(shù)據(jù)集保存在名為TD的變量下。
數(shù)據(jù)集現(xiàn)在已經(jīng)初始化,可以使用了!
一些代碼顯示數(shù)據(jù)集中發(fā)生了什么
這將向你展示數(shù)據(jù)是如何存儲(chǔ)在數(shù)據(jù)集中的。
# 顯示文本和標(biāo)簽。
print('\nFirst iteration of data set: ', next(iter(TD)), '\n')
# 打印數(shù)據(jù)集中的項(xiàng)目數(shù)
print('Length of data set: ', len(TD), '\n')
# 打印整個(gè)數(shù)據(jù)集
print('Entire data set: ', list(DataLoader(TD)), '\n')
輸出:
數(shù)據(jù)集的第一次迭代:{'Text':'Happy','Class':'Positive'}
數(shù)據(jù)集長(zhǎng)度:5
整個(gè)數(shù)據(jù)集:[{‘Text’: [‘Happy’], ‘Class’: [‘Positive’]}, {‘Text’: [‘Amazing’], ‘Class’: [‘Positive’]}, {‘Text’: [‘Sad’], ‘Class’: [‘Negative’]}, {‘Text’: [‘Unhapy’], ‘Class’: [‘Negative’]}, {‘Text’: [‘Glum’], ‘Class’: [‘Negative’]}]
使用“collate_fn”預(yù)處理數(shù)據(jù)
在機(jī)器學(xué)習(xí)或深度學(xué)習(xí)中,在訓(xùn)練之前需要對(duì)文本進(jìn)行清理并將其轉(zhuǎn)化為向量。DataLoader有一個(gè)方便的參數(shù)collate_fn。此參數(shù)允許你創(chuàng)建單獨(dú)的數(shù)據(jù)處理函數(shù),并在輸出數(shù)據(jù)之前將該函數(shù)中的處理應(yīng)用于數(shù)據(jù)。
def collate_batch(batch):
word_tensor = torch.tensor([[1.], [0.], [45.]])
label_tensor = torch.tensor([[1.]])
text_list, classes = [], []
for (_text, _class) in batch:
text_list.append(word_tensor)
classes.append(label_tensor)
text = torch.cat(text_list)
classes = torch.tensor(classes)
return text, classes
DL_DS = DataLoader(TD, batch_size=2, collate_fn=collate_batch)
例如,創(chuàng)建了兩個(gè)表示單詞和類的張量。實(shí)際上,這些可以是通過(guò)另一個(gè)函數(shù)傳入的單詞向量。然后將批處理解包,然后將單詞和標(biāo)簽張量添加到列表中。
然后將單詞張量串聯(lián)起來(lái),并將類張量列表(在本例中為1)組合成單個(gè)張量。該函數(shù)現(xiàn)在將返回已處理的文本數(shù)據(jù),以便進(jìn)行訓(xùn)練。
要激活此函數(shù),只需在初始化DataLoader對(duì)象時(shí)添加參數(shù)collate_fn=Your_Function_name。
訓(xùn)練模型時(shí)如何遍歷數(shù)據(jù)集
我們將在不使用collate_fn的情況下遍歷數(shù)據(jù)集,因?yàn)樗菀卓吹紻ataLoader如何輸出單詞和類。如果上述函數(shù)與collate_fn一起使用,則輸出將是張量。
DL_DS = DataLoader(TD, batch_size=2, shuffle=True)
for (idx, batch) in enumerate(DL_DS):
# 打印batch中的“text”數(shù)據(jù)
print(idx, 'Text data: ', batch['Text'])
# 打印batch中的"Class”數(shù)據(jù)
print(idx, 'Class data: ', batch['Class'], '\n')
DL_DS = DataLoader(TD, batch_size=2, shuffle=True) :這用我們剛剛創(chuàng)建的Dataset對(duì)象“TD”初始化DataLoader。
在本例中,批大小設(shè)置為2。這意味著當(dāng)你遍歷數(shù)據(jù)集時(shí),DataLoader將輸出2個(gè)數(shù)據(jù)實(shí)例,而不是一個(gè)。有關(guān)批處理的更多信息,請(qǐng)參閱本文:https://machinelearningmastery.com/difference-between-a-batch-and-an-epoch/。Shuffle將在每個(gè)epoch對(duì)數(shù)據(jù)進(jìn)行隨機(jī)化,這將阻止模型學(xué)習(xí)訓(xùn)練數(shù)據(jù)的順序。
for (idx, batch) in enumerate(DL_DS): 遍歷我們剛剛創(chuàng)建的DataLoader對(duì)象中的數(shù)據(jù)。enumerate(DL_DS)返回批的索引號(hào)和由兩個(gè)數(shù)據(jù)實(shí)例。
輸出:

如你所見,我們創(chuàng)建的5個(gè)數(shù)據(jù)實(shí)例是以2個(gè)為一個(gè)batch的方式輸出的。由于我們有奇數(shù)個(gè)訓(xùn)練示例,最后一個(gè)batch大小是1。
完整代碼
# 導(dǎo)入庫(kù)
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
# 創(chuàng)建自定義數(shù)據(jù)集類
class CustomTextDataset(Dataset):
def __init__(self, text, labels):
self.labels = labels
self.text = text
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
label = self.labels[idx]
data = self.text[idx]
sample = {"Text": data, "Class": label}
return sample
# 定義數(shù)據(jù)和類標(biāo)簽
text = ['Happy', 'Amazing', 'Sad', 'Unhapy', 'Glum']
labels = ['Positive', 'Positive', 'Negative', 'Negative', 'Negative']
# 創(chuàng)建Pandas DataFrame
text_labels_df = pd.DataFrame({'Text': text, 'Labels': labels})
# 定義數(shù)據(jù)集對(duì)象
TD = CustomTextDataset(text_labels_df['Text'], text_labels_df['Labels'])
# 顯示圖像和標(biāo)簽
print('\nFirst iteration of data set: ', next(iter(TD)), '\n')
# 打印數(shù)據(jù)集中有多少項(xiàng)
print('Length of data set: ', len(TD), '\n')
# 打印整個(gè)數(shù)據(jù)集
print('Entire data set: ', list(DataLoader(TD)), '\n')
# collate_fn
def collate_batch(batch):
word_tensor = torch.tensor([[1.], [0.], [45.]])
label_tensor = torch.tensor([[1.]])
text_list, classes = [], []
for (_text, _class) in batch:
text_list.append(word_tensor)
classes.append(label_tensor)
text = torch.cat(text_list)
classes = torch.tensor(classes)
return text, classes
# 創(chuàng)建數(shù)據(jù)集對(duì)象的DataLoader對(duì)象
bat_size = 2
DL_DS = DataLoader(TD, batch_size=bat_size, shuffle=True)
# 循環(huán)遍歷DataLoader對(duì)象中的每個(gè)batch
for (idx, batch) in enumerate(DL_DS):
# 打印“text”數(shù)據(jù)
print(idx, 'Text data: ', batch, '\n')
# 打印“Class”數(shù)據(jù)
print(idx, 'Class data: ', batch, '\n')
往期精彩回顧 本站qq群851320808,加入微信群請(qǐng)掃碼:
