【深度学习】在PyTorch中构建高效的自定义数据集
文章來源于磐創AI,作者磐創AI
學習Dataset類的來龍去脈,使用干凈的代碼結構,同時最大限度地減少在訓練期間管理大量數據的麻煩。
神經網絡訓練在數據管理上可能很難做到“大規模”。
PyTorch 最近已經出現在我的圈子里,盡管對Keras和TensorFlow感到滿意,但我還是不得不嘗試一下。令人驚訝的是,我發現它非常令人耳目一新,非常討人喜歡,尤其是PyTorch 提供了一個Pythonic API、一個更為固執己見的編程模式和一組很好的內置實用程序函數。我特別喜歡的一項功能是能夠輕松地創建一個自定義的Dataset對象,然后可以與內置的DataLoader一起在訓練模型時提供數據。
在本文中,我將從頭開始研究PyTorchDataset對象,其目的是創建一個用于處理文本文件的數據集,以及探索如何為特定任務優化管道。我們首先通過一個簡單示例來了解Dataset實用程序的基礎知識,然后逐步完成實際任務。具體地說,我們想創建一個管道,從The Elder Scrolls(TES)系列中獲取名稱,這些名稱的種族和性別屬性作為一個one-hot張量。你可以在我的網站(http://syaffers.xyz/#datasets)上找到這個數據集。
Dataset類的基礎知識
Pythorch允許您自由地對“Dataset”類執行任何操作,只要您重寫兩個子類函數:
-返回數據集大小的函數,以及
-函數的函數從給定索引的數據集中返回一個樣本。
數據集的大小有時可能是灰色區域,但它等于整個數據集中的樣本數。因此,如果數據集中有10000個單詞(或數據點、圖像、句子等),則函數“uuLen_uUu”應該返回10000個。
PyTorch使您可以自由地對Dataset類執行任何操作,只要您重寫改類中的兩個函數即可:
__len__ 函數:返回數據集大小
__getitem__ 函數:返回對應索引的數據集中的樣本
數據集的大小有時難以確定,但它等于整個數據集中的樣本數量。因此,如果您的數據集中有10,000個樣本(數據點,圖像,句子等),則__len__函數應返回10,000。
一個小示例
首先,創建一個從1到1000所有數字的Dataset來模擬一個簡單的數據集。我們將其適當地命名為NumbersDataset。
from?torch.utils.data?import?Datasetclass?NumbersDataset(Dataset):def?__init__(self):self.samples?=?list(range(1,?1001))def?__len__(self):return?len(self.samples)def?__getitem__(self,?idx):return?self.samples[idx]if?__name__?==?'__main__':dataset?=?NumbersDataset()print(len(dataset))print(dataset[100])print(dataset[122:361])
很簡單,對吧?首先,當我們初始化NumbersDataset時,我們立即創建一個名為samples的列表,該列表將存儲1到1000之間的所有數字。列表的名稱是任意的,因此請隨意使用您喜歡的名稱。需要重寫的函數是不用我說明的(我希望!),并且對在構造函數中創建的列表進行操作。如果運行該python文件,將看到1000、101和122到361之間的值,它們分別指的是數據集的長度,數據集中索引為100的數據以及索引為121到361之間的數據集切片。
擴展數據集
讓我們擴展此數據集,以便它可以存儲low和high之間的所有整數。
from?torch.utils.data?import?Datasetclass?NumbersDataset(Dataset):def?__init__(self,?low,?high):self.samples?=?list(range(low,?high))def?__len__(self):return?len(self.samples)def?__getitem__(self,?idx):return?self.samples[idx]if?__name__?==?'__main__':dataset?=?NumbersDataset(2821,?8295)print(len(dataset))print(dataset[100])print(dataset[122:361])運行上面代碼應在控制臺打印5474、2921和2943到3181之間的數字。通過編寫構造函數,我們現在可以將數據集的low和high設置為我們的想要的內容。這個簡單的更改顯示了我們可以從PyTorch的Dataset類獲得的各種好處。例如,我們可以生成多個不同的數據集并使用這些值,而不必像在NumPy中那樣,考慮編寫新的類或創建許多難以理解的矩陣。
從文件讀取數據
讓我們來進一步擴展Dataset類的功能。PyTorch與Python標準庫的接口設計得非常優美,這意味著您不必擔心集成功能。在這里,我們將
創建一個全新的使用Python I/O和一些靜態文件的Dataset類
收集TES角色名稱(我的網站上(http://syaffers.xyz/#datasets)有可用的數據集),這些角色名稱分為種族文件夾和性別文件,以填充samples列表
通過在samples列表中存儲一個元組而不只是名稱本身來跟蹤每個名稱的種族和性別。
TES名稱數據集具有以下目錄結構:
. |--?Altmer/ |???|--?Female |???`--?Male |--?Argonian/ |???|--?Female |???`--?Male ...?(truncated?for?brevity)(為了簡潔,這里進行省略) `--?Redguard/|--?Female`--?Male每個文件都包含用換行符分隔的TES名稱,因此我們必須逐行讀取每個文件,以捕獲每個種族和性別的所有字符名稱。
import?os from?torch.utils.data?import?Datasetclass?TESNamesDataset(Dataset):def?__init__(self,?data_root):self.samples?=?[]for?race?in?os.listdir(data_root):race_folder?=?os.path.join(data_root,?race)for?gender?in?os.listdir(race_folder):gender_filepath?=?os.path.join(race_folder,?gender)with?open(gender_filepath,?'r')?as?gender_file:for?name?in?gender_file.read().splitlines():self.samples.append((race,?gender,?name))def?__len__(self):return?len(self.samples)def?__getitem__(self,?idx):return?self.samples[idx]if?__name__?==?'__main__':dataset?=?TESNamesDataset('/home/syafiq/Data/tes-names/')print(len(dataset))print(dataset[420])我們來看一下代碼:首先創建一個空的samples列表,然后遍歷每個種族(race)文件夾和性別文件并讀取每個文件中的名稱來填充該列表。然后將種族,性別和名稱存儲在元組中,并將其添加到samples列表中。運行該文件應打印19491和('Bosmer', 'Female', 'Gluineth')(每臺計算機的輸出可能不太一樣)。讓我們看一下將數據集的一個batch的樣子:
#?將main函數改成下面這樣: if?__name__?==?'__main__':dataset?=?TESNamesDataset('/home/syafiq/Data/tes-names/')print(dataset[10:60])
正如您所想的,它的工作原理與列表完全相同。對本節內容進行總結,我們剛剛將標準的Python I/O 引入了PyTorch數據集中,并且我們不需要任何其他特殊的包裝器或幫助器,只需要單純的Python代碼。實際上,我們還可以包括NumPy或Pandas之類的其他庫,并且通過一些巧妙的操作,使它們在PyTorch中發揮良好的作用。讓我們現在來看看在訓練時如何有效地遍歷數據集。
用DataLoader加載數據
盡管Dataset類是創建數據集的一種不錯的方法,但似乎在訓練時,我們將需要對數據集的samples列表進行索引或切片。這并不比我們對列表或NumPy矩陣進行操作更簡單。PyTorch并沒有沿這條路走,而是提供了另一個實用工具類DataLoader。DataLoader充當Dataset對象的數據饋送器(feeder)。如果您熟悉的話,這個對象跟Keras中的flow數據生成器函數很類似。DataLoader需要一個Dataset對象(它延伸任何子類)和其他一些可選參數(參數都列在PyTorch的DataLoader文檔(https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)中)。在這些參數中,我們可以選擇對數據進行打亂,確定batch的大小和并行加載數據的線程(job)數量。這是TESNamesDataset在循環中進行調用的一個簡單示例。
#?將main函數改成下面這樣: if?__name__?==?'__main__':from?torch.utils.data?import?DataLoaderdataset?=?TESNamesDataset('/home/syafiq/Data/tes-names/')dataloader?=?DataLoader(dataset,?batch_size=50,?shuffle=True,?num_workers=2)for?i,?batch?in?enumerate(dataloader):print(i,?batch)當您看到大量的batch被打印出來時,您可能會注意到每個batch都是三元組的列表:第一個元組包含種族,下一個元組包含性別,最后一個元祖包含名稱。
等等,那不是我們之前對數據集進行切片時的樣子!這里到底發生了什么?好吧,事實證明,DataLoader以系統的方式加載數據,以便我們垂直而非水平來堆疊數據。這對于一個batch的張量(tensor)流動特別有用,因為張量垂直堆疊(即在第一維上)構成batch。此外,DataLoader還會為對數據進行重新排列,因此在發送(feed)數據時無需重新排列矩陣或跟蹤索引。
張量(tensor)和其他類型
為了進一步探索不同類型的數據在DataLoader中是如何加載的,我們將更新我們先前模擬的數字數據集,以產生兩對張量數據:數據集中每個數字的后4個數字的張量,以及加入一些隨機噪音的張量。為了拋出DataLoader的曲線球,我們還希望返回數字本身,而不是張量類型,是作為Python字符串返回。__getitem__函數將在一個元組中返回三個異構數據項。
from?torch.utils.data?import?Dataset import?torchclass?NumbersDataset(Dataset):def?__init__(self,?low,?high):self.samples?=?list(range(low,?high))def?__len__(self):return?len(self.samples)def?__getitem__(self,?idx):n?=?self.samples[idx]successors?=?torch.arange(4).float()?+?n?+?1noisy?=?torch.randn(4)?+?successorsreturn?n,?successors,?noisyif?__name__?==?'__main__':from?torch.utils.data?import?DataLoaderdataset?=?NumbersDataset(100,?120)dataloader?=?DataLoader(dataset,?batch_size=10,?shuffle=True)print(next(iter(dataloader)))請注意,我們沒有更改數據集的構造函數,而是修改了__getitem__函數。對于PyTorch數據集來說,比較好的做法是,因為該數據集將隨著樣本越來越多而進行縮放,因此我們不想在Dataset對象運行時,在內存中存儲太多張量類型的數據。取而代之的是,當我們遍歷樣本列表時,我們將希望它是張量類型,以犧牲一些速度來節省內存。在以下各節中,我將解釋它的用處。
觀察上面的輸出,盡管我們新的__getitem__函數返回了一個巨大的字符串和張量元組,但是DataLoader能夠識別數據并進行相應的堆疊。字符串化后的數字形成元組,其大小與創建DataLoader時配置的batch大小的相同。對于兩個張量,DataLoader將它們垂直堆疊成一個大小為10x4的張量。這是因為我們將batch大小配置為10,并且在__getitem__函數返回兩個大小為4的張量。
通常來說,DataLoader嘗試將一批一維張量堆疊為二維張量,將一批二維張量堆疊為三維張量,依此類推。在這一點上,我懇請您注意到這對其他機器學習庫中的傳統數據處理產生了翻天覆地的影響,以及這個做法是多么優雅。太不可思議了!如果您不同意我的觀點,那么至少您現在知道有這樣的一種方法。
完成TES數據集的代碼
讓我們回到TES數據集。似乎初始化函數的代碼有點不優雅(至少對于我而言,確實應該有一種使代碼看起來更好的方法。請記住我說過的,PyTorch API是像python的(Pythonic)嗎?數據集中的工具函數,甚至對內部函數進行初始化。為清理TES數據集的代碼,我們將更新TESNamesDataset的代碼來實現以下目的:
更新構造函數以包含字符集
創建一個內部函數來初始化數據集
創建一個將標量轉換為獨熱(one-hot)張量的工具函數
創建一個工具函數,該函數將樣本數據轉換為種族,性別和名稱的三個獨熱(one-hot)張量的集合。
為了使工具函數正常工作,我們將借助scikit-learn庫對數值(即種族,性別和名稱數據)進行編碼。具體來說,我們將需要LabelEncoder類。我們對代碼進行大量的更新,我將在接下來的幾小節中解釋這些修改的代碼。
import?os from?sklearn.preprocessing?import?LabelEncoder from?torch.utils.data?import?Dataset import?torchclass?TESNamesDataset(Dataset):def?__init__(self,?data_root,?charset):self.data_root?=?data_rootself.charset?=?charsetself.samples?=?[]self.race_codec?=?LabelEncoder()self.gender_codec?=?LabelEncoder()self.char_codec?=?LabelEncoder()self._init_dataset()def?__len__(self):return?len(self.samples)def?__getitem__(self,?idx):race,?gender,?name?=?self.samples[idx]return?self.one_hot_sample(race,?gender,?name)def?_init_dataset(self):races?=?set()genders?=?set()for?race?in?os.listdir(self.data_root):race_folder?=?os.path.join(self.data_root,?race)races.add(race)for?gender?in?os.listdir(race_folder):gender_filepath?=?os.path.join(race_folder,?gender)genders.add(gender)with?open(gender_filepath,?'r')?as?gender_file:for?name?in?gender_file.read().splitlines():self.samples.append((race,?gender,?name))self.race_codec.fit(list(races))self.gender_codec.fit(list(genders))self.char_codec.fit(list(self.charset))def?to_one_hot(self,?codec,?values):value_idxs?=?codec.transform(values)return?torch.eye(len(codec.classes_))[value_idxs]def?one_hot_sample(self,?race,?gender,?name):t_race?=?self.to_one_hot(self.race_codec,?[race])t_gender?=?self.to_one_hot(self.gender_codec,?[gender])t_name?=?self.to_one_hot(self.char_codec,?list(name))return?t_race,?t_gender,?t_nameif?__name__?==?'__main__':import?stringdata_root?=?'/home/syafiq/Data/tes-names/'charset?=?string.ascii_letters?+?"-'?"dataset?=?TESNamesDataset(data_root,?charset)print(len(dataset))print(dataset[420])修改的構造函數初始化
構造函數這里有很多變化,所以讓我們一點一點地來解釋它。您可能已經注意到構造函數中沒有任何文件處理邏輯。我們已將此邏輯移至_init_dataset函數中,并清理了構造函數。此外,我們添加了一些編碼器,來將原始字符串轉換為整數并返回。samples列表也是一個空列表,我們將在_init_dataset函數中填充該列表。構造函數還接受一個新的參數charset。顧名思義,它只是一個字符串,可以將char_codec轉換為整數。
已增強了文件處理功能,該功能可以在我們遍歷文件夾時捕獲種族和性別的唯一標簽。如果您沒有結構良好的數據集,這將很有用;例如,如果Argonians擁有一個與性別無關的名稱,我們將擁有一個名為“Unknown”的文件,并將其放入性別集合中,而不管其他種族是否存在“Unknown”性別。所有名稱存儲完畢后,我們將在由種族,性別和名稱構成數據集來初始化編碼器。
工具函數
我們添加了兩個工具函數:to_one_hot和one_hot_sample。to_one_hot使用數據集的內部編碼器將數值列表轉換為整數列表,然后再調用看似不適當的torch.eye函數。實際上,這是一種巧妙的技巧,可以將整數列表快速轉換為一個向量。torch.eye函數創建一個任意大小的單位矩陣,其對角線上的值為1。如果對矩陣行進行索引,則將在該索引處獲得值為1的行向量,這是獨熱向量的定義!
因為我們需要將三個數據轉換為張量,所以我們將在對應數據的每個編碼器上調用to_one_hot函數。one_hot_sample將單個樣本數據轉換為張量元組。種族和性別被轉換為二維張量,這實際上是擴展的行向量。該向量也被轉換為二維張量,但該二維向量包含該名稱的每個字符每個獨熱向量。
__getitem__調用
最后,__getitem__函數的代碼已更新為僅在one_hot_sample給定種族,性別和名稱的情況下調用該函數。注意,我們不需要在samples列表中預先準備張量,而是僅在調用__getitem__函數(即DataLoader加載數據流時)時形成張量。當您在訓練期間有成千上萬的樣本要加載時,這使數據集具有很好的可伸縮性。
您可以想象如何在計算機視覺訓練場景中使用該數據集。數據集將具有文件名列表和圖像目錄的路徑,從而讓__getitem__函數僅讀取圖像文件并將它們及時轉換為張量來進行訓練。通過提供適當數量的工作線程,DataLoader可以并行處理多個圖像文件,可以使其運行得更快。PyTorch數據加載教程(https://pytorch.org/tutorials/beginner/data_loading_tutorial.html)有更詳細的圖像數據集,加載器,和互補數據集。這些都是由torchvision庫進行封裝的(它經常隨著PyTorch一起安裝)。torchvision用于計算機視覺,使得圖像處理管道(例如增白,歸一化,隨機移位等)很容易構建。
回到原文。數據集已經構建好了,看來我們已準備好使用它進行訓練……
……但我們還沒有
如果我們嘗試使用DataLoader來加載batch大小大于1的數據,則會遇到錯誤:
您可能已經看到過這種情況,但現實是,文本數據的不同樣本之間很少有相同的長度。結果,DataLoader嘗試批量處理多個不同長度的名稱張量,這在張量格式中是不可能的,因為在NumPy數組中也是如此。為了說明此問題,請考慮以下情況:當我們將“ John”和“ Steven”之類的名稱堆疊在一起形成一個單一的獨熱矩陣時。'John'轉換為大小4xC的二維張量,'Steven'轉換為大小6xC二維張量,其中C是字符集的長度。DataLoader嘗試將這些名稱堆疊為大小2x?xC三維張量(DataLoader認為堆積大小為1x4xC和1x6xC)。由于第二維不匹配,DataLoader拋出錯誤,導致它無法繼續運行。
可能的解決方案
為了解決這個問題,這里有兩種方法,每種方法都各有利弊。
將批處理(batch)大小設置為1,這樣您就永遠不會遇到錯誤。如果批處理大小為1,則單個張量不會與(可能)不同長度的其他任何張量堆疊在一起。但是,這種方法在進行訓練時會受到影響,因為神經網絡在單批次(batch)的梯度下降時收斂將非常慢。另一方面,當批次大小不重要時,這對于快速測試時,數據加載或沙盒測試很有用。
通過使用空字符填充或截斷名稱來獲得固定的長度。截短長的名稱或用空字符來填充短的名稱可以使所有名稱格式正確,并具有相同的輸出張量大小,從而可以進行批處理。不利的一面是,根據任務的不同,空字符可能是有害的,因為它不能代表原始數據。
由于本文的目的,我將選擇第二個方法,您只需對整體數據管道進行很少的更改即可實現此目的。請注意,這也適用于任何長度不同的字符數據(盡管有多種填充數據的方法,請參見NumPy(https://docs.scipy.org/doc/numpy-1.15.0/reference/generated/numpy.pad.html)和PyTorch(https://pytorch.org/docs/stable/_modules/torch/nn/modules/padding.html)中的選項部分)。在我的例子中,我選擇用零來填充名稱,因此我更新了構造函數和_init_dataset函數:
...?def?__init__(self,?data_root,?charset,?length):self.data_root?=?data_rootself.charset?=?charset?+?'\0'self.length?=?length...with?open(gender_filepath,?'r')?as?gender_file:for?name?in?gender_file.read().splitlines():if?len(name)?<?self.length:name?+=?'\0'?*?(self.length?-?len(name))else:name?=?name[:self.length-1]?+?'\0'self.samples.append((race,?gender,?name))...首先,我在構造函數引入一個新的參數,該參數將所有傳入名稱字符固定為length值。我還將\0字符添加到字符集中,用于填充短的名稱。接下來,數據集初始化邏輯已更新。缺少長度的名稱僅用\0填充,直到滿足長度的要求為止。超過固定長度的名稱將被截斷,最后一個字符將被替換為\0。替換是可選的,這取決于具體的任務。
而且,如果您現在嘗試加載此數據集,您應該獲得跟您當初所期望的數據:正確的批(batch)大小格式的張量。下圖顯示了批大小為2的張量,但請注意有三個張量:
堆疊種族張量,獨熱編碼形式表示該張量是十個種族中的某一個種族
堆疊性別張量,獨熱編碼形式表示數據集中存在兩種性別中的某一種性別
堆疊名稱張量,最后一個維度應該是charset的長度,第二個維度是名稱長度(固定大小后),第一個維度是批(batch)大小。
數據拆分實用程序
所有這些功能都內置在PyTorch中,真是太棒了。現在可能出現的問題是,如何制作驗證甚至測試集,以及如何在不擾亂代碼庫并盡可能保持DRY的情況下執行驗證或測試。測試集的一種方法是為訓練數據和測試數據提供不同的data_root,并在運行時保留兩個數據集變量(另外還有兩個數據加載器),尤其是在訓練后立即進行測試的情況下。
如果您想從訓練集中創建驗證集,那么可以使用PyTorch數據實用程序中的random_split 函數輕松處理這一問題。random_split 函數接受一個數據集和一個劃分子集大小的列表,該函數隨機拆分數據,以生成更小的Dataset對象,這些對象可立即與DataLoader一起使用。這里有一個例子。
通過使用內置函數輕松拆分自定義PyTorch數據集來創建驗證集。
事實上,您可以在任意間隔進行拆分,這對于折疊交叉驗證集非常有用。我對這個方法唯一的不滿是你不能定義百分比分割,這很煩人。至少子數據集的大小從一開始就明確定義了。另外,請注意,每個數據集都需要單獨的DataLoader,這絕對比在循環中管理兩個隨機排序的數據集和索引更干凈。
結束語
希望本文能使您了解PyTorch中Dataset和DataLoader實用程序的功能。與干凈的Pythonic API結合使用,它可以使編碼變得更加輕松愉快,同時提供一種有效的數據處理方式。我認為PyTorch開發的易用性根深蒂固于他們的開發理念,并且在我的工作中使用PyTorch之后,我從此不再回頭使用Keras和TensorFlow。我不得不說我確實錯過了Keras模型隨附的進度條和fit /predict API,但這是一個小小的挫折,因為最新的帶TensorBoard接口的PyTorch帶回了熟悉的工作環境。盡管如此,目前,PyTorch是我將來的深度學習項目的首選。
我鼓勵以這種方式構建自己的數據集,因為它消除了我以前管理數據時遇到的許多凌亂的編程習慣。在復雜情況下,Dataset 是一個救命稻草。我記得必須管理屬于一個樣本的數據,但該數據必須來自三個不同的MATLAB矩陣文件,并且需要正確切片,規范化和轉置。如果沒有Dataset和DataLoader組合,我不知如何進行管理,特別是因為數據量巨大,而且沒有簡便的方法將所有數據組合成NumPy矩陣且不會導致計算機崩潰。
最后,查看PyTorch數據實用程序文檔頁面(https://pytorch.org/docs/stable/data.html) ,其中包含其他類別和功能,這是一個很小但有價值的實用程序庫。您可以在我的GitHub上找到TES數據集的代碼,在該代碼中,我創建了與數據集同步的PyTorch中的LSTM名稱預測變量(https://github.com/syaffers/tes-names-rnn)。讓我知道這篇文章是有用的還是不清楚的,以及您將來是否希望獲得更多此類內容。
原文鏈接:https://towardsdatascience.com/building-efficient-custom-datasets-in-pytorch-2563b946fd9f
- End -
往期精彩回顧適合初學者入門人工智能的路線及資料下載機器學習及深度學習筆記等資料打印機器學習在線手冊深度學習筆記專輯《統計學習方法》的代碼復現專輯 AI基礎下載機器學習的數學基礎專輯獲取一折本站知識星球優惠券,復制鏈接直接打開:https://t.zsxq.com/yFQV7am本站qq群1003271085。加入微信群請掃碼進群:總結
以上是生活随笔為你收集整理的【深度学习】在PyTorch中构建高效的自定义数据集的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【深度学习】常见优化器的PyTorch实
- 下一篇: 【论文解读】让特征感受野更灵活,腾讯优图