Facebook 发布深度学习工具包 PyTorch Hub,让论文复现变得更容易
近日,PyTorch 社區(qū)發(fā)布了一個(gè)深度學(xué)習(xí)工具包 PyTorchHub, 幫助機(jī)器學(xué)習(xí)工作者更快實(shí)現(xiàn)重要論文的復(fù)現(xiàn)工作。PyTorchHub 由一個(gè)預(yù)訓(xùn)練模型倉庫組成,專門用于提高研究工作的復(fù)現(xiàn)性以及新的研究。同時(shí)它還內(nèi)置了對(duì)Google Colab的支持,并與Papers With Code集成。目前 PyTorchHub 包括了一系列與圖像分類、分割、生成以及轉(zhuǎn)換相關(guān)的模型。
可復(fù)現(xiàn)性是許多研究領(lǐng)域的基本要求,這其中當(dāng)然包括基于機(jī)器學(xué)習(xí)技術(shù)的研究領(lǐng)域。然而, 許多機(jī)器學(xué)習(xí)相關(guān)論文要么無法復(fù)現(xiàn),要么難以重現(xiàn)。隨著論文數(shù)量的持續(xù)增長(zhǎng),包括目前在 arXiv 上預(yù)印刷的數(shù)萬份論文以及提交給會(huì)議的論文,研究工作的可復(fù)現(xiàn)性變得越來越重要。雖然其中許多論文都附有代碼以及訓(xùn)練好的模型,但這種幫助顯然非常有限,復(fù)現(xiàn)過程中仍有大量需要讀者自己摸索的步驟。下面讓我們來看一下如何通過 PyTorch Hub 這一利器完成快速的模型發(fā)布與工作復(fù)現(xiàn)。
如何快速發(fā)布模型
這部分主要介紹了對(duì)于模型發(fā)布者來說如何快速高效的將自己的模型加入 PyTorch Hub 庫。PyTorch Hub 支持通過添加簡(jiǎn)單的 hubconf.py 文件將預(yù)先訓(xùn)練的模型(模型定義和預(yù)先訓(xùn)練重)發(fā)布到 GitHub 存儲(chǔ)庫。這提供了模型列表以及其依賴庫列表。一些示例可以在torchvision,huggingface-bert和gan-model-zoo存儲(chǔ)庫中找到。
Pytoch 社區(qū)給出了 torchvision 的 hubconf.py 文件的示例:
復(fù)制代碼| ? | # Optional list of dependencies required by the package |
| ? | dependencies = ['torch'] |
| ? | ? |
| ? | from torchvision.models.alexnet import alexnet |
| ? | from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161 |
| ? | from torchvision.models.inception import inception_v3 |
| ? | from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d, resnext101_32x8d |
| ? | from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1 |
| ? | from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn |
| ? | from torchvision.models.segmentation import fcn_resnet101, deeplabv3_resnet101 |
| ? | from torchvision.models.googlenet import googlenet |
| ? | from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0 |
| ? | from torchvision.models.mobilenet import mobilenet_v2 |
在 torchvision 中,模型有以下特性:
- 每個(gè)模型文件可以被獨(dú)立執(zhí)行或?qū)崿F(xiàn)某個(gè)功能
- 不需要除了 PyTorch 之外的任何軟件包(在 hubconf.py 中編碼為 dependencies[‘torch’])
- 他們不需要單獨(dú)的入口點(diǎn),因?yàn)槟P驮趧?chuàng)建時(shí)可以無縫地開箱即用。
PyTroch 社區(qū)認(rèn)為最小化包依賴性可減少用戶加載模型時(shí)遇到的困難。這里他們給出了一個(gè)更為復(fù)雜的例子——HuggingFace’s BERT 模型,它的 hubconf.py 如下:
復(fù)制代碼| ? | dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex'] |
| ? | ? |
| ? | from hubconfs.bert_hubconf import ( |
| ? | bertTokenizer, |
| ? | bertModel, |
| ? | bertForNextSentencePrediction, |
| ? | bertForPreTraining, |
| ? | bertForMaskedLM, |
| ? | bertForSequenceClassification, |
| ? | bertForMultipleChoice, |
| ? | bertForQuestionAnswering, |
| ? | bertForTokenClassification |
| ? | ) |
此外,對(duì)于每個(gè)模型,PyTorch 官方提到都需要為其創(chuàng)建一個(gè)入口點(diǎn)。下面是一個(gè)用于指定 bertForMaskedLM 模型的入口點(diǎn)的代碼片段,這部分代碼完成的功能是返回加載了預(yù)訓(xùn)練參數(shù)的模型。
復(fù)制代碼| ? | def bertForMaskedLM(*args, **kwargs): |
| ? | """ |
| ? | BertForMaskedLM includes the BertModel Transformer followed by the |
| ? | pre-trained masked language modeling head. |
| ? | Example: |
| ? | ... |
| ? | """ |
| ? | model = BertForMaskedLM.from_pretrained(*args, **kwargs) |
| ? | return model |
這些入口點(diǎn)可以看成是復(fù)雜的模型結(jié)構(gòu)的一種封裝形式。它們可以在提供簡(jiǎn)潔高效的幫助文檔的同時(shí)完成下載預(yù)訓(xùn)練權(quán)重的功能(例如,通過 pretrained = True),也可以集成其他特定功能,例如可視化。
通過 hubconf.py,模型發(fā)布者可以在 Github 上基于template提交他們的合并請(qǐng)求。PyTorch 社區(qū)希望通過 PyTorch Hub 創(chuàng)建一系列高質(zhì)量、易復(fù)現(xiàn)且效果好的模型以提高研究工作的復(fù)現(xiàn)性。因此,PyTorch 會(huì)通過與模型發(fā)布者合作的方式以完善請(qǐng)求,并有可能會(huì)在某些情況下拒絕發(fā)布一些低質(zhì)量的模型。一旦 PyTorch 社區(qū)接受了模型發(fā)布者的請(qǐng)求,這些新的模型將會(huì)很快出現(xiàn)在 PyTorch Hub 的網(wǎng)頁上以供用戶瀏覽。
用戶工作流
對(duì)于想使用 PyTorch Hub 對(duì)別人的工作進(jìn)行復(fù)現(xiàn)的用戶,PyTorch Hub 提供了以下幾個(gè)步驟:1)瀏覽可用的模型;2)加載模型;3)探索已加載的模型。下面讓我們來瀏覽幾個(gè)例子。
瀏覽可用的入口點(diǎn)
用戶可以使用 torch.hub.list() API 列出倉庫中的所有可用入口點(diǎn)。
復(fù)制代碼| ? | >>> torch.hub.list('pytorch/vision') |
| ? | >>> |
| ? | ['alexnet', |
| ? | 'deeplabv3_resnet101', |
| ? | 'densenet121', |
| ? | ... |
| ? | 'vgg16', |
| ? | 'vgg16_bn', |
| ? | 'vgg19', |
| ? | 'vgg19_bn'] |
注意,PyTorch Hub 還允許輔助入口點(diǎn)(除了預(yù)訓(xùn)練模型),例如,用于 BERT 模型預(yù)處理的 bertTokenizer,它可以使用戶工作流程更加順暢。
加載模型
對(duì)于 PyTroch Hub 中可用的模型,用戶可以使用 torch.hub.load() API 加載模型入口點(diǎn)。此外,torch.hub.help() API 可以提供有關(guān)如何實(shí)例化模型的有用信息。
復(fù)制代碼| ? | print(torch.hub.help('pytorch/vision', 'deeplabv3_resnet101')) |
| ? | model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=True) |
由于倉庫的持有者會(huì)不斷添加錯(cuò)誤修復(fù)以及性能改進(jìn),PyTorch Hub 允許用戶通過調(diào)用以下內(nèi)容簡(jiǎn)單地獲取最新更新:
復(fù)制代碼| ? | model = torch.hub.load(..., force_reload=True) |
這一舉措可以有效地減輕倉庫持有者重復(fù)發(fā)布模型的負(fù)擔(dān),從而使他們能夠更專注于自己的研究工作。同時(shí),也確保了用戶可以獲得最新版本的模型。
此外,對(duì)于用戶來說,穩(wěn)定性也是一個(gè)重要問題。因此,某些模型所有者會(huì)從特征的分支或標(biāo)簽為他們提供服務(wù),以確保代碼的穩(wěn)定性。例如,pytorch_GAN_zoo 會(huì)從 hub 分支為他們提供服務(wù):
復(fù)制代碼| ? | model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', pretrained=True, useGPU=False) |
這里,傳遞給 hub.load() 的 * args,** kwargs 用于實(shí)例化模型。在上面的示例中,pretrained = True 和 useGPU = False 被傳遞給模型的入口點(diǎn)。
探索已加載的模型
從 PyTorch Hub 加載模型后,用戶可以使用以下工作流查看已加載模型的可用方法,并更好地了解運(yùn)行它所需的參數(shù)。
其中,dir(model) 可以查看模型中可用的方法。下面是 bertForMaskedLM 的一些方法:
復(fù)制代碼| ? | >>> dir(model) |
| ? | >>> |
| ? | ['forward' |
| ? | ... |
| ? | 'to' |
| ? | 'state_dict', |
| ? | ] |
help(model.forward)則會(huì)提供使已加載的模型運(yùn)行時(shí)所需參數(shù)的視圖:
復(fù)制代碼| ? | >>> help(model.forward) |
| ? | >>> |
| ? | Help on method forward in module pytorch_pretrained_bert.modeling: |
| ? | forward(input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None) |
| ? | ... |
更多細(xì)節(jié)可以查看BERT和DeepLabV3頁面:
- BERT:https://pytorch.org/hub/huggingface_pytorch-pretrained-bert_bert/
- DeepLabV3:https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/
其他探索方式與相關(guān)資源
PyTorch Hub 中提供的模型也支持 Colab,并且會(huì)直接鏈接在 Papers With Code 上,用戶只需單擊鏈接即可開始使用:
PyTorch 提供了一些相關(guān)資源幫助用戶快速上手 PyTorch Hub:
- PyTorch Hub API 手冊(cè)鏈接:https://pytorch.org/docs/stable/hub.html
- 模型提交地址:https://github.com/pytorch/hub
- 瀏覽 PyTorch Hub 網(wǎng)頁以學(xué)習(xí)更多可用模型:https://pytorch.org/hub
- 在 Paper with Code 上瀏覽更多模型:https://paperswithcode.com/
FAQ
問:如果我們想貢獻(xiàn)一個(gè) Hub 中已經(jīng)有了的模型,但也許我的模型具有更高的準(zhǔn)確性,我還應(yīng)該貢獻(xiàn)嗎?
答:是的,請(qǐng)?zhí)峤荒哪P?#xff0c;Hub 的下一步是開發(fā)投票系統(tǒng)以展示最佳模型。
問:誰負(fù)責(zé)保管 PyTorch Hub 的模型權(quán)重?
答:作為貢獻(xiàn)者,您負(fù)責(zé)保管模型權(quán)重。您可以在您喜歡的云存儲(chǔ)中托管您的模型,或者如果它符合限制,則可以在 GitHub 上托管您的模型。 如果您無法保管權(quán)重,請(qǐng)通過 Hub 倉庫中提交問題的方式與我們聯(lián)系。
問:如果我的模型使用了私有化數(shù)據(jù)進(jìn)行訓(xùn)練怎么辦?我還應(yīng)該貢獻(xiàn)這個(gè)模型嗎?
答:請(qǐng)不要提交您的模型!PyTorch Hub 以開源研究為中心,并擴(kuò)展到使用公開數(shù)據(jù)集來訓(xùn)練這些模型。如果提交了私有模型的合并請(qǐng)求,我們將懇請(qǐng)您重新提交使用公開數(shù)據(jù)進(jìn)行訓(xùn)練后的模型。
問:我下載的模型保存在哪里?
答:我們遵循 XDG 基本目錄規(guī)范,并遵循緩存文件和目錄的通用標(biāo)準(zhǔn)。這些位置按以下順序使用:
- 調(diào)用 hub.set_dir(<PATH_TO_HUB_DIR>)
- 如果環(huán)境變量了 TORCH_HOME,則為 $TORCH_HOME/hub。
- 如果設(shè)置了環(huán)境變量 XDG_CACHE_HOME,則為 $ XDG_CACHE_HOME / torch / hub。
- ~/.cache/torch/hub
相關(guān)推薦:
-
- Mxnet 計(jì)算機(jī)視覺深度學(xué)習(xí)工具包 GluonCV
- Tensorflow 復(fù)現(xiàn)工具
總結(jié)
以上是生活随笔為你收集整理的Facebook 发布深度学习工具包 PyTorch Hub,让论文复现变得更容易的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Transformer的PyTorch实
- 下一篇: 独家 | TensorFlow 2.0将