【TensorFlow-windows】扩展层之STN
前言
讀TensorFlow相關代碼看到了STN的應用,搜索以后發現可替代池化,增強網絡對圖像變換(旋轉、縮放、偏移等)的抗干擾能力,簡單說就是提高卷積神經網絡的空間不變性。
國際慣例,參考博客:
理解Spatial Transformer Networks
github-STN
Deep Learning Paper Implementations: Spatial Transformer Networks - Part I
Deep Learning Paper Implementations: Spatial Transformer Networks - Part II
將STN加入網絡訓練的一個關于圖像隱寫術的案例:StegaStamp
理論
圖像變換
因為圖像的本質就是矩陣,那么圖像變換就是矩陣變換,先復習一下與圖像相關的矩陣變換。假設MMM為變換矩陣,NNN為圖像,為了簡化表達,設MMM的維度是(2,2)(2,2)(2,2),NNN代表像素點坐標,則維度是(2,1)(2,1)(2,1),以下操作均為對像素位置的調整操作,而非對像素值的操作。
-
縮放
M×N=[p00q]×[xy]=[pxqy]M\times N=\begin{bmatrix} p&0\\ 0&q \end{bmatrix}\times \begin{bmatrix} x\\y \end{bmatrix}=\begin{bmatrix} px\\qy \end{bmatrix} M×N=[p0?0q?]×[xy?]=[pxqy?] -
旋轉:繞原點順時針旋轉θ\thetaθ角
M×N=[cos?θ?sin?θsin?θcos?θ]×[xy]=[xcos?θ?ysin?θxsin?θ+ycos?θ]M\times N=\begin{bmatrix} \cos\theta&-\sin\theta\\ \sin\theta&\cos\theta \end{bmatrix}\times \begin{bmatrix} x\\y \end{bmatrix}=\begin{bmatrix} x\cos\theta-y\sin\theta\\x\sin\theta+y\cos\theta \end{bmatrix} M×N=[cosθsinθ??sinθcosθ?]×[xy?]=[xcosθ?ysinθxsinθ+ycosθ?] -
錯切(shear):類似于將字的正體變成斜體
M×N=[1mn1]×[xy]=[x+myy+nx]M\times N=\begin{bmatrix} 1&m\\ n&1 \end{bmatrix}\times \begin{bmatrix} x\\y \end{bmatrix}=\begin{bmatrix} x+my\\y+nx \end{bmatrix} M×N=[1n?m1?]×[xy?]=[x+myy+nx?] -
平移:要轉換為齊次矩陣做平移
M′×N′=[10a01b]×[xy1]=[x+ay+b]M'\times N'=\begin{bmatrix} 1&0&a\\ 0&1&b \end{bmatrix}\times \begin{bmatrix} x\\y\\1 \end{bmatrix}=\begin{bmatrix} x+a\\y+b \end{bmatrix} M′×N′=[10?01?ab?]×???xy1????=[x+ay+b?]
盜用參考博客的圖解就是:
注意,我們進行多次變換的時候有多個變換矩陣,如果每次計算一個變換會比較耗時,參考矩陣的乘法特性,我們可以先將變換矩陣相乘,得到一個完整的矩陣代表所有變換,最后乘以圖像,就可將圖像按照組合變換順序得到變換圖像。這個代表一系列的變換的矩陣通常表示為:
M=[abcdef]M=\begin{bmatrix} a&b&c\\d&e&f \end{bmatrix} M=[ad?be?cf?]
因為直接計算位置的值,很可能得到小數,比如將(3,3)(3,3)(3,3)的圖像放大到(9,9)(9,9)(9,9),也就是放大3倍,那么新圖像(8,8)(8,8)(8,8)位置的像素就是原圖(8/3,8/3)(8/3,8/3)(8/3,8/3)位置的像素,但是像素位置不可能是小數,因而出現了解決方案:雙線性插值
雙線性插值
先復習一下線性插值,直接去看之前寫的這篇博客,知道(x1,y1)(x_1,y_1)(x1?,y1?)與(x2,y2)(x_2,y_2)(x2?,y2?),求(x1,x2)區間內的點(x_1,x_2)區間內的點(x1?,x2?)區間內的點xxx位置的y值,結果是:
y=x?x2x1?x2y1+x?x1x2?x1y2y=\frac{x-x_2}{x_1-x_2}y_1+\frac{x-x_1}{x_2-x_1}y_2 y=x1??x2?x?x2??y1?+x2??x1?x?x1??y2?
可以發現線性插值是針對一維坐標的,即給xxx求yyy,但是雙線性插值是針對二維坐標點的,即給(x,y)(x,y)(x,y)求值QQQ。方法是先在xxx軸方向做兩次線性插值,再在yyy軸上做一次線性插值。
設需要求(x,y)(x,y)(x,y)處的值,我們需要預先知道其附近四個坐標點及其對應的值,如:
- (x,y)(x,y)(x,y)左下角坐標為(x1,y1)(x_1,y_1)(x1?,y1?),值為Q1Q_1Q1?
- (x,y)(x,y)(x,y)右下角坐標為(x2,y1)(x_2,y_1)(x2?,y1?), 值為Q2Q_2Q2?
- (x,y)(x,y)(x,y)左上角坐標為(x1,y2)(x_1,y_2)(x1?,y2?), 值為Q3Q_3Q3?
- (x,y)(x,y)(x,y)右上角坐標為(x2,y2)(x_2,y_2)(x2?,y2?),值為Q4Q_4Q4?
首先對下面的(x1,y1)(x_1,y_1)(x1?,y1?)和(x2,y1)(x_2,y_1)(x2?,y1?)做線性插值,方法是把它兩看做一維坐標(x1,Q1)(x_1,Q_1)(x1?,Q1?)和(x2,Q2)(x_2,Q2)(x2?,Q2),得到:
P1=x?x2x1?x2Q1+x?x1x2?x1Q2P_1=\frac{x-x_2}{x_1-x_2}Q_1+\frac{x-x_1}{x_2-x_1}Q_2 P1?=x1??x2?x?x2??Q1?+x2??x1?x?x1??Q2?
同理得到上面的兩個坐標(x1,y2)(x_1,y_2)(x1?,y2?)與(x2,y2)(x_2,y_2)(x2?,y2?)的插值結果,也就是(x1,Q3)(x_1,Q_3)(x1?,Q3?)和(x2,Q4)(x_2,Q_4)(x2?,Q4?)的線性插值結果:
P2=x?x2x1?x2Q3+x?x1x2?x1Q4P_2=\frac{x-x_2}{x_1-x_2}Q_3+\frac{x-x_1}{x_2-x_1}Q_4 P2?=x1??x2?x?x2??Q3?+x2??x1?x?x1??Q4?
再對(y1,P1)(y_1,P_1)(y1?,P1?)和(y2,P2)(y_2,P_2)(y2?,P2?)做線性插值:
P=x?y2y1?y2P1+y?y1y2?y1P2P=\frac{x-y_2}{y_1-y_2}P_1+\frac{y-y_1}{y_2-y_1}P_2 P=y1??y2?x?y2??P1?+y2??y1?y?y1??P2?
解決上面圖像變換的問題,假設變換后的坐標不是整數,那么就選擇這個坐標四個角的坐標的雙線性插值的結果,比如(8/3,8/3)(8/3,8/3)(8/3,8/3)位置的像素就是(2,2),(3,2),(2,3),(3,3)(2,2),(3,2),(2,3),(3,3)(2,2),(3,2),(2,3),(3,3)位置像素的雙線性插值結果。
總之就是先計算目標圖像像素在源圖像中的位置,然后得到源圖像位置是小數,針對小數位置的四個頂點做雙線性插值。
上面就是STN做的工作,也可以發現STN接受的參數就是6個,接下來看看為什么STN能提高卷積網絡的旋轉、平移、縮放不變性。
總結一下:
圖像處理中的仿射變換通常包含三個步驟:
- 創建由(x,y)(x,y)(x,y)組成的采樣網格,比如(400,400)(400,400)(400,400)的灰度圖對應創建一個同樣大小的網格。
- 將變換矩陣應用到采樣網格上
- 使用插值技術從原圖中計算變換圖的像素值
池化
強行翻譯一波這篇文章關于池化的部分,建議看原文,這里摘取個人認為重要部分:
池化在某種程度上增加了模型的空間不變性,因為池化是一種下采樣技術,減少了每層特征圖的空間大小,極大減少了參數數量,提高了運算速度。
池化提供的不變性確切來說是什么?池化的思路是將一個圖像切分成多個單元,這些復雜單元被池化以后得到了可以描述輸出的簡單的單元。比如有3張不同方向的數字7的圖像,池化是通過圖像上的小網格來檢測7,不受7的位置影響,因為通過聚集的像素值,我們得到的信息大致一樣。個人覺得,作者的本意是單看小網格,是有很多一樣的塊。
池化的缺點在于:
- 丟失了75%的信息(應該是(2,2)(2,2)(2,2)的最大值池化方法),意味著我們一定丟了是精確的位置信息。有人會問,這樣可以增加空間魯棒性哇。然而,對于視覺識別人物,空間信息是非常重要的。比如分類貓的時候,知道貓的胡須的位置相對于鼻子的位置有可能很重要,但是如果使用最大池化,可能丟失了這個信息。
- 池化是局部的且預定義好的。一個小的接受域,池化操作的影響僅僅是針對更深的網絡層(越深感受野越大),也就是中間的特征圖可能受到嚴重的輸入失真的影響。我們不能任意增加接受域,這樣會過度下采樣。
主要結論就是卷積網絡對于相對大的輸入失真不具有不變性。
The pooling operation used in convolutional neural networks is a big mistake and the fact that it works so well is a disaster. (Geoffrey Hinton, Reddit AMA)STN理論
STN的全稱是Spatial Transformer Networks,空間變換網絡。時空變換機制就是通過給CNN提供顯式的空間變換能力,以解決上述池化出現的問題。有三種特性:
- Modular:STN能夠被插入到網絡的任意地方,僅需很小的調整
- differentiable:STN可以通過反向傳播訓練
- dynamic:STN是對每個輸入樣本的一個特征圖做空間變換,而池化是針對所有樣本。
上圖是STN網絡的主要框架。所以到底什么是空間變換?通過結構圖可發現模型包含三部分:localisation network、grid generator、sampler。
Localisation Network
主要是提取被應用到輸入特征圖上的仿射變換的參數θ\thetaθ,網絡結構是:
- 輸入:大小為(H,W,C)(H,W,C)(H,W,C)的特征圖UUU
- 輸出:大小為(6,1)(6,1)(6,1)的變換矩陣θ\thetaθ
- 結構:全連接或者卷積
Parametrised Sampling Grid
輸出參數化的采樣網格,是一系列的點,每個輸入特征圖能夠產生期望的變換輸出。
具體就是:網格生成器首先產生于輸入圖像UUU大小相同的標準網格,然后將仿射變換應用到網格。公式表達即,假設輸入圖的索引是(xt,yt)(x^t,y^t)(xt,yt),將θ\thetaθ代表的變換應用到坐標上得到新的坐標:
[xsys]=[θ1θ2θ3θ4θ5θ6]×[xtyt1]\begin{bmatrix} x^s\\y^s \end{bmatrix}=\begin{bmatrix} \theta_1&\theta_2&\theta_3\\\theta_4&\theta_5&\theta_6 \end{bmatrix}\times\begin{bmatrix} x^t\\y^t\\1 \end{bmatrix} [xsys?]=[θ1?θ4??θ2?θ5??θ3?θ6??]×???xtyt1????
Differentiable Image Sampling
依據輸入特征圖和參數化采樣網格,我們可以利用雙線性插值方法獲得輸出特征圖。注意,這一步我們可以通過制定采樣網格的大小執行上采樣或者下采樣,很像池化。
左圖使用了單位變換,右圖使用了旋轉的仿射變換。
【注】因為雙線性插值是可微的,所以STN可以作為訓練網絡的一部分。
代碼
利用STN前向過程做圖像變換
GitHub上有作者提供了源碼,也可以用pip直接安裝。
代碼直接貼了,稍微改了一點點:
導入包
import tensorflow as tf import cv2 import numpy as npfrom stn import spatial_transformer_network as transformer讀入圖像,轉換為四維矩陣:
img=cv2.imread('test_img.jpg') img=np.array(img) H,W,C=img.shape img=img[np.newaxis,:] print(img.shape)旋轉變換的角度
degree=np.deg2rad(45) theta=np.array([[np.cos(degree),-np.sin(degree),0],[np.sin(degree),np.cos(degree),0] ])構建網絡結構
x=tf.placeholder(tf.float32,shape=[None,H,W,C]) with tf.variable_scope('spatial_transformer'):theta=theta.astype('float32')theta=theta.flatten()loc_in=H*W*C #輸入維度loc_out=6 #輸出維度W_loc=tf.Variable(tf.zeros([loc_in,loc_out]),name='W_loc')b_loc=tf.Variable(initial_value=theta,name='b_loc')#運算fc_loc=tf.matmul(tf.zeros([1,loc_in]),W_loc)+b_loch_trans=transformer(x,fc_loc)把圖像喂進去,并顯示圖像
init=tf.global_variables_initializer() with tf.Session() as sess:sess.run(init)y=sess.run(h_trans,feed_dict={x:img})print(y.shape)y=np.squeeze(np.array(y,dtype=np.uint8)) print(y.shape) cv2.imshow('trasformedimg',y) cv2.waitKey() cv2.destroyAllWindows()重點關注網絡構建:
權重w_loc是全零的大小為(HWC,6)(HWC,6)(HWC,6)的矩陣,偏置b_loc是大小為(1,6)(1,6)(1,6)的向量,這樣經過運算
fc_loc=tf.matmul(tf.zeros([1,loc_in]),W_loc)+b_loc得到的其實就是我們指定的旋轉角度對應的6維變換參數,最后利用變換函數transformer執行此變換就行了。
將STN加入網絡中訓練
主要參考StegaStamp作者的寫法,這里做STN部分加入網絡的方法:
輸入一張圖片到如下網絡結構(Keras網絡結構搭建語法):
得到(1,128)(1,128)(1,128)維的向量,其實用一個網絡替換上面前向計算中的loc_in,目的是為了得到二維圖像對應的一維信息
后面的過程就和前向計算一樣了,定義權重和偏置:
然后利用一維信息得到圖像變換所需的6個值:
x = tf.matmul(stn_params, self.W_fc1) + self.b_fc1最后利用STN庫將變換應用到圖像中,得到下一層網絡結構的輸入
transformed_image = stn_transformer(image, x, [self.height, self.width, 3])可以看出,STN加入到網絡后,訓練參數有:
- 二維圖像到一維特征向量的卷積+全連接網絡的權重和偏置
- 一維向量到6維變換參數的權重和偏置
總結
通篇就是對池化方案的改變,使用STN能夠增加網絡的變換不變性,比池化的效果更好。
代碼:
鏈接:https://pan.baidu.com/s/1kDs9T-Mf1F_mzQyvslcROA
提取碼:crdu
總結
以上是生活随笔為你收集整理的【TensorFlow-windows】扩展层之STN的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 《恋与制作人》周棋洛生日快乐 协奏情意绵
- 下一篇: 【TensorFlow-windows】