【深度学习】基于PyTorch深度学习框架的序列图像数据装载器
作者 | Harsh Maheshwari
編譯 | VK
來(lái)源 | Towards Data Science
如今,深度學(xué)習(xí)和機(jī)器學(xué)習(xí)算法正在統(tǒng)治世界。PyTorch是最常用的深度學(xué)習(xí)框架之一,用于實(shí)現(xiàn)各種深度學(xué)習(xí)算法。另一方面,基于學(xué)習(xí)的方法本質(zhì)上需要一些帶注釋的訓(xùn)練數(shù)據(jù)集,這些數(shù)據(jù)集可以被模型用來(lái)提取輸入數(shù)據(jù)和標(biāo)簽之間的關(guān)系。為了給神經(jīng)網(wǎng)絡(luò)提供數(shù)據(jù),我們定義了一個(gè)數(shù)據(jù)加載器。
在這個(gè)博客中,我們將看到如何在PyTorch框架中為不同的數(shù)據(jù)集編寫(xiě)一個(gè)數(shù)據(jù)加載器。
圖像數(shù)據(jù)集的數(shù)據(jù)加載器
我們將致力于狗與貓的圖像分類(lèi)問(wèn)題。我們需要對(duì)給定的圖像進(jìn)行分類(lèi),數(shù)據(jù)集可以從這里下載:https://www.kaggle.com/c/dogs-vs-cats。訓(xùn)練數(shù)據(jù)集總共包含25000個(gè)圖像。因?yàn)檫@是一個(gè)分類(lèi)問(wèn)題,所以dog的標(biāo)簽是“0”,cat的標(biāo)簽是“1”。
讓我們從導(dǎo)入所有必需的庫(kù)開(kāi)始。
import?os from?PIL?import?Image import?torch from?torch.utils.data?import?DataLoader,?Dataset import?torchvision.transforms?as?transforms import?torch.nn?as?nnPyTorch框架的dataset類(lèi)被定義為一個(gè)類(lèi),其基本結(jié)構(gòu)如下
class?data(Dataset):def?__init__(self,?param1,?param2):#?函數(shù)在此處初始化def?__len__(self):#?函數(shù)返回?cái)?shù)據(jù)的長(zhǎng)度def?__getitem__(self,?index):#?一次提供一個(gè)項(xiàng)目這個(gè)類(lèi)的最終目的是使用函數(shù) __getitem__每次提供一個(gè)數(shù)據(jù)點(diǎn)。這是通過(guò)使用內(nèi)部傳遞給函數(shù)的索引完成的,使用Dataloader中定義的sampler函數(shù)(將在接下來(lái)的博客中討論)。
初始化數(shù)據(jù)集的對(duì)象時(shí),會(huì)調(diào)用函數(shù) __init__。在這里,你可以傳遞多個(gè)參數(shù),這些參數(shù)對(duì)于編寫(xiě) __getitem__非常有用。
函數(shù)用于返回?cái)?shù)據(jù)集的總長(zhǎng)度。在此基礎(chǔ)上,將生成索引,然后將其提供給getitem。
dog vs cat數(shù)據(jù)集的格式如下-:
data/-?dog_1.jpg-?dog_2.jpg.........-?cat_1.jpg-?cat_2.jpg.........現(xiàn)在我們已經(jīng)了解了編寫(xiě)數(shù)據(jù)加載器所需的組件,讓我們深入研究一下我們的用例。
class?data(Dataset):???def?__init__(self,?path,?transform):self.files?=?os.listdir(path)self.transform?=?transformself.path?=?path???def?__len__(self):return?len(self.files)???def?__getitem__(self,?index):filename?=?self.files[index]input?=?Image.open(os.path.join(self.path,?filename))label?=?0?if?filename.find("dog")>=0?else?1img_as_tensor?=?self.transform(input)return?img_as_tensor,?labeltransformations?=?transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])path?=?"./data" train_dataset?=?data(path,?transformations) dataloader?=?DataLoader(train_dataset,?batch_size=Train_Batch_Size,?shuffle=True)首先讓我們了解函數(shù)__init__。類(lèi)數(shù)據(jù)用兩個(gè)參數(shù)path和transform初始化,這兩個(gè)參數(shù)作為參數(shù)傳遞給__init__。當(dāng)我們聲明這個(gè)類(lèi)的一個(gè)對(duì)象時(shí),它會(huì)在內(nèi)部調(diào)用__init__。
由于使用了len來(lái)返回整個(gè)數(shù)據(jù)集的長(zhǎng)度,所以我使用len(self.files)來(lái)返回相同的長(zhǎng)度。
函數(shù)getitem是最關(guān)鍵的,它加載圖像,然后調(diào)整其大小,然后將其轉(zhuǎn)換為張量。這里需要注意的一點(diǎn)是,提供給神經(jīng)網(wǎng)絡(luò)的數(shù)據(jù)應(yīng)該總是標(biāo)準(zhǔn)化的。我們使用transforms.ToTensor處理規(guī)范化。最后,getitem返回兩個(gè)結(jié)果,image作為張量,label作為對(duì)應(yīng)的數(shù)據(jù)點(diǎn)。
在初始化類(lèi)數(shù)據(jù)之后,我們使用DataLoader函數(shù)自動(dòng)將整個(gè)數(shù)據(jù)批處理成一個(gè)定義的批大小。因此,如果你的原始數(shù)據(jù)點(diǎn)大小是(3,224,224)(你從__getitem__獲得),那么dataloader的每個(gè)項(xiàng)都將具有大小(batch_size,3,224,224),即它會(huì)自動(dòng)對(duì)數(shù)據(jù)點(diǎn)的batch_size數(shù)進(jìn)行采樣。
這在我們的例子中是可能的,因?yàn)閳D像的大小是恒定的,所以DataLoader函數(shù)能夠自動(dòng)創(chuàng)建批處理。然而,在自然語(yǔ)言處理這樣的情況下,當(dāng)大小不是常數(shù)時(shí),我們需要編寫(xiě)自己的批處理函數(shù)。
序列數(shù)據(jù)集的數(shù)據(jù)加載器
現(xiàn)在讓我們來(lái)處理序列數(shù)據(jù)集,即句子、時(shí)間序列、音頻等。這里的__getitem__將不再提供相同大小的數(shù)據(jù)點(diǎn)。例如,考慮情緒分類(lèi)的任務(wù)(在這里解釋),那么一句話可以是“The flight service was very good”,另一句話可以是“I did not get my baggage on the belt, pathetic service.”在這里,兩句話的長(zhǎng)度是不同的。
為了解決這個(gè)問(wèn)題,讓我們先回答三個(gè)問(wèn)題。
什么是batch?-批處理是指將多個(gè)數(shù)據(jù)點(diǎn)的張量合并成一個(gè)張量
為什么我們需要分批處理?批處理可以用于加快計(jì)算速度,因?yàn)榕幚砜梢酝瑫r(shí)處理多個(gè)數(shù)據(jù)點(diǎn),而不是一次只處理一個(gè)數(shù)據(jù)點(diǎn)。
如何進(jìn)行batch化?因?yàn)槲覀冊(cè)谶@里合并多個(gè)張量,所以張量的每個(gè)維度的大小都需要相同。由于輸出的數(shù)據(jù)點(diǎn)大小不一,我們手中就有一個(gè)問(wèn)題。
我們現(xiàn)在主要要解決batch化問(wèn)題。
為了便于我們?cè)谶@里討論,我們將使用IMDB數(shù)據(jù)集,它是一個(gè)評(píng)論數(shù)據(jù)集。因?yàn)槲覀冊(cè)谶@里處理的是句子,所以處理數(shù)據(jù)集的方法會(huì)有所不同。
因?yàn)樯窠?jīng)網(wǎng)絡(luò)只懂?dāng)?shù)字,不懂單詞,所以我們必須把每個(gè)單詞轉(zhuǎn)換成一個(gè)數(shù)字。為了做到這一點(diǎn),我們必須構(gòu)建一個(gè)詞匯表,如下代碼所述。
import?os import?gensim from?collections?import?Counter import?jsontrain_path?=?"./aclImdb/train" test_path?=?"./aclImdb/test"#?simple函數(shù)從目錄讀取數(shù)據(jù)并返回?cái)?shù)據(jù)和標(biāo)簽 #?你可以為其他數(shù)據(jù)集制作自己的讀取器。 def?reader(path):pos_path?=?os.path.join(path,?"pos")neg_path?=?os.path.join(path,?"neg")data?=?[]label?=?[]for?file?in?os.listdir(pos_path):f?=?open(os.path.join(pos_path,?file))data.append(f.read())label.append(1)for?file?in?os.listdir(neg_path):f?=?open(os.path.join(neg_path,?file))data.append(f.read())label.append(0)#?print(data[:1])return?data,?labeldef?build_vocab(data,?min_word_count?=?5):counter?=?Counter()for?line?in?data:l?=?gensim.utils.simple_preprocess(line)counter.update(l)#?初始化一個(gè)字典或查找表word2id?=?{}word2id['<pad>']?=?0word2id['<unk>']?=?1#?只包括那些在字典中出現(xiàn)超過(guò)min次的單詞。words?=?[word?for?word,?count?in?counter.items()?if?count>min_word_count]for?i,?word?in?enumerate(words):word2id[word]?=?i+2with?open("word2id.json",?'w')?as?f:json.dump(word2id,?f)return?word2iddata,?label?=?reader(train_path) word2id?=?build_vocab(data) print("Dictionary?Formed?and?saved.?The?length?of?dictionary?is-:?",?len(word2id))函數(shù)讀取器用于讀取整個(gè)數(shù)據(jù),它返回所有句子的列表,標(biāo)簽“0”表示消極評(píng)論,“1”表示積極評(píng)論。
函數(shù)build_vocab將數(shù)據(jù)和最小字?jǐn)?shù)作為輸入,并將每個(gè)字的映射(稱(chēng)為“word2id”)作為輸出,映射到一個(gè)唯一的數(shù)字。對(duì)于每個(gè)向前的未知單詞,對(duì)應(yīng)的數(shù)字將是1。
繼續(xù)為序列數(shù)據(jù)集編寫(xiě)數(shù)據(jù)集類(lèi)。我們的目標(biāo)是在給定索引的情況下,一次輸出一個(gè)item。
import?torch from?torch.utils.data?import?Dataset,?DataLoader import?numpy?as?np import?os import?gensimclass?Dataset_seq(Dataset):def?__init__(self,?word2id,?train_path):self.word2id?=?word2idself.train_path?=?train_path#?讀取數(shù)據(jù)和標(biāo)簽self.data,?self.label?=?reader(train_path)def?__getitem__(self,?index):#?返回seq和標(biāo)簽seq?=?self.preprocess(self.data[index])label?=?self.label[index]return?seq,?labeldef?__len__(self):return(len(self.data))def?preprocess(self,?text):#?用于將line轉(zhuǎn)換為token,然后使用word2id將其轉(zhuǎn)換為相應(yīng)的數(shù)字值line?=?gensim.utils.simple_preprocess(text)seq?=?[]for?word?in?line:if?word?in?self.word2id:seq.append(self.word2id[word])else:seq.append(self.word2id['<unk>'])#?將list轉(zhuǎn)換成張量seq?=?torch.from_numpy(np.array(seq))return?seq由于上面已經(jīng)討論了不同函數(shù)的功能,我將簡(jiǎn)要地回顧一下。
函數(shù)__init__采用word2id映射和train路徑。然后,init調(diào)用reader獲取與句子對(duì)應(yīng)的數(shù)據(jù)和標(biāo)簽。
函數(shù)__len__ 返回整個(gè)數(shù)據(jù)集的長(zhǎng)度,即self.data。
函數(shù)preprocess將輸入句子轉(zhuǎn)換成數(shù)字張量,其中每個(gè)數(shù)字對(duì)應(yīng)于句子中的單詞。
函數(shù)getitem用于在索引的幫助下輸出一個(gè)經(jīng)過(guò)處理的數(shù)據(jù)點(diǎn)。
下面的代碼定義了collate_fn。
train_dataset?=?Dataset_seq(word2id,?train_path) train_dataloader?=?DataLoader(dataset=train_dataset,?batch_size=batch_size,?shuffle=True,collate_fn=collate_fn)def?collate_fn(data):'''??我們應(yīng)該構(gòu)建一個(gè)自定義的collate_fn,而不是使用默認(rèn)的collate_fn,因?yàn)槊總€(gè)句子的大小不同,并且默認(rèn)不支持合并序列。Args:data:?元組列表?(training?sequence,?label)Return:padded_seq?-?填充序列,形狀?(batch_size,?padded_length)length?-?每個(gè)序列的原始長(zhǎng)度(沒(méi)有填充),?形狀(batch_size)label?-?張量形狀?(batch_size)'''data.sort(key=lambda?x:?len(x[0]),?reverse=True)sequences,?label?=?zip(*data)length?=?[len(seq)?for?seq?in?sequences]padded_seq?=?torch.zeros(len(sequences),?max(length)).long()for?i,?seq?in?enumerate(sequences):end?=?length[i]padded_seq[i,:end]?=?seqreturn?padded_seq,?torch.from_numpy(np.array(length)),?torch.from_numpy(np.array(label))這里需要注意的一點(diǎn)是,在一個(gè)元組列表中,每個(gè)元組可以有不同的大小,但在張量中,所有維度的大小都必須相同才能合并它們。
collate_fn自動(dòng)獲得一個(gè)名為data的輸入,這是一個(gè)長(zhǎng)度等于batch size的元組列表。每個(gè)元組包含數(shù)字張量及其相應(yīng)的標(biāo)簽。
為了簡(jiǎn)單起見(jiàn),我們將它們分別稱(chēng)為sequence和label。所以最終我們必須以這樣一種方式轉(zhuǎn)換每個(gè)序列,使它們的大小保持不變。
為了實(shí)現(xiàn)這一點(diǎn),我們執(zhí)行零填充,如上面的代碼所示。由于對(duì)整個(gè)數(shù)據(jù)集統(tǒng)一使用零填充,因此模型了解到它沒(méi)有多大用處,它只是表示浪費(fèi)值。
我們肯定已經(jīng)找到了解決辦法,但問(wèn)題是,這是一個(gè)最佳的解決辦法嗎?如果所有序列的原始大小都有很大的差異,或者換言之有很大的差異,那么我們最終會(huì)浪費(fèi)大量的GPU內(nèi)存,而這些內(nèi)存是零填充的,這最終是沒(méi)有用的。必須有一個(gè)更好的方法來(lái)最小化零填充的要求!
這個(gè)問(wèn)題的解決請(qǐng)關(guān)注后續(xù)文章!
往期精彩回顧適合初學(xué)者入門(mén)人工智能的路線及資料下載機(jī)器學(xué)習(xí)及深度學(xué)習(xí)筆記等資料打印機(jī)器學(xué)習(xí)在線手冊(cè)深度學(xué)習(xí)筆記專(zhuān)輯《統(tǒng)計(jì)學(xué)習(xí)方法》的代碼復(fù)現(xiàn)專(zhuān)輯 AI基礎(chǔ)下載機(jī)器學(xué)習(xí)的數(shù)學(xué)基礎(chǔ)專(zhuān)輯黃海廣老師《機(jī)器學(xué)習(xí)課程》課件合集 本站qq群851320808,加入微信群請(qǐng)掃碼:總結(jié)
以上是生活随笔為你收集整理的【深度学习】基于PyTorch深度学习框架的序列图像数据装载器的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 码云提交时报错git 报错 fatal:
- 下一篇: 【深度学习】深度学习手写代码汇总(建议收