LambdaNetworks解读
最近有不少人和我提到 ViT 以及 DETR 以及商湯提出的 Deformable DETR,仿若看到了 Transformer 在計算機視覺中大放異彩的未來,甚至谷歌對其在自注意力機制上進行了調整并提出 Performer。但是,由于 Transformer 的自注意力機制對內存的需求是輸入的平方倍,這在圖像任務上計算效率過低,當輸入序列很長的時候,自注意力對長程交互建模計算量更是龐大無比。而且,Transformer 是出了名的難訓練。所以,想要看到其在視覺任務上有更好的表現,還需要面臨不小的挑戰,不過,LambdaNetworks倒是提出了一種新的長程交互信息捕獲的新范式,而且在視覺任務中效果很不錯。
簡介
文章對于捕獲輸入和結構化上下文之間的長程交互提出了一種新的通用框架,該方法名為Lambda Layer。它通過將可用上下文轉化為名為lambdas的線性函數,并將這些函數分別應用于每個輸入。Lambda層是通用的,它可以建模全局或者局部的內容和位置上的信息交互。并且,由于其避開了使用“昂貴”的注意力圖,使得其可以適用于超長序列或者高分辨率圖像。由Lambda構成的LambdaNetworks在計算上是高效的,并且可以通過主流計算庫實現。實驗證明,LambdaNetworks在圖像分類、目標檢測、實例分割等任務上達到sota水準且計算更加高效。同時,作者也基于ResNet改進設計了LambdaResNet并且獲得和EfficientNet相當的效果,快了4.5倍。
-
論文標題
LambdaNetworks: Modeling long-range Interactions without Attention
-
論文地址
https://openreview.net/forum?id=xTJEN-ggl1b
-
論文源碼
https://github.com/lucidrains/lambda-networks
介紹
建模長程信息交互是機器學習領域很重要的課題,注意力機制是當前最主流的范式,然而,自注意力的二次內存占用不利于處理超長序列或者多維輸入,比如包含數萬像素的圖像。論文中這里舉了個例子,一批256個64x64的圖像使用8head的多頭注意力就需要32G的內存。
考慮到自注意力的局限性,論文提出了Lambda層,該層為捕獲輸入和結構化的上下文之間的長程信息交互提供了一種新的通用框架。Lambda層捕獲信息交互的方式也很簡單,它將可用上下文轉化為線性函數,并將這些線性函數分別應用于每個輸入,這些線性函數就是lambda。Lambda層可以成為注意力機制的替代品,注意力在輸入和上下文之間定義了一個相似性核,而Lambda層將上下文信息總結為一個固定size的線性函數,這樣就避開了很耗內存的注意力圖。他倆的對比,可以通過下面的圖看出來(左圖是一個包含三個query的局部上下文,它們同處一個全局上下文中;中圖是attention機制產生的注意力圖;右圖則是lambda層線性函數作用于query的結果)。
Lambda層用途廣泛,可以實現為在全局、局部或masked上下文中對內容和基于位置的交互進行建模。由此產生的神經網絡結構LambdaNetworks具有高效的計算能力,并且可以以較小的內存開銷建模長程依賴,因此非常適用于超大結構化輸入,如高分辨率圖像。
后文也用實驗證明,在注意力表現很好的任務中,LambdaNetworks表現相當,且計算更為高效且更快。
長程信息交互建模
論文在第二部分主要對一些Lambda的術語進行了定義,引入keys作為捕獲queries和它們的上下文之間信息交互的需求,而且,作者也說明,Lambda layer采用了很多自注意力的術語來減少閱讀差異,這就是為什么很多人覺得兩者在很多名稱定義上差異不大的原因。
queries、contexts和interactions
Q={(qn,n)}\mathcal{Q}=\left\{\left(\boldsymbol{q}_{n}, n\right)\right\}Q={(qn?,n)}和C={(cm,m)}\mathcal{C}=\left\{\left(\boldsymbol{c}_{m}, m\right)\right\}C={(cm?,m)}分別表示queries和contexts,每個(qn,n)\left(\boldsymbol{q}_{n}, n\right)(qn?,n)都包含內容qn∈R∣k∣\boldsymbol{q}_{n} \in \mathbb{R}^{|k|}qn?∈R∣k∣和位置nnn,同樣的,每個上下文元素(cm,m)\left(\boldsymbol{c}_{m}, m\right)(cm?,m)都包含內容cm\boldsymbol{c}_{m}cm?和位置mmm,而(n,m)(n, m)(n,m)指的是任意結構化元素之間的成對關系。舉個例子,這個(n,m)對可以指被固定在二維柵格上的兩個像素的相對距離,也可以指圖(Graph)上倆node之間的關系。
下面,作者介紹了Lambda layer的工作過程。先是考慮給定的上下文C\mathcal{C}C的情況下通過函數F:((qn,n),C)?yn\boldsymbol{F}:\left(\left(\boldsymbol{q}_{n}, n\right), \mathcal{C}\right) \mapsto \boldsymbol{y}_{n}F:((qn?,n),C)?yn?將query映射到輸出向量yn\boldsymbol{y}_{n}yn?。顯然,如果處理的是結構化輸入,那么這個函數可以作為神經網絡中的一個層來看待。將(qn,cm)\left(\boldsymbol{q}_{n}, \boldsymbol{c}_{m}\right)(qn?,cm?)稱為基于內容的交互,(qn,(n,m))\left(\boldsymbol{q}_{n},(n, m)\right)(qn?,(n,m))則為基于位置的交互。此外,若yn\boldsymbol{y}_{n}yn?依賴于所有的(qn,cm)\left(\boldsymbol{q}_{n}, \boldsymbol{c}_{m}\right)(qn?,cm?)或者(qn,(n,m))\left(\boldsymbol{q}_{n},(n, m)\right)(qn?,(n,m)),則稱F\boldsymbol{F}F捕獲了全局信息交互,如果只是圍繞nnn的一個較小的受限上下文用于映射,則稱F\boldsymbol{F}F捕獲了局部信息交互。最后,若這些交互包含了上下文中所有∣m∣|m|∣m∣個元素則稱為密集交互(dense interaction),否則為稀疏交互(sparse interaction)。
引入key來捕獲長程信息交互
在深度學習這種依賴GPU計算的場景下,我們優先考慮快速的線性操作并且通過點積操作來捕獲信息交互。這就促使了引入可以和query通過點擊進行交互的向量,該向量和query同維。特別是基于內容的交互(qn,cm)\left(\boldsymbol{q}_{n}, \boldsymbol{c}_{m}\right)(qn?,cm?)需要一個依賴cm\boldsymbol{c}_{m}cm?的kkk維向量,這個向量就是key(鍵)。相反,基于位置的交互(qn,(n,m))\left(\boldsymbol{q}_{n},(n, m)\right)(qn?,(n,m))則需要位置編碼enm∈R∣k∣\boldsymbol{e}_{n m} \in \mathbb{R}^{|k|}enm?∈R∣k∣,有時也稱為相對key。query和key的深度∣k∣|k|∣k∣以及上下文空間維度∣m∣|m|∣m∣不在輸出yn∈R∣v∣\boldsymbol{y}_{n} \in \mathbb{R}^{|v|}yn?∈R∣v∣,因此需要將這些維度收縮為layer計算的一部分。因此,捕獲長程交互的每一層都可以根據它是收縮查詢深度還是首先收縮上下文位置來表征。
注意力交互
收縮query的深度首先會在query和上下文元素之間創建一個相似性核,這就是attention操作。隨著上下文位置∣m∣|m|∣m∣的增大而輸入輸出維度∣k∣|k|∣k∣和∣v∣|v|∣v∣不變,考慮到層輸出是一個很小維度的向量∣v∣?∣m∣|v| \ll|m|∣v∣?∣m∣,注意力圖(attention map)的計算會變得很浪費資源。
Lambda交互
相反,通過一個線性函數λ(C,n)\boldsymbol{\lambda}(\mathcal{C}, n)λ(C,n)獲得輸出yn=F((qn,n),C)=λ(C,n)(qn)\boldsymbol{y}_{n}=F\left(\left(\boldsymbol{q}_{n}, n\right), \mathcal{C}\right)=\boldsymbol{\lambda}(\mathcal{C}, n)\left(\boldsymbol{q}_{n}\right)yn?=F((qn?,n),C)=λ(C,n)(qn?)會更高效地簡化映射過程(map)。在這個場景中,上下文被聚合為一個固定size的線性函數λn=λ(C,n)\boldsymbol{\lambda}_{n}=\boldsymbol{\lambda}(\mathcal{C}, n)λn?=λ(C,n)。每個λn\boldsymbol{\lambda}_{n}λn?作為一個小的線性函數獨立于上下文并且被用到相關的queryqn\boldsymbol{q}_nqn?后丟棄。這個機制很容易聯想到影響比較大的函數式編程和lambda積分,所以稱為lambda層。
Lambda層
一個lambda層將輸入X∈R∣n∣×din\boldsymbol{X} \in \mathbb{R}^{|n| \times d_{i n}}X∈R∣n∣×din?和上下文C∈R∣m∣×dc\boldsymbol{C} \in \mathbb{R}^{|m| \times d_{c}}C∈R∣m∣×dc?作為輸入并產生線性函數lambdas分別作用于query,返回輸出Y∈R∣n∣×dout\boldsymbol{Y} \in \mathbb{R}^{|n| \times d_{o u t}}Y∈R∣n∣×dout?。顯然,在自注意力中,C=X\boldsymbol{C} = \boldsymbol{X}C=X。為了不失一般性,我們假定din=dc=dout=dd_{i n}=d_{c}=d_{o u t}=ddin?=dc?=dout?=d。在接下來的論文里,作者將重點放在了lambda層的一個具體實例上,并且證明lambda層可以獲得密集的長程內容和位置的信息交互而不需要構建注意力圖。
將上下文轉化為線性函數
首先,假定上下文只有一個query(qn,n)\left(\boldsymbol{q}_{n}, n\right)(qn?,n)。我們希望產生一個線性函數lambdaR∣k∣→R∣v∣\mathbb{R}^{|k|} \rightarrow \mathbb{R}^{|v|}R∣k∣→R∣v∣,我們將R∣k∣×∣v∣\mathbb{R}^{|k| \times|v|}R∣k∣×∣v∣稱為函數。下表所示的就是lambda層的超參、參數以及其他相關的配置。
生成上下文lambda函數:lambda層首先通過線性投影上下文來計算keys和values,并且使用softmax操作跨上下文對keys進行標準化從而得到標準化后的Kˉ\bar{K}Kˉ。它的實現可以看作是一種函數式消息傳遞,每個上下文元素貢獻一個內容functionμmc=K ̄mVmT\boldsymbol{\mu}_{m}^{c}=\overline{\boldsymbol{K}}_{m} \boldsymbol{V}_{\boldsymbol{m}}^{T}μmc?=Km?VmT?和位置functionμnmp=EnmVmT\boldsymbol{\mu}_{n m}^{p}=\boldsymbol{E}_{n m} \boldsymbol{V}_{\boldsymbol{m}}^{T}μnmp?=Enm?VmT?,最終的lambda函數其實是兩者的和,具體如下,式子中的λc\boldsymbol{\lambda}^{c}λc為內容lambda,而λnp\boldsymbol{\lambda}^p_nλnp?為位置lambda。內容λc\boldsymbol{\lambda}^{c}λc對上下文元素的排列是不變的,在所有的query位置nnn之間共享,并僅基于上下文內容對qn\boldsymbol{q}_{n}qn?進行編碼轉換。不同的是,位置λnp\lambda_{n}^{p}λnp?基于內容cm\boldsymbol{c}_{m}cm?和位置(n,m)(n, m)(n,m)對查詢query進行編碼轉換,從而支持對結構化輸入建模如圖像。
λc=∑mμmc=∑mK ̄mVmTλnp=∑mμnmp=∑mEnmVmTλn=λc+λnp∈R∣k∣×∣v∣\begin{aligned} \boldsymbol{\lambda}^{c} &=\sum_{m} \boldsymbol{\mu}_{m}^{c}=\sum_{m} \overline{\boldsymbol{K}}_{m} \boldsymbol{V}_{\boldsymbol{m}}^{T} \\ \boldsymbol{\lambda}_{n}^{p} &=\sum_{m} \boldsymbol{\mu}_{n m}^{p}=\sum_{m} \boldsymbol{E}_{n m} \boldsymbol{V}_{\boldsymbol{m}}^{T} \\ \boldsymbol{\lambda}_{n} &=\boldsymbol{\lambda}^{c}+\boldsymbol{\lambda}_{n}^{p} \in \mathbb{R}^{|k| \times|v|} \end{aligned} λcλnp?λn??=m∑?μmc?=m∑?Km?VmT?=m∑?μnmp?=m∑?Enm?VmT?=λc+λnp?∈R∣k∣×∣v∣?
應用lambda到query:輸入被轉化為queryqn=WQxn\boldsymbol{q}_{n}=\boldsymbol{W}_{Q} \boldsymbol{x}_{n}qn?=WQ?xn?,然后lambda層獲得如下輸出。
yn=λnqn=(λc+λnp)qn∈R∣v∣\boldsymbol{y}_{n}=\boldsymbol{\lambda}_{n} \boldsymbol{q}_{n}=\left(\boldsymbol{\lambda}^{c}+\boldsymbol{\lambda}_{n}^{p}\right) \boldsymbol{q}_{n} \in \mathbb{R}^{|v|} yn?=λn?qn?=(λc+λnp?)qn?∈R∣v∣
Lambda的解釋:λn∈R∣k∣×∣v∣\boldsymbol{\lambda}_{n} \in \mathbb{R}^{|k| \times|v|}λn?∈R∣k∣×∣v∣矩陣的列可以看作∣k∣∣v∣|k| |v|∣k∣∣v∣維上下文特征的固定size的集合。這些上下文特征從上下文內容和結構聚合而來。應用lambda線性函數動態地分布這些上下文特征來產生輸出yn=∑kqnkλnk\boldsymbol{y}_{n}=\sum_{k} q_{n k} \boldsymbol{\lambda}_{n k}yn?=∑k?qnk?λnk?。這個過程捕獲密集地內容和位置的長程信息交互,而不需要產生注意力圖。
標準化: 實驗表明,非線性或者標準化操作對計算是有幫助的,作者在計算的query和value之后應用batch normalization發現是有效的。
對結構化上下文應用Lambda函數
在這一節,作者主要介紹如何將lambda層應用于結構化上下文。
Translation equivariance:在很多機器學習場景中,Translation equivariance是一個很強的歸納偏置。由于基于內容的信息交互是排列等變的,因此本就是translation equivariant。而位置的信息交互獲得translation equivariant則通過對任意的translation ttt確保位置編碼滿足Enm=Et(n)t(m)\boldsymbol{E}_{n m}=\boldsymbol{E}_{t(n) t(m)}Enm?=Et(n)t(m)?來做到。實際中,我們定義一個相對位置編碼的張量R∈R∣k∣×∣r∣×∣u∣\boldsymbol{R} \in \mathbb{R}^{|k| \times|r| \times|u|}R∈R∣k∣×∣r∣×∣u∣,其中rrr索引對所有的(n,m)(n,m)(n,m)對可能的相對位置,并將其重新索引為E∈R∣k∣×∣n∣×∣m∣×∣u∣\boldsymbol{E} \in \mathbb{R}^{|k| \times|n| \times|m| \times|u|}E∈R∣k∣×∣n∣×∣m∣×∣u∣,如Enm=Rr(n,m)\boldsymbol{E}_{n m}=\boldsymbol{R}_{r(n, m)}Enm?=Rr(n,m)?。
Lambda 卷積: 盡管有長程信息交互的好處,局部性在許多任務中仍然是一個強烈的歸納偏置。從計算的角度來看,使用全局上下文可能會產生噪聲或過度。因此,將位置交互的范圍限制到查詢位置nnn周圍的一個局部鄰域,就像局部自注意和卷積的情況一樣,可能是有用的。這可以通過對所需范圍之外的上下文位置mmm的位置嵌入進行歸零來實現。然而,對于較大的∣m∣|m|∣m∣值,這種策略仍然代價高昂,因為計算仍然會發生(它們只是被歸零)。在上下文被安排在多維網格上時,可以通過常規卷積從局部上下文中生成位置lambdas,將V\boldsymbol{V}V中的vvv維視為額外的空間維度。考慮在一維序列上的大小為∣r∣|r|∣r∣的局部域上生成位置lambdas。相對位置編碼張量R∈R∣r∣×∣u∣×∣k∣\boldsymbol{R} \in \mathbb{R}^{|r| \times|u| \times|k|}R∈R∣r∣×∣u∣×∣k∣可以被reshape到R ̄∈R∣r∣×1×∣u∣×∣k∣\overline{\boldsymbol{R}} \in \mathbb{R}^{|r| \times 1 \times|u| \times|k|}R∈R∣r∣×1×∣u∣×∣k∣,并且被用作二維卷積核來計算需要的位置lambda,算式如下。
λbnvk=conv?2d(Vbnvu,R ̄r1uk)\boldsymbol{\lambda}_{b n v k}=\operatorname{conv} 2 \mathrmozvdkddzhkzd\left(\boldsymbol{V}_{b n v u}, \overline{\boldsymbol{R}}_{r 1 u k}\right) λbnvk?=conv2d(Vbnvu?,Rr1uk?)
這個操作稱為lambda卷積,由于計算被限制在一個局部范圍,lambda卷積相對于輸入只需要線性時間和內存復雜度的消耗。lambda卷積很容易和其他功能一起使用,如dilation和striding,并且在硬件計算上享受告訴運算。計算效率和局部自注意力形成了鮮明對比,如下表。
multiquery lambdas減少復雜性
這部分作者主要對計算復雜度進行了分析,設計了多query lambda,計算復雜度對比如下。
提出的multiquery lambdas可以通過einsum高效實現。
λbkvc=einsum?(K ̄bmku,Vbmvu)λbnkvp=einsum?(Eknmu,Vbmvu)Ybnhvc=einsum?(Qbnhk,λbkvc)Ybnhvp=einsum?(Qbnhk,λbnkvp)Ybnhv=Ybnhvc+Ybnhvp\begin{aligned} \boldsymbol{\lambda}_{b k v}^{c}=& \operatorname{einsum}\left(\overline{\boldsymbol{K}}_{b m k u}, \boldsymbol{V}_{b m v u}\right) \\ \boldsymbol{\lambda}_{b n k v}^{p} &=\operatorname{einsum}\left(\boldsymbol{E}_{k n m u}, \boldsymbol{V}_{b m v u}\right) \\ \boldsymbol{Y}_{b n h v}^{c} &=\operatorname{einsum}\left(\boldsymbol{Q}_{b n h k}, \boldsymbol{\lambda}_{b k v}^{c}\right) \\ \boldsymbol{Y}_{b n h v}^{p} &=\operatorname{einsum}\left(\boldsymbol{Q}_{b n h k}, \boldsymbol{\lambda}_{b n k v}^{p}\right) \\ \boldsymbol{Y}_{b n h v} &=\boldsymbol{Y}_{b n h v}^{c}+\boldsymbol{Y}_{b n h v}^{p} \end{aligned} λbkvc?=λbnkvp?Ybnhvc?Ybnhvp?Ybnhv??einsum(Kbmku?,Vbmvu?)=einsum(Eknmu?,Vbmvu?)=einsum(Qbnhk?,λbkvc?)=einsum(Qbnhk?,λbnkvp?)=Ybnhvc?+Ybnhvp??
然后,對比了lambda 層和自注意力在resnet50架構上的imagenet分類任務效果。顯然,lambda層參數量是很少的,且準確率很高。
實驗
在大尺度高分辨率計算機視覺任務上進行了充分的實驗,和SOTA的EfficientNet相比,可以說無論是速度還是精度都有不小的突破。
其長子檢測任務上,LambdaResNet也極具優勢。
總結
作者提出了Lambda Layer代替自注意力機制,獲得了較好的改進。并借此設計了LambdaNetworks,其在各個任務上都超越了SOTA且速度提高了很多。如果實踐證明,Lambda Layer的效果具有足夠的魯棒性,在以后的研究中應該會被廣泛使用。
總結
以上是生活随笔為你收集整理的LambdaNetworks解读的全部內容,希望文章能夠幫你解決所遇到的問題。