快速掌握TensorFlow中张量运算的广播机制
相信大家在使用numpy和tensorflow的時(shí)候都會(huì)遇到如下的錯(cuò)誤
ValueError: operands could not be broadcast together with shapes (4,3) (4,)這是由于numpy和tensorflow中的張量在進(jìn)行運(yùn)算的時(shí)候形狀不滿足廣播機(jī)制的要求,不理解廣播機(jī)制的同學(xué)可能會(huì)通過各種魔改代碼來(lái)讓代碼正常運(yùn)行起來(lái),但是卻不知道為什么那樣改就可以。
本文將從原理上介紹張量運(yùn)算中經(jīng)常用到的廣播機(jī)制。
廣播(broadcasting)指的是不同形狀的張量之間的算數(shù)運(yùn)算的執(zhí)行方式。
通過兩個(gè)例子直觀了解廣播
數(shù)組與標(biāo)量值的乘法
import numpy as np arr = np.arange(5) arr #-> array([0, 1, 2, 3, 4]) arr * 4 #-> array([ 0, 4, 8, 12, 16])在上面的乘法運(yùn)算中,標(biāo)量值4被廣播到了其他所有元素上
通過減去列平均值的方式對(duì)數(shù)組每一列進(jìn)行距平化處理
arr = np.random.randn(4,3) arr #-> array([[ 1.83518156, 0.86096695, 0.18681254],# [ 1.32276051, 0.97987486, 0.27828887],# [ 0.65269467, 0.91924574, -0.71780692],# [-0.05431312, 0.58711748, -1.21710134]]) arr.mean(axis=0) #-> array([ 0.93908091, 0.83680126, -0.36745171])關(guān)于mean中的axis參數(shù),可以這樣理解:
在numpy中,axis = 0為行軸(豎直方向),axis = 1為列軸(水平方向),指定axis表示該操作沿axis進(jìn)行,得到結(jié)果將是一個(gè)shape為除去該axis的array,對(duì)于多維張量,axis=i是指運(yùn)算操作沿著第i個(gè)張量下標(biāo)變化的方向進(jìn)行。
在上例中,arr.mean(axis=0)表示對(duì)arr沿著軸0(豎直方向)求均值。顯然,第0個(gè)下標(biāo)變化的方向即為豎直方向,以第一列為例,4個(gè)元素的下標(biāo)分別為[(0,0),(1,0),(2,0),(3,0)]。
而arr的shape為(4,3),除去axis=0的shape,結(jié)果為(1,3)或者(3,),這與上面的代碼運(yùn)行結(jié)果相符。
廣播機(jī)制的原理
★如果兩個(gè)數(shù)組的后緣維度(從末尾開始算起的維度)的軸長(zhǎng)度相符或其中一方的長(zhǎng)度為1,則認(rèn)為它們是廣播兼容的。廣播會(huì)在缺失維度和(或)軸長(zhǎng)度為1的維度上進(jìn)行。
”demeaned = arr - arr.mean(axis=0) demeaned > array([[ 0.89610065, 0.02416569, 0.55426426],[ 0.3836796 , 0.1430736 , 0.64574058],[-0.28638623, 0.08244448, -0.35035521],[-0.99339402, -0.24968378, -0.84964963]]) demeaned.mean(axis=0) > array([ -5.55111512e-17, -5.55111512e-17, 0.00000000e+00])在上面的對(duì)arr每一列減去列平均值的例子中,arr的后緣維度為3,arr.mean(0)后緣維度也是3,滿足軸長(zhǎng)度相符的條件,廣播會(huì)在缺失維度進(jìn)行。
這里有點(diǎn)奇怪的是缺失維度不是axis=1,而是axis=0,個(gè)人理解是缺失維度指的是兩個(gè)arr除了軸長(zhǎng)度匹配的維度,在上面的例子中,正好是axis=0。
arr.mean(0)沿著axis=0廣播,可以看作是把a(bǔ)rr.mean(0)沿著豎直方向復(fù)制4份,即廣播的時(shí)候arr.mean(0)相當(dāng)于一個(gè)shape=(4,3)的數(shù)組,數(shù)組的每一行均相同,均為arr.mean(0)
各行減去行均值
row_means = arr.mean(axis=1) row_means.shape > (4,) arr - row_means > ---------------------------------------------------------------------------ValueError Traceback (most recent call last)<ipython-input-10-3d1314c7e700> in <module>()----> 1 arr - row_meansValueError: operands could not be broadcast together with shapes (4,3) (4,)直接相減,報(bào)錯(cuò),無(wú)法進(jìn)行廣播。
回顧上面的原則,要么滿足后緣維度軸長(zhǎng)度相等,要么滿足其中一方長(zhǎng)度為1。在這個(gè)例子中,兩者均不滿足,所以報(bào)錯(cuò)。根據(jù)廣播原則,較小數(shù)組的廣播維必須為1。解決方案是為較小的數(shù)組添加一個(gè)長(zhǎng)度為1的新軸。
numpy提供了一種通過索引機(jī)制插入軸的特殊語(yǔ)法。通過特殊的np.newaxis屬性以及“全”切片來(lái)插入新軸。
下面的例子中,我們通過插入新軸的方式實(shí)現(xiàn)二維數(shù)組各行減去行均值。這里將行均值沿著水平方向進(jìn)行廣播,廣播軸為axis=1,對(duì)row_means添加一個(gè)新軸axis=1
row_means[:,np.newaxis].shape > (4, 1) arr - row_means[:,np.newaxis] > array([[ 0.87419454, -0.10002007, -0.77417447],[ 0.46245243, 0.11956678, -0.58201921],[ 0.36798351, 0.63453458, -1.00251808],[ 0.17378588, 0.81521647, -0.98900235]])另一個(gè)例子
a = np.array([1,2,3]) a.shape # -> (3,) b = np.array([[1,],[2,],[3]]) # -> (3,1) b - a # -> array([[ 0, -1, -2],# [ 1, 0, -1],# [ 2, 1, 0]])上面的例子輸出為什么是一個(gè)3*3的數(shù)組??
我們來(lái)分析一下,根據(jù)廣播原則,b滿足其中一方軸長(zhǎng)度為1,那么廣播會(huì)沿著長(zhǎng)度為1的軸,及axis=1進(jìn)行,對(duì)數(shù)組b沿著axis=1即水平方向進(jìn)行復(fù)制,相當(dāng)于b變成一個(gè)shape為(3,3)且各列均為[1,2,3]的數(shù)組。
一個(gè)維度為(3,3)的數(shù)組減去一個(gè)維度為(3,)的數(shù)組,滿足后緣維度軸長(zhǎng)度相等,數(shù)組a沿著axis=0即豎直方向進(jìn)行廣播,相當(dāng)遠(yuǎn)a變成一個(gè)shape為(3,3)且個(gè)行均為[1,2,3]的數(shù)組。
b-a的時(shí)候,
?b被廣播成為
[[1,1,1],[2,2,2],[3,3,3]]a被廣播成為
[[1,2,3],[1,2,3],[1,2,3]]所以b-a的結(jié)果是
[[0,-1,-2],[1, 0,-1],[2, 1, 0]]三維情況
下面的例子中,構(gòu)造一個(gè)3*4*5的隨機(jī)數(shù)組arr_3d,我們希望實(shí)現(xiàn)對(duì)arr_3d的每個(gè)元素減去其深度(axis=2)方向的均值
#構(gòu)造三維數(shù)組 arr_3d = np.random.randn(3,4,5) #求深度方向的均值,想想結(jié)果的shape是什么?原始shape是(3,4,5) #除去axis=2后還剩(3,4) depth_means = arr_3d.mean(axis=2) depth_means.shape > (3, 4) #arr(3,4,5)和depth_means(3,4)不能直接廣播,后緣維度不相符且不存在軸長(zhǎng)度為1的軸 arr_3d_new = arr_3d - depth_means[:,:,np.newaxis] #所以我們添加廣播軸 arr_3d_new.mean(axis=2) #結(jié)果應(yīng)該為0,這里是接近0的浮點(diǎn)數(shù),符合預(yù)期> array([[ -5.55111512e-17, 4.44089210e-17, 4.44089210e-17, 4.44089210e-17],[ -8.88178420e-17, -1.11022302e-16, -6.66133815e-17,0.00000000e+00],[ 0.00000000e+00, -7.77156117e-17, -2.22044605e-17,-2.22044605e-17]])以上就是關(guān)于張量運(yùn)算中廣播機(jī)制的一點(diǎn)介紹,歡迎關(guān)注公眾號(hào)淺夢(mèng)的學(xué)習(xí)筆記,一起討論交流!
關(guān)于本站
“機(jī)器學(xué)習(xí)初學(xué)者”公眾號(hào)由是黃海廣博士創(chuàng)建,黃博個(gè)人知乎粉絲23000+,github排名全球前110名(32000+)。本公眾號(hào)致力于人工智能方向的科普性文章,為初學(xué)者提供學(xué)習(xí)路線和基礎(chǔ)資料。原創(chuàng)作品有:吳恩達(dá)機(jī)器學(xué)習(xí)個(gè)人筆記、吳恩達(dá)深度學(xué)習(xí)筆記等。
往期精彩回顧
那些年做的學(xué)術(shù)公益-你不是一個(gè)人在戰(zhàn)斗
適合初學(xué)者入門人工智能的路線及資料下載
吳恩達(dá)機(jī)器學(xué)習(xí)課程筆記及資源(github標(biāo)星12000+,提供百度云鏡像)
吳恩達(dá)深度學(xué)習(xí)筆記及視頻等資源(github標(biāo)星8500+,提供百度云鏡像)
《統(tǒng)計(jì)學(xué)習(xí)方法》的python代碼實(shí)現(xiàn)(github標(biāo)星7200+)
精心整理和翻譯的機(jī)器學(xué)習(xí)的相關(guān)數(shù)學(xué)資料
首發(fā):深度學(xué)習(xí)入門寶典-《python深度學(xué)習(xí)》原文代碼中文注釋版及電子書
備注:加入本站微信群或者qq群,請(qǐng)回復(fù)“加群”
加入知識(shí)星球(4300+用戶,ID:92416895),請(qǐng)回復(fù)“知識(shí)星球”
與50位技術(shù)專家面對(duì)面20年技術(shù)見證,附贈(zèng)技術(shù)全景圖總結(jié)
以上是生活随笔為你收集整理的快速掌握TensorFlow中张量运算的广播机制的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 原创:机器学习代码练习(一、回归)
- 下一篇: 员外陪你读论文:DeepWalk: On