【深度学习】干货!小显存如何训练大模型
之前Kaggle有一個(gè)Jigsaw多語(yǔ)言毒舌評(píng)論分類(lèi)[1]比賽,當(dāng)時(shí)我只有一張11G顯存的1080Ti,根本沒(méi)法訓(xùn)練SOTA的Roberta-XLM-large模型,只能遺憾躺平。在這篇文章中,我將分享一些關(guān)于如何減少訓(xùn)練時(shí)顯存使用的技巧,以便你可以用現(xiàn)有的GPU訓(xùn)練更大的網(wǎng)絡(luò)。
混合精度訓(xùn)練
第一個(gè)可能已經(jīng)普及的技巧是使用混合精度(mixed-precision)訓(xùn)練。當(dāng)訓(xùn)練一個(gè)模型時(shí),一般來(lái)說(shuō)所有的參數(shù)都會(huì)存儲(chǔ)在顯存VRAM中。很簡(jiǎn)單,總的VRAM使用量等于存儲(chǔ)的參數(shù)數(shù)量乘以單個(gè)參數(shù)的VRAM使用量。一個(gè)更大的模型不僅意味著更好的性能,而且也會(huì)使用更多的VRAM。由于性能相當(dāng)重要,比如在Kaggle比賽中,我們不希望減小模型的規(guī)模。因此減少顯存使用的唯一方法是減少每個(gè)變量的內(nèi)存使用。默認(rèn)情況下變量是32位浮點(diǎn)格式,這樣一個(gè)變量就會(huì)消耗4個(gè)字節(jié)。幸運(yùn)的是,人們發(fā)現(xiàn)可以在某些變量上使用16位浮點(diǎn),而不會(huì)損失太多的精度。這意味著我們可以減少一半的內(nèi)存消耗! 此外,使用低精度還可以提高訓(xùn)練速度,特別是在支持Tensor Core的GPU上。
在1.5版本之后,pytorch開(kāi)始支持自動(dòng)混合精度(AMP)訓(xùn)練。該框架可以識(shí)別需要全精度的模塊,并對(duì)其使用32位浮點(diǎn)數(shù),對(duì)其他模塊使用16位浮點(diǎn)數(shù)。下面是Pytorch官方文檔[2]中的一個(gè)示例代碼。
#?Creates?model?and?optimizer?in?default?precision model?=?Net().cuda() optimizer?=?optim.SGD(model.parameters(),?...)#?Creates?a?GradScaler?once?at?the?beginning?of?training. scaler?=?GradScaler()for?epoch?in?epochs:for?input,?target?in?data:optimizer.zero_grad()#?Runs?the?forward?pass?with?autocasting.with?autocast():output?=?model(input)loss?=?loss_fn(output,?target)#?Scales?loss.??Calls?backward()?on?scaled?loss?to?create?scaled?gradients.#?Backward?passes?under?autocast?are?not?recommended.#?Backward?ops?run?in?the?same?dtype?autocast?chose?for?corresponding?forward?ops.scaler.scale(loss).backward()#?scaler.step()?first?unscales?the?gradients?of?the?optimizer's?assigned?params.#?If?these?gradients?do?not?contain?infs?or?NaNs,?optimizer.step()?is?then?called,#?otherwise,?optimizer.step()?is?skipped.scaler.step(optimizer)#?Updates?the?scale?for?next?iteration.scaler.update()梯度積累
第二個(gè)技巧是使用梯度積累。梯度累積的想法很簡(jiǎn)單:在優(yōu)化器更新參數(shù)之前,用相同的模型參數(shù)進(jìn)行幾次前后向傳播。在每次反向傳播時(shí)計(jì)算的梯度被累積(加總)。如果你的實(shí)際batch size是N,而你積累了M步的梯度,你的等效批處理量是N*M。然而,訓(xùn)練結(jié)果不會(huì)是嚴(yán)格意義上的相等,因?yàn)橛行﹨?shù),如Batch Normalization,不能完全累積。
關(guān)于梯度累積,有一些事情需要注意:
當(dāng)你在混合精度訓(xùn)練中使用梯度累積時(shí),scale應(yīng)該為有效批次進(jìn)行校準(zhǔn),scale更新應(yīng)該以有效批次的粒度進(jìn)行。
當(dāng)你在分布式數(shù)據(jù)并行(DDP)訓(xùn)練中使用梯度累積時(shí),使用no_sync()上下文管理器來(lái)禁用前M-1步的梯度全還原,這可以增加訓(xùn)練的速度。
具體的實(shí)現(xiàn)方法可以參考文檔[3]。
梯度檢查點(diǎn)
最后一個(gè),也是最重要的技巧是使用梯度檢查點(diǎn)(Gradient Checkpoint)。Gradient Checkpoint的基本思想是只將一些節(jié)點(diǎn)的中間結(jié)果保存為checkpoint,在反向傳播過(guò)程中對(duì)這些節(jié)點(diǎn)之間的其他部分進(jìn)行重新計(jì)算。據(jù)Gradient Checkpoint的作者說(shuō)[4],在這個(gè)技巧的幫助下,他們可以把10倍大的模型放到GPU上,而計(jì)算時(shí)間只增加20%。Pytorch從0.4.0版本開(kāi)始正式支持這一功能,一些非常常用的庫(kù)如Huggingface Transformers也支持這一功能,而且非常簡(jiǎn)單,只需要下面的兩行代碼:
bert?=?AutoModel.from_pretrained(pretrained_model_name) bert.config.gradient_checkpointing=True實(shí)驗(yàn)
在這篇文章的最后,我想分享之前我在惠普Z(yǔ)4工作站上做的一個(gè)簡(jiǎn)單的benchmark。該工作站配備了2個(gè)24G VRAM的RTX6000 GPU(去年底升級(jí)到2個(gè)48G的A6000了),在實(shí)驗(yàn)中我只用了一個(gè)GPU。我用不同的配置在Kaggle Jigsaw多語(yǔ)言毒舌評(píng)論分類(lèi)比賽的訓(xùn)練集上訓(xùn)練了XLM-Roberta Base/Large,并觀察顯存的使用量,結(jié)果如下。
| Batch size/GPU | 8 | 8 | 16 | 8 | 8 | 8 |
| Mixed-precision | off | on | on | off | on | on |
| gradient checkpointing | off | off | off | off | off | on |
| VRAM usage | 12.28G | 10.95G | 16.96 | OOM | 23.5G | 11.8G |
| one epoch | 70min | 50min | 40min | - | 100min | 110min |
我們可以看到,混合精度訓(xùn)練不僅減少了內(nèi)存消耗,而且還帶來(lái)了顯著的速度提升。梯度檢查點(diǎn)的功能也非常強(qiáng)大。它將VRAM的使用量從23.5G減少到11.8G!
以上就是所有內(nèi)容,希望對(duì)大家有幫助🙂
參考資料
[1]
Jigsaw多語(yǔ)言毒舌評(píng)論分類(lèi): https://www.kaggle.com/c/jigsaw-multilingual-toxic-comment-classification
[2]Pytorch官方文檔: https://pytorch.org/docs/1.8.1/notes/amp_examples.html
[3]gradient-accumulation文檔: https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-accumulation
[4]據(jù)Gradient Checkpoint的作者說(shuō): https://github.com/cybertronai/gradient-checkpointing
往期精彩回顧適合初學(xué)者入門(mén)人工智能的路線(xiàn)及資料下載(圖文+視頻)機(jī)器學(xué)習(xí)入門(mén)系列下載中國(guó)大學(xué)慕課《機(jī)器學(xué)習(xí)》(黃海廣主講)機(jī)器學(xué)習(xí)及深度學(xué)習(xí)筆記等資料打印《統(tǒng)計(jì)學(xué)習(xí)方法》的代碼復(fù)現(xiàn)專(zhuān)輯 AI基礎(chǔ)下載機(jī)器學(xué)習(xí)交流qq群955171419,加入微信群請(qǐng)掃碼:總結(jié)
以上是生活随笔為你收集整理的【深度学习】干货!小显存如何训练大模型的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: gta5显示nat较为严格_一年内上涨近
- 下一篇: 【深度学习】绝了!分割mask生成动漫人