日韩av黄I国产麻豆传媒I国产91av视频在线观看I日韩一区二区三区在线看I美女国产在线I麻豆视频国产在线观看I成人黄色短片

歡迎訪問 生活随笔!

生活随笔

當前位置: 首頁 >

tensorflow 显存 训练_【他山之石】训练时显存优化技术——OP合并与gradient checkpoint...

發布時間:2025/3/19 37 豆豆
生活随笔 收集整理的這篇文章主要介紹了 tensorflow 显存 训练_【他山之石】训练时显存优化技术——OP合并与gradient checkpoint... 小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
作者:bindog

地址:http://bindog.github.io/

01

背景

前幾天看到知乎上的文章FLOPs與模型推理速度[1],文中提到一個比較耗時又占顯存的pointwise操作x * sigmoid(x),這實際上是swish activation[2];暫且不提它背后的爭議,本文主要想從這個結構入手來優化它的顯存占用以及耗時,并討論更廣泛的訓練時顯存優化技術。

02

反向傳播是如何工作的?

要分析清楚swish activation為什么會比較占顯存,我們首先需要搞清楚反向傳播是如何工作的,或者更進一步說,現有的自動求導框架是如何求出梯度的。先明確一點,所謂自動求導框架實際上是“半自動”的:它并非直接求出一個復雜函數導數的解析形式,而是通過構建計算圖和預先寫好的基礎函數的求導規則,結合鏈式求導法則實現的自動求導。以swish acivation為例進行說明,其表達式為f(x) = x * sigmoid(x),通過簡單的數學推導得到其梯度的解析式為f'(x) = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x));先把這個結果放一邊,看看自動求導框架是如何一步步求出這個結果的,畫出計算圖如下:除了計算圖以外,我們還需要定義幾個基本函數的求導規則,在這個例子里涉及兩個函數,一個是乘法,另一個是sigmoid函數(實際上sigmoid也是由幾個基本函數構成的,這里我們將其視為一個整體)f(x, y) = x * y# gradient for x: y# gradient for y: xg(x) = sigmoid(x) # 1 / (1 + exp(-x))# gradient for x: sigmoid(x) * (1 - sigmoid(x))

03

顯存被誰吃掉了

先說一個結論,在絕大多數神經網絡的訓練過程中,顯存占用的大頭是中間結果,也就是所謂的“特征圖”。那我們為什么要保留中間結果呢?當然是為了方便求導啊!還是以swish acivation為例,把它放入神經網絡來看,x就是前一層輸出的中間結果(特征圖)
  • 在適用乘法的求導規則時,要求我們要事先保留下中間結果x和sigmoid(x),有人可能會說只保留一個x不就可以了嗎?sigmoid(x)可以通過計算得出,注意框架定義的乘法及其求導規則是通用規則,乘法的左右兩邊完全可能是不相關的兩個值,所以必須同時保留下來。
  • 在對sigmoid函數適用求導規則時,需要存下中間結果x。
在不考慮框架自身優化的情況下,顯存占用就包括了兩個x和一個sigmoid(x),注意x可不是一個單獨的數值,而是類似32x32x128這樣大小的特征圖,考慮到swish acivation在網絡中數量龐大,每出現一次就意味著巨大的顯存浪費。

04

手動合并OP

那么有沒有辦法優化呢?當然是可以的,既然我們能用數學公式提前算出swish acivation的梯度,那么直接將其視為一個整體不就好了?無非就是定義一個新的函數和新的求導規則

swish(x) = x * sigmoid(x)# gradient for x: sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))

這樣一來,計算圖變成了下面這個樣子:

x的梯度可以直接根據新的規則求出,而在新的規則下,我們只需要保留x這一個中間結果即可,sigmoid(x)可以根據x求出。所以說,自動求導框架雖然省事,但是其缺陷也很明顯,由于大部分求導規則是面向通用的函數,很難針對特定的場景進行自動優化而導致顯存浪費。對swish acivation這樣的函數,只能依靠工程師的經驗手動的進行優化。需要指出的是,現有的一些框架如TVM和TensorRT也能自動的對某些算子進行融合,進而大大提高計算效率,降低顯存消耗,但是這些都屬于部署階段了,而本文討論的均為訓練階段。類似的優化案例還有inplace-abn[3],針對的是類似BN-ReLU-Conv這樣的常見結構組合,如下所示

