tensorflow 显存 训练_【他山之石】训练时显存优化技术——OP合并与gradient checkpoint...
地址:http://bindog.github.io/
01
背景
前幾天看到知乎上的文章FLOPs與模型推理速度[1],文中提到一個(gè)比較耗時(shí)又占顯存的pointwise操作x * sigmoid(x),這實(shí)際上是swish activation[2];暫且不提它背后的爭議,本文主要想從這個(gè)結(jié)構(gòu)入手來優(yōu)化它的顯存占用以及耗時(shí),并討論更廣泛的訓(xùn)練時(shí)顯存優(yōu)化技術(shù)。02
反向傳播是如何工作的?
要分析清楚swish activation為什么會(huì)比較占顯存,我們首先需要搞清楚反向傳播是如何工作的,或者更進(jìn)一步說,現(xiàn)有的自動(dòng)求導(dǎo)框架是如何求出梯度的。先明確一點(diǎn),所謂自動(dòng)求導(dǎo)框架實(shí)際上是“半自動(dòng)”的:它并非直接求出一個(gè)復(fù)雜函數(shù)導(dǎo)數(shù)的解析形式,而是通過構(gòu)建計(jì)算圖和預(yù)先寫好的基礎(chǔ)函數(shù)的求導(dǎo)規(guī)則,結(jié)合鏈?zhǔn)角髮?dǎo)法則實(shí)現(xiàn)的自動(dòng)求導(dǎo)。以swish acivation為例進(jìn)行說明,其表達(dá)式為f(x) = x * sigmoid(x),通過簡單的數(shù)學(xué)推導(dǎo)得到其梯度的解析式為f'(x) = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x));先把這個(gè)結(jié)果放一邊,看看自動(dòng)求導(dǎo)框架是如何一步步求出這個(gè)結(jié)果的,畫出計(jì)算圖如下:除了計(jì)算圖以外,我們還需要定義幾個(gè)基本函數(shù)的求導(dǎo)規(guī)則,在這個(gè)例子里涉及兩個(gè)函數(shù),一個(gè)是乘法,另一個(gè)是sigmoid函數(shù)(實(shí)際上sigmoid也是由幾個(gè)基本函數(shù)構(gòu)成的,這里我們將其視為一個(gè)整體)f(x, y) = x * y# gradient for x: y# gradient for y: xg(x) = sigmoid(x) # 1 / (1 + exp(-x))# gradient for x: sigmoid(x) * (1 - sigmoid(x))03
顯存被誰吃掉了
先說一個(gè)結(jié)論,在絕大多數(shù)神經(jīng)網(wǎng)絡(luò)的訓(xùn)練過程中,顯存占用的大頭是中間結(jié)果,也就是所謂的“特征圖”。那我們?yōu)槭裁匆A糁虚g結(jié)果呢?當(dāng)然是為了方便求導(dǎo)啊!還是以swish acivation為例,把它放入神經(jīng)網(wǎng)絡(luò)來看,x就是前一層輸出的中間結(jié)果(特征圖)- 在適用乘法的求導(dǎo)規(guī)則時(shí),要求我們要事先保留下中間結(jié)果x和sigmoid(x),有人可能會(huì)說只保留一個(gè)x不就可以了嗎?sigmoid(x)可以通過計(jì)算得出,注意框架定義的乘法及其求導(dǎo)規(guī)則是通用規(guī)則,乘法的左右兩邊完全可能是不相關(guān)的兩個(gè)值,所以必須同時(shí)保留下來。
- 在對sigmoid函數(shù)適用求導(dǎo)規(guī)則時(shí),需要存下中間結(jié)果x。
04
手動(dòng)合并OP
那么有沒有辦法優(yōu)化呢?當(dāng)然是可以的,既然我們能用數(shù)學(xué)公式提前算出swish acivation的梯度,那么直接將其視為一個(gè)整體不就好了?無非就是定義一個(gè)新的函數(shù)和新的求導(dǎo)規(guī)則
swish(x) = x * sigmoid(x)# gradient for x: sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))這樣一來,計(jì)算圖變成了下面這個(gè)樣子:
x的梯度可以直接根據(jù)新的規(guī)則求出,而在新的規(guī)則下,我們只需要保留x這一個(gè)中間結(jié)果即可,sigmoid(x)可以根據(jù)x求出。所以說,自動(dòng)求導(dǎo)框架雖然省事,但是其缺陷也很明顯,由于大部分求導(dǎo)規(guī)則是面向通用的函數(shù),很難針對特定的場景進(jìn)行自動(dòng)優(yōu)化而導(dǎo)致顯存浪費(fèi)。對swish acivation這樣的函數(shù),只能依靠工程師的經(jīng)驗(yàn)手動(dòng)的進(jìn)行優(yōu)化。需要指出的是,現(xiàn)有的一些框架如TVM和TensorRT也能自動(dòng)的對某些算子進(jìn)行融合,進(jìn)而大大提高計(jì)算效率,降低顯存消耗,但是這些都屬于部署階段了,而本文討論的均為訓(xùn)練階段。類似的優(yōu)化案例還有inplace-abn[3],針對的是類似BN-ReLU-Conv這樣的常見結(jié)構(gòu)組合,如下所示圖中的虛線框是需要保留的中間結(jié)果,inplace-abn的優(yōu)化思路是只保留中間結(jié)果z,通過反推得到x,然而眾所周知ReLU是不可逆的運(yùn)算,因此inplace-abn將其替換為了Leaky ReLU,計(jì)算圖變成了如下形式:接下來的事情就是用數(shù)學(xué)的方式手動(dòng)求出導(dǎo)數(shù),然后定義成規(guī)則即可。對II型,更進(jìn)一步,直接用的反函數(shù),進(jìn)行替換即可
雖然推導(dǎo)過程有些復(fù)雜,但寫出求導(dǎo)公式后,我們只需要將其封裝進(jìn)手寫的模塊中即可。原論文[4]中的實(shí)現(xiàn)表明,采用Inplace-abn后,顯存占用最高可下降50%左右,而且由于Leaky ReLU實(shí)際效果其實(shí)與ReLU非常接近,省下來的顯存可以用于提高batch_size,模型訓(xùn)練實(shí)際上能從中得到更大收益。
05
還能更進(jìn)一步嗎?
回想前面的優(yōu)化過程,我們發(fā)現(xiàn)其實(shí)這是一種典型的時(shí)間換空間的做法,雖然模型占用的顯存下降了(舍棄了大量中間結(jié)果),但是我們定義的求導(dǎo)規(guī)則非常復(fù)雜,計(jì)算步驟明顯多于優(yōu)化前,其根本原因并非是不需要中間結(jié)果,而是有辦法在求導(dǎo)過程中實(shí)時(shí)的計(jì)算出之前被舍棄掉的中間結(jié)果。考慮GPU上顯存資源與計(jì)算資源的關(guān)系,只用較少的計(jì)算量和額外的一點(diǎn)計(jì)算時(shí)間換取寶貴的顯存資源,這么做實(shí)際上是劃算的。如果沿著這個(gè)思路更進(jìn)一步,所有的中間結(jié)果都不需要存儲(chǔ)了,只需要存最初的輸入即可,因?yàn)?strong>所有的中間結(jié)果都可以由輸入重新計(jì)算得到,然而這個(gè)方案顯然是不劃算的,因?yàn)榉聪騻鞑サ倪^程是“由深入淺”,而計(jì)算中間結(jié)果的過程是“由淺入深”,二者的方向并不匹配,每當(dāng)我們需要中間結(jié)果時(shí)就需要從頭再來一遍,這樣的計(jì)算和時(shí)間開銷顯然是不劃算的。如果折中一下呢?這就是OpenAI提出的gradient-checkpoint的思路,在神經(jīng)網(wǎng)絡(luò)中間設(shè)置若干個(gè)檢查點(diǎn)(checkpoint),檢查點(diǎn)以外的中間結(jié)果全部舍棄,反向傳播求導(dǎo)數(shù)的時(shí)間,需要某個(gè)中間結(jié)果時(shí),從最近的檢查點(diǎn)開始計(jì)算,這樣既節(jié)省了顯存,又避免了從頭計(jì)算的繁瑣過程;從代碼層面來看,原版實(shí)現(xiàn)[5]用的是tensorflow,由于是靜態(tài)圖的緣故,需要用到grapheditor等一系列騷操作,而且包含了很多“智能”尋找bottleneck選擇為checkpoint的代碼,很容易勸退新人。但是如果看一下pytorch的官方實(shí)現(xiàn)[6],你會(huì)驚訝的發(fā)現(xiàn)gradient-checkpoint的核心部分出奇的簡單,這也算是動(dòng)態(tài)圖以及pytorch的一點(diǎn)小優(yōu)勢吧,當(dāng)然pytorch版本的實(shí)現(xiàn)并不包括智能尋找checkpoint點(diǎn)的功能,需要人為設(shè)定。核心代碼如下所示:class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): check_backward_validity(args) ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state if preserve_rng_state: ctx.fwd_cpu_state = torch.get_rng_state() # Don't eagerly initialize the cuda context by accident. # (If the user intends that the context is initialized later, within their # run_function, we SHOULD actually stash the cuda state here. Unfortunately, # we have no way to anticipate this will happen before we run the function.) ctx.had_cuda_in_fwd = False if torch.cuda._initialized: ctx.had_cuda_in_fwd = True ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) ctx.save_for_backward(*args) with torch.no_grad(): outputs = run_function(*args) return outputs @staticmethod def backward(ctx, *args): if not torch.autograd._is_checkpoint_valid(): raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") inputs = ctx.saved_tensors # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state # when we're done. rng_devices = [] if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: rng_devices = ctx.fwd_gpu_devices with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): if ctx.preserve_rng_state: torch.set_rng_state(ctx.fwd_cpu_state) if ctx.had_cuda_in_fwd: set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) detached_inputs = detach_variable(inputs) with torch.enable_grad(): outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor): outputs = (outputs,) torch.autograd.backward(outputs, args) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) return (None, None) + grads注意到最近曠視開源的MegEngine,在PR的時(shí)候提到一個(gè)亞線性顯存優(yōu)化技術(shù)[7],其實(shí)就是gradient-checkpoint技術(shù),詳情可參考論文Training Deep Nets with Sublinear Memory Cost[8],當(dāng)然MegEngine肯定在細(xì)節(jié)上對其進(jìn)行了一些優(yōu)化,本文就不展開討論了。06
CUDA版的swish activation
回到swish activation的優(yōu)化上來,如果要追求效率的極致提升,下一步考慮的方案應(yīng)該是手寫C++ extension,將計(jì)算從python層面轉(zhuǎn)移到C++與CUDA上。如何基于 pytorch寫C++擴(kuò)展,官方文檔上有非常詳細(xì)的教程[9],寫法和方式也都比較靈活,可以根據(jù)自己的習(xí)慣進(jìn)行選擇,這里我們選擇利用setuptools的方式進(jìn)行構(gòu)建pytorch在用戶自定義擴(kuò)展上也是做了非常多的支持,用戶能非常方便的使用pytorch底層定義好的一些類和函數(shù);在寫CUDA函數(shù)時(shí),pytorch還提供了一個(gè)CUDAApplyUtils.cuh頭文件,專門用于優(yōu)化pointwise操作的情況,以減小拷貝和臨時(shí)存儲(chǔ)的顯存浪費(fèi)(用于lambda函數(shù),函數(shù)名非常直觀,CUDA_tensor_applyN表示操作數(shù)的個(gè)數(shù),N可以為1,2,3,4,用戶還可以指定每個(gè)操作數(shù)的屬性,如只讀/讀寫,針對每對情形都有專門的優(yōu)化實(shí)現(xiàn))對于swish activation來說,由于全是pointwise操作,利用這個(gè)優(yōu)化技巧可以把顯存占用進(jìn)一步壓縮。具體代碼可參考swish_optimize[10]簡單對比一下以上幾種實(shí)現(xiàn)在實(shí)際場景中(單卡RTX 2070,resnet50, bs=32)的顯存占用情況和運(yùn)行時(shí)間(一次forward & 一次backward & 參數(shù)更新)- 無優(yōu)化純Python:GPU memory=6383MB,time=223ms
- 合并算子(Python):GPU memory=5139MB,time=234ms
- 合并算子(CUDA):GPU memory=5143MB,time=188ms
[1] https://zhuanlan.zhihu.com/p/122943688
[2] https://arxiv.org/abs/1710.05941
[3] https://github.com/mapillary/inplace_abn
[4] https://arxiv.org/pdf/1712.02616.pdf
[5] https://github.com/cybertronai/gradient-checkpointing
[6] https://github.com/pytorch/pytorch/blob/176174a68ba2d36b9a5aaef0943421682ecc66d4/torch/utils/checkpoint.py#L55
[7] https://zhuanlan.zhihu.com/p/138730559
[8] https://arxiv.org/abs/1604.06174
[9] https://pytorch.org/tutorials/advanced/cpp_extension.html
[10] https://github.com/bindog/swish_optimize
本文目的在于學(xué)術(shù)交流,并不代表本公眾號贊同其觀點(diǎn)或?qū)ζ鋬?nèi)容真實(shí)性負(fù)責(zé),版權(quán)歸原作者所有,如有侵權(quán)請告知?jiǎng)h除。
直播預(yù)告
歷史文章推薦
【CVPR 2020 Tutorial】如何寫好論文和評審(概述)
如何撰寫高水平的博士論文?超全論文指導(dǎo)!
北大讀博手記:怎樣完成自己的博士生涯?非常具有指導(dǎo)性!
太牛逼了!一位中國博士把整個(gè)CNN都給可視化了,每個(gè)細(xì)節(jié)看的清清楚楚!
Nature發(fā)表牛津博士建議:我希望在讀博士之初時(shí)就能知道的20件事
沈向洋、華剛:讀科研論文的三個(gè)層次、四個(gè)階段與十個(gè)問題
如何看待2021年秋招算法崗灰飛煙滅?
獨(dú)家解讀 | ExprGAN:基于強(qiáng)度可控的表情編輯
獨(dú)家解讀 | 矩陣視角下的BP算法
獨(dú)家解讀 | Capsule Network深度解讀
獨(dú)家解讀 | Fisher信息度量下的對抗攻擊
論文解讀 | 知識圖譜最新研究綜述
你的畢業(yè)論文過了嗎?《如何撰寫畢業(yè)論文?》
卡爾曼濾波系列——經(jīng)典卡爾曼濾波推導(dǎo)
分享、點(diǎn)贊、在看,給個(gè)三連擊唄!
總結(jié)
以上是生活随笔為你收集整理的tensorflow 显存 训练_【他山之石】训练时显存优化技术——OP合并与gradient checkpoint...的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python收入波动告警分析_使用Pyt
- 下一篇: 标题隐藏_头条官方课程没看就想起好标题?