【深度学习】基于PyTorch深度学习框架的序列图像数据装载器
作者 | Harsh Maheshwari
編譯 | VK
來源 | Towards Data Science
如今,深度學習和機器學習算法正在統(tǒng)治世界。PyTorch是最常用的深度學習框架之一,用于實現(xiàn)各種深度學習算法。另一方面,基于學習的方法本質上需要一些帶注釋的訓練數據集,這些數據集可以被模型用來提取輸入數據和標簽之間的關系。為了給神經網絡提供數據,我們定義了一個數據加載器。
在這個博客中,我們將看到如何在PyTorch框架中為不同的數據集編寫一個數據加載器。
圖像數據集的數據加載器
我們將致力于狗與貓的圖像分類問題。我們需要對給定的圖像進行分類,數據集可以從這里下載:https://www.kaggle.com/c/dogs-vs-cats。訓練數據集總共包含25000個圖像。因為這是一個分類問題,所以dog的標簽是“0”,cat的標簽是“1”。
讓我們從導入所有必需的庫開始。
import?os from?PIL?import?Image import?torch from?torch.utils.data?import?DataLoader,?Dataset import?torchvision.transforms?as?transforms import?torch.nn?as?nnPyTorch框架的dataset類被定義為一個類,其基本結構如下
class?data(Dataset):def?__init__(self,?param1,?param2):#?函數在此處初始化def?__len__(self):#?函數返回數據的長度def?__getitem__(self,?index):#?一次提供一個項目這個類的最終目的是使用函數 __getitem__每次提供一個數據點。這是通過使用內部傳遞給函數的索引完成的,使用Dataloader中定義的sampler函數(將在接下來的博客中討論)。
初始化數據集的對象時,會調用函數 __init__。在這里,你可以傳遞多個參數,這些參數對于編寫 __getitem__非常有用。
函數用于返回數據集的總長度。在此基礎上,將生成索引,然后將其提供給getitem。
dog vs cat數據集的格式如下-:
data/-?dog_1.jpg-?dog_2.jpg.........-?cat_1.jpg-?cat_2.jpg.........現(xiàn)在我們已經了解了編寫數據加載器所需的組件,讓我們深入研究一下我們的用例。
class?data(Dataset):???def?__init__(self,?path,?transform):self.files?=?os.listdir(path)self.transform?=?transformself.path?=?path???def?__len__(self):return?len(self.files)???def?__getitem__(self,?index):filename?=?self.files[index]input?=?Image.open(os.path.join(self.path,?filename))label?=?0?if?filename.find("dog")>=0?else?1img_as_tensor?=?self.transform(input)return?img_as_tensor,?labeltransformations?=?transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])path?=?"./data" train_dataset?=?data(path,?transformations) dataloader?=?DataLoader(train_dataset,?batch_size=Train_Batch_Size,?shuffle=True)首先讓我們了解函數__init__。類數據用兩個參數path和transform初始化,這兩個參數作為參數傳遞給__init__。當我們聲明這個類的一個對象時,它會在內部調用__init__。
由于使用了len來返回整個數據集的長度,所以我使用len(self.files)來返回相同的長度。
函數getitem是最關鍵的,它加載圖像,然后調整其大小,然后將其轉換為張量。這里需要注意的一點是,提供給神經網絡的數據應該總是標準化的。我們使用transforms.ToTensor處理規(guī)范化。最后,getitem返回兩個結果,image作為張量,label作為對應的數據點。
在初始化類數據之后,我們使用DataLoader函數自動將整個數據批處理成一個定義的批大小。因此,如果你的原始數據點大小是(3,224,224)(你從__getitem__獲得),那么dataloader的每個項都將具有大小(batch_size,3,224,224),即它會自動對數據點的batch_size數進行采樣。
這在我們的例子中是可能的,因為圖像的大小是恒定的,所以DataLoader函數能夠自動創(chuàng)建批處理。然而,在自然語言處理這樣的情況下,當大小不是常數時,我們需要編寫自己的批處理函數。
序列數據集的數據加載器
現(xiàn)在讓我們來處理序列數據集,即句子、時間序列、音頻等。這里的__getitem__將不再提供相同大小的數據點。例如,考慮情緒分類的任務(在這里解釋),那么一句話可以是“The flight service was very good”,另一句話可以是“I did not get my baggage on the belt, pathetic service.”在這里,兩句話的長度是不同的。
為了解決這個問題,讓我們先回答三個問題。
什么是batch?-批處理是指將多個數據點的張量合并成一個張量
為什么我們需要分批處理?批處理可以用于加快計算速度,因為批處理可以同時處理多個數據點,而不是一次只處理一個數據點。
如何進行batch化?因為我們在這里合并多個張量,所以張量的每個維度的大小都需要相同。由于輸出的數據點大小不一,我們手中就有一個問題。
我們現(xiàn)在主要要解決batch化問題。
為了便于我們在這里討論,我們將使用IMDB數據集,它是一個評論數據集。因為我們在這里處理的是句子,所以處理數據集的方法會有所不同。
因為神經網絡只懂數字,不懂單詞,所以我們必須把每個單詞轉換成一個數字。為了做到這一點,我們必須構建一個詞匯表,如下代碼所述。
import?os import?gensim from?collections?import?Counter import?jsontrain_path?=?"./aclImdb/train" test_path?=?"./aclImdb/test"#?simple函數從目錄讀取數據并返回數據和標簽 #?你可以為其他數據集制作自己的讀取器。 def?reader(path):pos_path?=?os.path.join(path,?"pos")neg_path?=?os.path.join(path,?"neg")data?=?[]label?=?[]for?file?in?os.listdir(pos_path):f?=?open(os.path.join(pos_path,?file))data.append(f.read())label.append(1)for?file?in?os.listdir(neg_path):f?=?open(os.path.join(neg_path,?file))data.append(f.read())label.append(0)#?print(data[:1])return?data,?labeldef?build_vocab(data,?min_word_count?=?5):counter?=?Counter()for?line?in?data:l?=?gensim.utils.simple_preprocess(line)counter.update(l)#?初始化一個字典或查找表word2id?=?{}word2id['<pad>']?=?0word2id['<unk>']?=?1#?只包括那些在字典中出現(xiàn)超過min次的單詞。words?=?[word?for?word,?count?in?counter.items()?if?count>min_word_count]for?i,?word?in?enumerate(words):word2id[word]?=?i+2with?open("word2id.json",?'w')?as?f:json.dump(word2id,?f)return?word2iddata,?label?=?reader(train_path) word2id?=?build_vocab(data) print("Dictionary?Formed?and?saved.?The?length?of?dictionary?is-:?",?len(word2id))函數讀取器用于讀取整個數據,它返回所有句子的列表,標簽“0”表示消極評論,“1”表示積極評論。
函數build_vocab將數據和最小字數作為輸入,并將每個字的映射(稱為“word2id”)作為輸出,映射到一個唯一的數字。對于每個向前的未知單詞,對應的數字將是1。
繼續(xù)為序列數據集編寫數據集類。我們的目標是在給定索引的情況下,一次輸出一個item。
import?torch from?torch.utils.data?import?Dataset,?DataLoader import?numpy?as?np import?os import?gensimclass?Dataset_seq(Dataset):def?__init__(self,?word2id,?train_path):self.word2id?=?word2idself.train_path?=?train_path#?讀取數據和標簽self.data,?self.label?=?reader(train_path)def?__getitem__(self,?index):#?返回seq和標簽seq?=?self.preprocess(self.data[index])label?=?self.label[index]return?seq,?labeldef?__len__(self):return(len(self.data))def?preprocess(self,?text):#?用于將line轉換為token,然后使用word2id將其轉換為相應的數字值line?=?gensim.utils.simple_preprocess(text)seq?=?[]for?word?in?line:if?word?in?self.word2id:seq.append(self.word2id[word])else:seq.append(self.word2id['<unk>'])#?將list轉換成張量seq?=?torch.from_numpy(np.array(seq))return?seq由于上面已經討論了不同函數的功能,我將簡要地回顧一下。
函數__init__采用word2id映射和train路徑。然后,init調用reader獲取與句子對應的數據和標簽。
函數__len__ 返回整個數據集的長度,即self.data。
函數preprocess將輸入句子轉換成數字張量,其中每個數字對應于句子中的單詞。
函數getitem用于在索引的幫助下輸出一個經過處理的數據點。
下面的代碼定義了collate_fn。
train_dataset?=?Dataset_seq(word2id,?train_path) train_dataloader?=?DataLoader(dataset=train_dataset,?batch_size=batch_size,?shuffle=True,collate_fn=collate_fn)def?collate_fn(data):'''??我們應該構建一個自定義的collate_fn,而不是使用默認的collate_fn,因為每個句子的大小不同,并且默認不支持合并序列。Args:data:?元組列表?(training?sequence,?label)Return:padded_seq?-?填充序列,形狀?(batch_size,?padded_length)length?-?每個序列的原始長度(沒有填充),?形狀(batch_size)label?-?張量形狀?(batch_size)'''data.sort(key=lambda?x:?len(x[0]),?reverse=True)sequences,?label?=?zip(*data)length?=?[len(seq)?for?seq?in?sequences]padded_seq?=?torch.zeros(len(sequences),?max(length)).long()for?i,?seq?in?enumerate(sequences):end?=?length[i]padded_seq[i,:end]?=?seqreturn?padded_seq,?torch.from_numpy(np.array(length)),?torch.from_numpy(np.array(label))這里需要注意的一點是,在一個元組列表中,每個元組可以有不同的大小,但在張量中,所有維度的大小都必須相同才能合并它們。
collate_fn自動獲得一個名為data的輸入,這是一個長度等于batch size的元組列表。每個元組包含數字張量及其相應的標簽。
為了簡單起見,我們將它們分別稱為sequence和label。所以最終我們必須以這樣一種方式轉換每個序列,使它們的大小保持不變。
為了實現(xiàn)這一點,我們執(zhí)行零填充,如上面的代碼所示。由于對整個數據集統(tǒng)一使用零填充,因此模型了解到它沒有多大用處,它只是表示浪費值。
我們肯定已經找到了解決辦法,但問題是,這是一個最佳的解決辦法嗎?如果所有序列的原始大小都有很大的差異,或者換言之有很大的差異,那么我們最終會浪費大量的GPU內存,而這些內存是零填充的,這最終是沒有用的。必須有一個更好的方法來最小化零填充的要求!
這個問題的解決請關注后續(xù)文章!
往期精彩回顧適合初學者入門人工智能的路線及資料下載機器學習及深度學習筆記等資料打印機器學習在線手冊深度學習筆記專輯《統(tǒng)計學習方法》的代碼復現(xiàn)專輯 AI基礎下載機器學習的數學基礎專輯黃海廣老師《機器學習課程》課件合集 本站qq群851320808,加入微信群請掃碼:總結
以上是生活随笔為你收集整理的【深度学习】基于PyTorch深度学习框架的序列图像数据装载器的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 码云提交时报错git 报错 fatal:
- 下一篇: 【深度学习】深度学习手写代码汇总(建议收