【深度好文】多任务模型中的DataLoader实现
對于多任務學習multi-task-learning(MTL)問題,經常會要求特定的訓練過程,比如數據處理,模型結構和性能評估函數.本文主要針對數據處理部分進行展開,主要針對多個標注好的數據集如何來訓練一個多任務模型.
本文主要從兩個方面進行展開:
1.將兩個或多個dataset組合成pytorch中的一個Dataset.這個dataset將會作為pytorch中Dataloader的輸入.
2.修改batch產生過程,以確保在第一個batch中產生第一個任務的數據,在第二個batch中產生下一個任務的數據.
為了簡單處理,我們將以兩個dataset作為例子來講述.通常來說,dataset的數目以及data的類型不會對我們的整體方案帶來太大影響.一個pytorch的Dataset需要實現 __getitem__()函數.這個函數的作用為預取數據并且為給定index準備數據.
第一節 定義dataset
首先,我們先來定義兩個dummy dataset,如下所示:
import torch from torch.utils.data.dataset import ConcatDatasetclass MyFirstDataset(torch.utils.data.Dataset):def __init__(self):# dummy datasetself.samples = torch.cat((-torch.ones(5), torch.ones(5)))def __getitem__(self, index):# change this to your samples fetching logicreturn self.samples[index]def __len__(self):# change this to return number of samples in your datasetreturn self.samples.shape[0]class MySecondDataset(torch.utils.data.Dataset):def __init__(self):# dummy datasetself.samples = torch.cat((torch.ones(50) * 5, torch.ones(5) * -5))def __getitem__(self, index):# change this to your samples fetching logicreturn self.samples[index]def __len__(self):# change this to return number of samples in your datasetreturn self.samples.shape[0]first_dataset = MyFirstDataset() second_dataset = MySecondDataset() concat_dataset = ConcatDataset([first_dataset, second_dataset])上述代碼中,我們定義了兩個dataset,其中第一個dataset長度為10,其中前5個sample為-1,后5個sample為1;其中第二個dataset長度為55,其中前50個sample為5,后5個sample為-5.上述數據集僅僅為了說明方便.在實際應用中,我們應該會同時擁有sample和label,當然我們也可能會從一個目錄或者數據庫中讀取數據,但是上面簡單的dataset足夠幫助我們來了解整個實現流程.
第二節 定義dataloader
接著我們來定義Dataloader,這里我們使用pytorch中的concat_data來實現兩個dataset的合并.
代碼如下:
運行結果如下:
tensor([ 5., 5., 5., 5., -1., 5., 5., 5.]) tensor([ 5., 1., -1., -1., 5., 5., 5., -5.]) tensor([5., 5., 5., 5., 5., 5., 5., 5.]) tensor([ 5., -5., -5., 5., 5., 5., 5., 5.]) tensor([-1., 5., -1., 5., 5., 5., 5., 5.]) tensor([ 5., 5., -5., 5., 5., 5., 5., 1.]) tensor([5., 5., 5., 5., 1., 5., 5., 5.]) tensor([ 5., 1., 5., -5., 5., 5., 1., 5.])對于我們的concat_dataset來說,每個batch有8個sample.每個sample的次序是隨機的.
第三節 定義sampler
到現在為止,上述實現都很簡單直接.上述dataset被合并成一個dataset,并且sample都是從原先dataset中隨機挑選組合成batch的.現在讓我們來寫控制每個batch中的sample來源.我們預期達到的目的在每一個batch中,數據僅來自一個task的dataset,在下一個batch中進行切換.此時我們需要自己定義sample,其代碼實現如下:
import math import torch from torch.utils.data.sampler import RandomSamplerclass BatchSchedulerSampler(torch.utils.data.sampler.Sampler):"""iterate over tasks and provide a random batch per task in each mini-batch"""def __init__(self, dataset, batch_size):self.dataset = datasetself.batch_size = batch_sizeself.number_of_datasets = len(dataset.datasets)self.largest_dataset_size = max([len(cur_dataset.samples) for cur_dataset in dataset.datasets])def __len__(self):return self.batch_size * math.ceil(self.largest_dataset_size / self.batch_size) * len(self.dataset.datasets)def __iter__(self):samplers_list = []sampler_iterators = []for dataset_idx in range(self.number_of_datasets):cur_dataset = self.dataset.datasets[dataset_idx]sampler = RandomSampler(cur_dataset)samplers_list.append(sampler)cur_sampler_iterator = sampler.__iter__()sampler_iterators.append(cur_sampler_iterator)push_index_val = [0] + self.dataset.cumulative_sizes[:-1]step = self.batch_size * self.number_of_datasetssamples_to_grab = self.batch_size# for this case we want to get all samples in dataset, this force us to resample from the smaller datasetsepoch_samples = self.largest_dataset_size * self.number_of_datasetsfinal_samples_list = [] # this is a list of indexes from the combined datasetfor _ in range(0, epoch_samples, step):for i in range(self.number_of_datasets):cur_batch_sampler = sampler_iterators[i]cur_samples = []for _ in range(samples_to_grab):try:cur_sample_org = cur_batch_sampler.__next__()cur_sample = cur_sample_org + push_index_val[i]cur_samples.append(cur_sample)except StopIteration:# got to the end of iterator - restart the iterator and continue to get samples# until reaching "epoch_samples"sampler_iterators[i] = samplers_list[i].__iter__()cur_batch_sampler = sampler_iterators[i]cur_sample_org = cur_batch_sampler.__next__()cur_sample = cur_sample_org + push_index_val[i]cur_samples.append(cur_sample)final_samples_list.extend(cur_samples)return iter(final_samples_list)上述定義了一個BatchSchedulerSampler類,實現了一個新的sampler iterator.首先,通過為每一個單獨的dataset創建RandomSampler;接著,在每一個dataset iter中獲取對應的sample index;最后,創建新的sample index list.這里我們使用batchsize=8,那么我們將會從每個dataset中預取8個samples.
接著我們來測試上述sampler,代碼如下:
運行結果如下:
tensor([ 1., -1., 1., 1., -1., -1., -1., 1.]) tensor([ 5., 5., 5., 5., 5., -5., 5., -5.]) tensor([ 1., -1., -1., -1., -1., 1., 1., 1.]) tensor([5., 5., 5., 5., 5., 5., 5., 5.]) tensor([ 1., 1., -1., -1., 1., 1., 1., 1.]) tensor([5., 5., 5., 5., 5., 5., 5., 5.]) tensor([-1., 1., -1., -1., -1., -1., 1., -1.]) tensor([-5., 5., 5., 5., 5., 5., 5., 5.]) tensor([ 1., -1., 1., -1., -1., 1., -1., 1.]) tensor([ 5., -5., 5., 5., 5., 5., 5., 5.]) tensor([-1., -1., 1., -1., 1., -1., -1., 1.]) tensor([ 5., 5., 5., -5., 5., 5., 5., 5.]) tensor([ 1., 1., -1., -1., 1., 1., 1., 1.]) tensor([5., 5., 5., 5., 5., 5., 5., 5.])Wow,綜上,我們實現了每一個minibatch僅從一個dataset中取數據的功能,并且下一個minibatch從不同任務的dataset中取batch.
參考:鏈接
關注公眾號《AI算法之道》,獲取更多AI算法資訊.
總結
以上是生活随笔為你收集整理的【深度好文】多任务模型中的DataLoader实现的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 空间数据挖掘与空间大数据的探索与思考(三
- 下一篇: 顺丰快递业务接入API总览-快递鸟