圖中的虛線框是需要保留的中間結果,inplace-abn的優化思路是只保留中間結果z,通過反推得到x,然而眾所周知ReLU是不可逆的運算,因此inplace-abn將其替換為了Leaky ReLU,計算圖變成了如下形式:

接下來的事情就是用數學的方式手動求出導數,然后定義成規則即可。

對II型,更進一步,直接用的反函數,進行替換即可

雖然推導過程有些復雜,但寫出求導公式后,我們只需要將其封裝進手寫的模塊中即可。原論文[4]中的實現表明,采用Inplace-abn后,顯存占用最高可下降50%左右,而且由于Leaky ReLU實際效果其實與ReLU非常接近,省下來的顯存可以用于提高batch_size,模型訓練實際上能從中得到更大收益。

05

還能更進一步嗎?

回想前面的優化過程,我們發現其實這是一種典型的時間換空間的做法,雖然模型占用的顯存下降了(舍棄了大量中間結果),但是我們定義的求導規則非常復雜,計算步驟明顯多于優化前,其根本原因并非是不需要中間結果,而是有辦法在求導過程中實時的計算出之前被舍棄掉的中間結果。考慮GPU上顯存資源與計算資源的關系,只用較少的計算量和額外的一點計算時間換取寶貴的顯存資源,這么做實際上是劃算的。如果沿著這個思路更進一步,所有的中間結果都不需要存儲了,只需要存最初的輸入即可,因為所有的中間結果都可以由輸入重新計算得到,然而這個方案顯然是不劃算的,因為反向傳播的過程是“由深入淺”,而計算中間結果的過程是“由淺入深”,二者的方向并不匹配,每當我們需要中間結果時就需要從頭再來一遍,這樣的計算和時間開銷顯然是不劃算的。如果折中一下呢?這就是OpenAI提出的gradient-checkpoint的思路,在神經網絡中間設置若干個檢查點(checkpoint),檢查點以外的中間結果全部舍棄,反向傳播求導數的時間,需要某個中間結果時,從最近的檢查點開始計算,這樣既節省了顯存,又避免了從頭計算的繁瑣過程;從代碼層面來看,原版實現[5]用的是tensorflow,由于是靜態圖的緣故,需要用到grapheditor等一系列騷操作,而且包含了很多“智能”尋找bottleneck選擇為checkpoint的代碼,很容易勸退新人。但是如果看一下pytorch的官方實現[6],你會驚訝的發現gradient-checkpoint的核心部分出奇的簡單,這也算是動態圖以及pytorch的一點小優勢吧,當然pytorch版本的實現并不包括智能尋找checkpoint點的功能,需要人為設定。核心代碼如下所示:class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): check_backward_validity(args) ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state if preserve_rng_state: ctx.fwd_cpu_state = torch.get_rng_state() # Don't eagerly initialize the cuda context by accident. # (If the user intends that the context is initialized later, within their # run_function, we SHOULD actually stash the cuda state here. Unfortunately, # we have no way to anticipate this will happen before we run the function.) ctx.had_cuda_in_fwd = False if torch.cuda._initialized: ctx.had_cuda_in_fwd = True ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) ctx.save_for_backward(*args) with torch.no_grad(): outputs = run_function(*args) return outputs @staticmethod def backward(ctx, *args): if not torch.autograd._is_checkpoint_valid(): raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") inputs = ctx.saved_tensors # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state # when we're done. rng_devices = [] if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: rng_devices = ctx.fwd_gpu_devices with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): if ctx.preserve_rng_state: torch.set_rng_state(ctx.fwd_cpu_state) if ctx.had_cuda_in_fwd: set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) detached_inputs = detach_variable(inputs) with torch.enable_grad(): outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor): outputs = (outputs,) torch.autograd.backward(outputs, args) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) return (None, None) + grads注意到最近曠視開源的MegEngine,在PR的時候提到一個亞線性顯存優化技術[7],其實就是gradient-checkpoint技術,詳情可參考論文Training Deep Nets with Sublinear Memory Cost[8],當然MegEngine肯定在細節上對其進行了一些優化,本文就不展開討論了。

