strict=False 但还是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur
strict=False 但還是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur
問題
我們知道通過
model.load_state_dict(state_dict, strict=False)可以暫且忽略掉模型和參數(shù)文件中不匹配的參數(shù),先將正常匹配的參數(shù)從文件中載入模型。
筆者在使用時(shí)遇到了這樣一個(gè)報(bào)錯(cuò):
RuntimeError: Error(s) in loading state_dict for ViT_Aes:size mismatch for mlp_head.1.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint, the shape in current model is torch.Size([10, 768]).size mismatch for mlp_head.1.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([10]).一開始筆者很奇怪,我已經(jīng)寫明strict=False了,不匹配參數(shù)的不管就是了,為什么還要給我報(bào)錯(cuò)。
原因及解決方案
經(jīng)過筆者仔細(xì)打印模型的鍵和文件中的鍵進(jìn)行比對,發(fā)現(xiàn)是這樣的:strict=False可以保證模型中的鍵與文件中的鍵不匹配時(shí)暫且跳過不管,但是一旦模型中的鍵和文件中的鍵匹配上了,PyTorch就會嘗試幫我們加載參數(shù),就必須要求參數(shù)的尺寸相同,所以會有上述報(bào)錯(cuò)。
比如在我們需要將某個(gè)預(yù)訓(xùn)練的模型的最后的全連接層的輸出的類別數(shù)替換為我們自己的數(shù)據(jù)集的類別數(shù),再進(jìn)行微調(diào),有時(shí)會遇到上述情況。這時(shí),我們知道全連接層的參數(shù)形狀會是不匹配,比如我們加載 ImageNet 1K 1000分類的預(yù)訓(xùn)練模型,它的最后一層全連接的輸出維度是1000,但如果我們自己的數(shù)據(jù)集是10分類,我們需要將最后一層全鏈接的輸出維度改為10。但是由于鍵名相同,所以PyTorch還是嘗試給我們加載,這時(shí)1000和10維度不匹配,就會導(dǎo)致報(bào)錯(cuò)。
解決方案就是我們將 .pth 模型文件讀入后,將其中我們不需要的層(通常是最后的全連接層)的參數(shù)pop掉即可。
以 ViT 為例子,假設(shè)我們有一個(gè) ViT 模型,并有一個(gè)參數(shù)文件 vit-in1k.pth,它里面存儲著 ViT 模型在 ImageNet-1K 1000分類數(shù)據(jù)集上訓(xùn)練的參數(shù),而我們要在自己的10分類數(shù)據(jù)集上微調(diào)這個(gè)模型。
model = ViT(num_classes=10) ckpt = torch.load('vit-in1k.pth', map_location='cpu') msg = model.load_state_dict(ckpt, strict=False) print(msg)直接這樣加載會出錯(cuò),就是上面的錯(cuò)誤:
size mismatch for head.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint, the shape in current model is torch.Size([10, 768]).size mismatch for head.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([10]).我們將最后 pth 文件加載進(jìn)來之后(即 ckpt) 中全連接層的參數(shù)直接pop掉,至于需要pop掉哪些鍵名,就是上面報(bào)錯(cuò)信息中提到了的,在這里就是 head.weight 和 head.bias
ckpt.pop('head.weight') ckpt.pop('head.bias')之后在運(yùn)行,會發(fā)現(xiàn)我們打印的 msg 顯示:
_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=[])即缺失了head.weight 和 head.bias 這兩個(gè)參數(shù),這是正常的,因?yàn)樵谧约旱臄?shù)據(jù)集上微調(diào)時(shí),我們本就不需要這兩個(gè)參數(shù),并且已經(jīng)將它們從模型文件字典 ckpt 中pop掉了?,F(xiàn)在,模型全連接之前的層(通常即所謂的特征提取層)的參數(shù)已經(jīng)正常加載了,接下來可以在自己的數(shù)據(jù)集上進(jìn)行微調(diào)。
因?yàn)榉凑覀円膊挥眠@些參數(shù),就直接把這個(gè)鍵值對從字典中pop掉,以免 PyTorch 在幫我們加載時(shí)試圖加載這些維度不匹配,我們也不需要的參數(shù)。
總結(jié)
以上是生活随笔為你收集整理的strict=False 但还是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 天窗是什么意思?
- 下一篇: [分布式训练] 单机多卡的正确打开方式: