GEMM算法及优化流程详解
目錄
前言
im2col+GEMM算法簡介
GEMM算法優化
optimize1
optimize2
optimize3
前言
神經網絡前向耗時主要由卷積的耗時決定,參考賈楊青畢業論文,那么如何對卷積加速便成了重要的一個點,主流的加速方法有
以下幾種:
im2col+GEMM:目前幾乎所有的主流計算框架包括 Caffe, MXNet 等都實現了該方法. 該方法把整個卷積過程轉化成了GEMM過程,而GEMM在各種 BLAS 庫中都是被極致優化的,一般來說,速度較快。
Winograd: Winograd 是存在已久最近被重新發現的方法,在大部分場景中, Winograd方法都顯示和較大的優勢,目前cudnn中計算卷積就使用了該方法。
Strassen:1969年,Volker Strassen提出了第一個時間復雜度低于O(N^3)的算法,其復雜度為O(N^(2^(log2(7)))),但這種方法只在大卷積核情況下優勢才比較明顯,目前還沒有在開源框架中見到這種方法。
FFT:傅里葉變換和快速傅里葉變化是在經典圖像處理里面經常使用的計算方法,但是,在 ConvNet中通常不采用,主要是因為在 ConvNet 中的卷積模板通常都比較小,例如?3×3?等,這種情況下,FFT 的時間開銷反而更大,所以很少在CNN中利用FFT實現卷積。
很高興你看完前言:最近發現這篇文章寫的很好,阿里那邊的,《支付寶如何優化移動端深度學習引擎》推薦給大家~
?
im2col+GEMM算法簡介
GEMM在深度學習中是十分重要的,全連接層以及卷積層基本上都是通過GEMM來實現的,而網絡中大約90%的運算都是在這兩層中。而一個良好的GEMM的實現可以充分利用系統的多級存儲結構和程序執行的局部性來充分加速運算。
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?常規的卷積操作為:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ???? ? ? ?
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 3維卷積運算執行完畢,得一個2維的平面:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
將卷積操作的3維立體變為二維矩陣乘法,可以調用BLAS中的GEMM庫,按 [kernel_height, kernel_width, kernel_depth] ? 將輸入分成 3 維的 patch,并將其展成一維向量:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
此時的卷積操作就可轉化為矩陣乘法:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??
下面我們將以M=K=N=600為例說明GEMM算法的優化過程:
?
直接暴力卷積:
for (int m = 0; m < M; m++) {for (int n = 0; n < N; n++) {for (int k = 0; k < K; k++) {C[m][n]+= A[m][k] * B[k][n];}} }上述公式總計算量為2MNK FLOPs(其中 𝑀、𝑁、𝐾 分別指代三層循環執行的次數,2 指代循環最內層的一次乘法和加法) ,內存訪問操作總數為 4MNK(其中 2MNK 指代對 𝐶 的內存訪問,𝐶 需要先讀取內存、累和再存儲)。GEMM 的優化均以此為基點。
耗時分析:上述暴力gemm代碼耗時約為872ms
?
GEMM算法優化
optimize1
首先能想到的就是減少C矩陣的訪存次數,將C[m][n]放到外面,全部累和之后再賦值即可:
for (int m = 0; m < M; m++) {for (int n = 0; n < N; n++) {float temp = C[m][n];for (int k = 0; k < K; k++) {temp += A[m][k] * B[k][n];}C[m][n] = temp;} }上述公式總計算量依然為2MNK FLOPs,內存訪問操作總數為 2MNK+2MN(其中 2MN?指代對 𝐶 的內存訪問,𝐶 需要先讀取內存、累加完畢在存儲)。
耗時分析:上述代碼耗時約為791ms,耗時變少的原因是減少了部分C的訪存
?
optimize2
將輸出的計算拆分為 1×4 的小塊,即將 𝑁 維度拆分為兩部分。計算該塊輸出時,需要使用 𝐴 矩陣的 1 行,和 𝐵 矩陣的 4 列。
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?圖一:矩陣乘計算?1×4輸出
下面是該計算的偽代碼表示,這里已經將 1×4 中 N 維度的內部拆分進行了展開。這里的計算量仍然是 2𝑀𝑁𝐾 ,這一點在本文中不會有變化。
for (int m = 0; m < M; m++) {for (int n = 0; n < N; n += 4) {float temp_m0n0 = C[m][n + 0];float temp_m0n1 = C[m][n + 1];float temp_m0n2 = C[m][n + 2];float temp_m0n3 = C[m][n + 3];for (int k = 0; k < K; k++) {float temp = A[m][k];temp_m0n0 += temp * B[k][n + 0];temp_m0n1 += temp * B[k][n + 1];temp_m0n2 += temp * B[k][n + 2];temp_m0n3 += temp * B[k][n + 3];}C[m][n + 0] = temp_m0n0;C[m][n + 1] = temp_m0n1;C[m][n + 2] = temp_m0n2;C[m][n + 3] = temp_m0n3;} }簡單的觀察即可發現,上述偽代碼的最內側計算使用的矩陣 𝐴 的元素是一致的。因此可以將 𝐴[𝑚][𝑘] 讀取到寄存器中,從而實現 4 次數據復用(這里不再給出示例)。一般將最內側循環稱作計算核(micro kernel)。進行這樣的優化后,內存訪問操作數量變為 2MN+5/4MNK,訪存約為上面的5/8。
耗時分析:本優化耗時約為473ms,相比暴力耗時減少300ms左右,可能的兩個原因:1、由于B是行優先排列,1x4方法能夠減少數據從內存到cache的加載次數;2、合理利用寄存器,減少對𝐴矩陣訪存次數
?
optimize3
類似地,我們可以繼續拆分輸出的 𝑀 維度,從而在內側循環中計算 4×4 輸出,如圖二。
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?圖二:矩陣乘計算?4×4輸出
同樣地,將計算核心展開,可以得到下面的偽代碼。由于乘數效應,4×4 的拆分可以將對輸入數據的訪存縮減到 MN/16*(16*2+8K)=2MN+1/2*MNK。這相對于最開始的 4MNK 已經得到了 8X 的改進,這些改進都是通過展開循環后利用寄存器存儲數據減少訪存得到的。
for (int m = 0; m < M; m += 4) {for (int n = 0; n < N; n += 4) {float temp_m0n0 = C[m + 0][n + 0];float temp_m0n1 = C[m + 0][n + 1];float temp_m0n2 = C[m + 0][n + 2];float temp_m0n3 = C[m + 0][n + 3];float temp_m1n0 = C[m + 1][n + 0];float temp_m1n1 = C[m + 1][n + 1];float temp_m1n2 = C[m + 1][n + 2];float temp_m1n3 = C[m + 1][n + 3];float temp_m2n0 = C[m + 2][n + 0];float temp_m2n1 = C[m + 2][n + 1];float temp_m2n2 = C[m + 2][n + 2];float temp_m2n3 = C[m + 2][n + 3];float temp_m3n0 = C[m + 3][n + 0];float temp_m3n1 = C[m + 3][n + 1];float temp_m3n2 = C[m + 3][n + 2];float temp_m3n3 = C[m + 3][n + 3];for (int k = 0; k < K; k++) {float temp_m0 = A[m + 0][k];float temp_m1 = A[m + 1][k];float temp_m2 = A[m + 2][k];float temp_m3 = A[m + 3][k];float temp_n0 = B[k][n + 0];float temp_n1 = B[k][n + 1];float temp_n2 = B[k][n + 2];float temp_n3 = B[k][n + 3];temp_m0n0 += temp_m0 * temp_n0;temp_m0n1 += temp_m0 * temp_n1;temp_m0n2 += temp_m0 * temp_n2;temp_m0n3 += temp_m0 * temp_n3;temp_m1n0 += temp_m1 * temp_n0;temp_m1n1 += temp_m1 * temp_n1;temp_m1n2 += temp_m1 * temp_n2;temp_m1n3 += temp_m1 * temp_n3;temp_m2n0 += temp_m2 * temp_n0;temp_m2n1 += temp_m2 * temp_n1;temp_m2n2 += temp_m2 * temp_n2;temp_m2n3 += temp_m2 * temp_n3;temp_m3n0 += temp_m3 * temp_n0;temp_m3n1 += temp_m3 * temp_n1;temp_m3n2 += temp_m3 * temp_n2;temp_m3n3 += temp_m3 * temp_n3;}C[m + 0][n + 0] = temp_m0n0;C[m + 0][n + 1] = temp_m0n1;C[m + 0][n + 2] = temp_m0n2;C[m + 0][n + 3] = temp_m0n3;C[m + 1][n + 0] = temp_m1n0;C[m + 1][n + 1] = temp_m1n1;C[m + 1][n + 2] = temp_m1n2;C[m + 1][n + 3] = temp_m1n3;C[m + 2][n + 0] = temp_m2n0;C[m + 2][n + 1] = temp_m2n1;C[m + 2][n + 2] = temp_m2n2;C[m + 2][n + 3] = temp_m2n3;C[m + 3][n + 0] = temp_m3n0;C[m + 3][n + 1] = temp_m3n1;C[m + 3][n + 2] = temp_m3n2;C[m + 3][n + 3] = temp_m3n3;} }耗時分析:本優化耗時約為354ms,相比1x4耗時減少120ms左右
總結
以上是生活随笔為你收集整理的GEMM算法及优化流程详解的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: tensorflow量化策略详解
- 下一篇: InsightFace及其mxnet、t