06

CUDA版的swish activation

回到swish activation的優化上來,如果要追求效率的極致提升,下一步考慮的方案應該是手寫C++ extension,將計算從python層面轉移到C++與CUDA上。如何基于 pytorch寫C++擴展,官方文檔上有非常詳細的教程[9],寫法和方式也都比較靈活,可以根據自己的習慣進行選擇,這里我們選擇利用setuptools的方式進行構建pytorch在用戶自定義擴展上也是做了非常多的支持,用戶能非常方便的使用pytorch底層定義好的一些類和函數;在寫CUDA函數時,pytorch還提供了一個CUDAApplyUtils.cuh頭文件,專門用于優化pointwise操作的情況,以減小拷貝和臨時存儲的顯存浪費(用于lambda函數,函數名非常直觀,CUDA_tensor_applyN表示操作數的個數,N可以為1,2,3,4,用戶還可以指定每個操作數的屬性,如只讀/讀寫,針對每對情形都有專門的優化實現)對于swish activation來說,由于全是pointwise操作,利用這個優化技巧可以把顯存占用進一步壓縮。具體代碼可參考swish_optimize[10]簡單對比一下以上幾種實現在實際場景中(單卡RTX 2070,resnet50, bs=32)的顯存占用情況和運行時間(一次forward & 一次backward & 參數更新)
  • 無優化純Python:GPU memory=6383MB,time=223ms
  • 合并算子(Python):GPU memory=5139MB,time=234ms
  • 合并算子(CUDA):GPU memory=5143MB,time=188ms
從上述對比結果來看,結果基本符合前文的分析,純Python的實現下,顯存優化后的由于是時間換空間,所以顯存占用降低了,而時間稍有增加;在CUDA版本的優化下,一方面得益于C++的高效,另一方面得益于由于pointwise計算的優化,在顯存占用降低的同時,計算時間也大幅縮短外鏈地址:

[1] https://zhuanlan.zhihu.com/p/122943688

[2] https://arxiv.org/abs/1710.05941

[3] https://github.com/mapillary/inplace_abn

[4] https://arxiv.org/pdf/1712.02616.pdf

[5] https://github.com/cybertronai/gradient-checkpointing

[6] https://github.com/pytorch/pytorch/blob/176174a68ba2d36b9a5aaef0943421682ecc66d4/torch/utils/checkpoint.py#L55

[7] https://zhuanlan.zhihu.com/p/138730559

[8] https://arxiv.org/abs/1604.06174

[9] https://pytorch.org/tutorials/advanced/cpp_extension.html

[10] https://github.com/bindog/swish_optimize

本文目的在于學術交流,并不代表本公眾號贊同其觀點或對其內容真實性負責,版權歸原作者所有,如有侵權請告知刪除。

直播預告

歷史文章推薦

  • 【CVPR 2020 Tutorial】如何寫好論文和評審(概述)

  • 如何撰寫高水平的博士論文?超全論文指導!

  • 北大讀博手記:怎樣完成自己的博士生涯?非常具有指導性!

  • 太牛逼了!一位中國博士把整個CNN都給可視化了,每個細節看的清清楚楚!

  • Nature發表牛津博士建議:我希望在讀博士之初時就能知道的20件事

  • 沈向洋、華剛:讀科研論文的三個層次、四個階段與十個問題

  • 如何看待2021年秋招算法崗灰飛煙滅?

  • 獨家解讀 | ExprGAN:基于強度可控的表情編輯

  • 獨家解讀 | 矩陣視角下的BP算法

  • 獨家解讀 | Capsule Network深度解讀

  • 獨家解讀 | Fisher信息度量下的對抗攻擊

  • 論文解讀 | 知識圖譜最新研究綜述

  • 你的畢業論文過了嗎?《如何撰寫畢業論文?》

  • 卡爾曼濾波系列——經典卡爾曼濾波推導

分享、點贊、在看,給個三連擊唄!

總結

以上是生活随笔為你收集整理的tensorflow 显存 训练_【他山之石】训练时显存优化技术——OP合并与gradient checkpoint...的全部內容,希望文章能夠幫你解決所遇到的問題。

如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。