从2021年多篇顶会论文看OOD泛化新理论、新方法和新讨论
?PaperWeekly 原創 ·?作者?|?張一帆
學校?|?華南理工大學本科生
研究方向?|?CV,Causality
arXiv 2021
論文標題:
Towards a Theoretical Framework of Out-of-Distribution Generalization
論文鏈接:
https://arxiv.org/abs/2106.04496
這篇文章應該是今年投稿 NeurIPS 的文章,文章貢獻有兩點:
在 OOD 泛化受到極大關注的今天,一個合適的理論框架是非常難得的,就像 DA 的泛化誤差一樣;
本文通過泛化誤差提出了模型選擇策略,不單純使用驗證集的精度,二十同時考慮驗證集的精度和在各個 domain 驗證精度的方差。
1.1 Preliminary
先來看一看 OOD 經典的問題建模,考慮一個多分類問題 。用 表示可見的訓練集,以及所有集合。 表示輸入-標簽組,OOD 泛化問題就是要找一個分類器 來最小化 worst-domain loss:
這里的 是假設空間, 是損失函數。 同樣可以分解為 ,即分類器和特征提取器。 可以寫為:
是一個標量的特征映射, 是預設的特征維度。下文將 簡寫為 。
1.2 Framework of OOD Generalization Problem
對 OOD 問題的分析難點在于如何構建 和 之間的聯系,以及域泛化和二者聯系之間的聯系。接下來我們就一步步的看看這篇文章是如何進行構建的。
作者先介紹了兩個定義:特征的 “variation(變化)”和 “informativeness(信息量)”。前者是一個類似于 divergence 的概念,我們希望對同一個 label,在各個域上的特征變化不大。后者表示了這個特征要有足夠的表示能力,包含了區分各個標簽的能力。
Variation:給定如下定義,如果一個特征滿足 ,那么我們說他是是 -invariant 的:
Informativeness:給定如下定義,如果一個特征滿足 ,那么我們說他是是 -Informative 的:
有了這兩個定義,接下來就進入最難的環節,構建 和 之間的聯系,本文是基于這樣一個假設“如果一個特征包含足夠的信息,而且在 上能夠做到 invariant,那么就能夠泛化到 ”,這個假設還是挺強的,但是暫時也沒有更好的替代方案,從該假設來看,, 的聯系捕獲了 OOD 泛化的可行性和難度。為了定量的測量這個聯系,我們還需要這樣一個函數類。
Expansion Function:這是一個函數 ,如果它滿足:1)單調遞增且?;2),我們稱之為一個擴增函數。
這個函數定義了 , 之間的關系,我們可以想一下,如果可見域只是全部域的一小部分,那么 就是一個非常陡峭的函數,否則如果可見域就是全體域,那么 。
有了這三樣東西,我們來定義最后一個最重要的概念
Learnability:對所有滿足信息容量 的特征提取器而言,如果存在上述的 和一個擴增函數 ,使得 我們稱一個 OOD 問題是可學習的。
原文還提供了一些討論幫助讀者更好的理解這幾個問題。
1.3 Generalization Bound
接下來的推導就是文章最復雜的部分了,對于分類器 ,我們定義泛化誤差為:
這里只有一個假設:損失函數有界 。接下來主要講講對它證明的直觀理解,不涉及具體推導。
首先上述損失可以推導出這樣一個 bound,這一步將 loss 之間的差值轉化為了 分布之差。
接下來就要根據假設,將分布的差轉化為常數項和 Variation 或者 Informativeness 相關的項,這里用到了傅里葉反變換公式以及較多的數學轉化,最終得到了如下這樣一個復雜的結果:
將常數項簡化一下就得到了誤差上界:
類似的,我們也可以推出一個下界:
看到這里可能有人疑惑了,上下界都和 variation 有關,但是和 Informativeness 無關,那我輸出全 0 向量不就可以做到 invariant 了嗎?答案是否定的,在 bound 的證明中總是假設該問題滿足 Learnability,而 Learnability 關鍵的一點就是限制信息容量大于一個定值。
所以這個 bound 對我們的啟發在于,為了追求良好的 OOD 性能,OOD 算法應同時注重提高預測性能和控制 variation 的變化(好像大家一直都是這么做的)。
1.4 Variation as a Factor of Model Selection Criterion
本文中提出了一種新的模型選擇策略,如果我們按照驗證集的總體精確度來選擇最終的模型,其實沒有幾個模型比 ERM 好很多,這一結果并不奇怪,因為傳統的選擇方法主要關注(驗證)準確性,這在 OOD 概化中有偏倚。
相反本文沒有單獨考慮驗證精度,而是將其與 variation 相結合,選擇了高驗證精度和低 variation 的模型。文中也通過實驗驗證了這種選擇策略的有效性。其中 “val” 就是傳統的模型選擇策略。
ICML 2021
論文標題:
Can Subnetwork Structure be the Key to Out-of-Distribution Generalization?
論文鏈接:
https://arxiv.org/abs/2106.02890
本文基于彩票假設,即使我們整體的模型是有偏的(偏向背景,上下文而不是物體本身),這個網絡中也存在一些子網絡他們是無偏的,可以實現更好的 OOD 性能。文中提供了模塊風險最小化 MRM 來尋找這些“彩票”。
MRM 算法理解起來也并不困難:
給定數據,完整的網絡,子網絡的 logits ,logit 是一個用于產生 mask 的隨機分布,比如網絡第 層有 個參數,那么 。該層的 mask 通過從 中采樣得到,mask 將完整網絡轉化為子網絡(= 0 即忽略第 層的第 個參數);
我們對模型進行初始化然后使用 ERM 的目標進行訓練 個 step;
我們從整個網絡中采樣子網絡,結合交叉熵和稀疏正則化作為損失函數來學習有效的子網結構;
最后只需要簡單地只使用所得到的子網中的權值重新進行訓練,并將其他權值固定為零。
文章最大的亮點就在于 MRM 和目前主流的研究方向(修改目標函數)是正交的,無論目標函數是什么,MRM 都能找到這么些泛化能力更強的子網絡。
ICLR 2021
論文標題:
Understanding the failure modes of out-of-distribution generalization
論文鏈接:
https://arxiv.org/abs/2010.15775
代碼鏈接:
https://github.com/google-research/OOD-failures
3.1 Motivation
現有的理論可以解釋為什么當不變性特征本身信息不足時,分類器依賴于虛假特征(下圖 a)。但是,當不變特征完全能夠預測標簽時,這些解釋就不成立了。
比如在下圖 b 中,顯然我們只需要 max-margin 就可以很好的識別 camel 和 cow,但是真實情況是,分類器依然會依賴于這些虛假特征,其實在很多現實設置中,形狀或者說輪廓信息是完全可以預測對應的標簽的,但是我們的分類器總會依賴于類似于背景或者顏色等虛假的信息。作者發現了兩個影響因素,這是本文的核心貢獻。
3.2 Easy to learn domain generalization tasks
之前的探索 ERM 失敗原因的工作主要基于兩種模式:
不變性特征和虛假特征都只能部分預測標簽,因此一個優化負對數似然的分類器當然不能錯過虛假特征包含的信息;
不變性特征和虛假特征都能完全預測標簽,但是虛假特征更容易學習(更加線性),因此梯度下降會選擇更容易學習的特征進行分類。
本文對這些假設進行了質疑,構造任務時針對以上每一點進行了回應,其任務有以下特點:
不變性的特征有足夠的能力完成對標簽的預測,虛假特征不能完全預測標簽;
不變性特征有一個線性的分類 boundary,很好學習。
作者驗證了,即使在這樣一個易學習的任務中,ERM 仍然還是依賴于虛假特征。作者總結并驗證了兩個原因。
3.3 Failure due to geometric skews
首先文中指出了一點,隨著數據量的增大,max-margin 分類器的 norm 也隨之增大,geometric skews 是一個非正式的稱謂,用于形容過參數化的神經網絡,考慮上圖 c 中的場景,我們可以把數據集分為兩類:
多數類 ,對應 cow/camel 在 green/yellow 背景下。
少數類 ,對應 cow/camel 在 yellow/green 背景下。
假設我們使用不變性特征來進行預測將花費 的 norm,使用不變性特征對少數類進行預測的花費為 ,因為數據量的原因,。
那么此時我們還可以這樣進行預測:使用虛假特征作為捷徑來預測多數類,然后使用不變性特征來預測少數類,這種方式需要更少的 norm,因此更容易成為我們max-margin 分類器選擇的策略。
3.4 Failure due to statistical skews
上文研究的是 max-margin 分類器,我們知道對于一個線性分類器而言,在一定條件下,在線性可分離的數據集上,這些分類器會在無限時間下收斂到 max-margin 分類器,也就是說即使他收斂了也會收到 geometric skew 的影響。那么這個收斂過程本身會不會引入一定的 spurious correlation 呢。
作者通過觀察發現,即使我們的數據集不存在 geometric skew,即 max-margin 分類器不會失敗,我們花費超長時間訓練一個線性分類器使他收斂,他依然會依賴于虛假特征。作者在文章推導出了一個收斂性隨偽相關而變化的 bound 來討論使用梯度下降訓練的過程中引入的偽相關。
總結一下,目前大部分注意力都集中在實用主義或啟發式解決方案(設計或學習“不變”特性的各種技巧)上,而我們對 OOD 情況中出錯原因的基本理解仍然不完整。本文旨在通過研究簡化的設置來填補這些理解上的空白,并提出這樣一個問題:當任務可以只使用安全的(“不變的”)特性來解決時,為什么統計模型要學習使用易變化的特性(“虛假的”特性)。在制定了多個約束條件(保證對容易學習的任務適用)后,他們表明失敗有兩種形式:幾何傾斜和統計傾斜。他們依次進行分析和解釋,同時也提供了說明性的實證結果。
ICML 2021 Oral
論文標題:
Domain Generalization using Causal Matching
論文鏈接:
https://arxiv.org/abs/2006.07500
代碼鏈接:
https://github.com/microsoft/robustdg
這篇文章乍一看非常簡單,但是細看之后發現其實有很多地方理解起來并不容易。
這篇文章的主要貢獻在于:
作者 argue 了一件事情,我們以往學習的不變性特征表達包括 與 domain 無關還是 與 domain 無關其實都是有問題的,根據文中假設的因果圖來看,要真正捕捉到域不變特征,我們需要約束 不變,其中 是圖像的 object 信息。
作者加了一項看著很簡單的約束:擁有相同的對象(object)的跨域圖像也應該有相同的表示。
文中涉及的證明比較多,這里只闡述 high-level 的觀點。首先我們來分析一下傳統的 ERM 算法:
分別是我們可見的數據分布和總體的數據分布,文中證明了當可見域的數目等于所有域,樣本數量趨于正無窮時,ERM 能夠收斂到最優分類器。然而正常情況下我們的可見域只是數據域的一部分 ,因此 ERM 就會過擬合。
這是文中提出的結構化因果圖,對因果不了解的朋友們無需擔心,總之一大堆證明就是為了提出我們需要的 object 特征 要滿足 ,這個條件其實不難理解,就是說對同一個對象而言,它的特征不應該隨著 domain 變化,所以文中在 ERM 的基礎上添加了如下約束:
作者證明了:
滿足上述約束的分類器中包含了最有分類器;
在具有虛假相關性的數據集中,優化如下的損失函數能夠帶來最優分類器。
到這里文章的內容好像已經完整了?其實不然,考慮一個數據非常不平衡的數據集,一個 domain 中擁有超多 object A,其他 domain 基本沒有,那么上述的 match 其實是在不斷地減小同一個 domain 下同一類的特征距離,這對泛化是沒有太大好處的。
對于 Rotated MNIST 這類的數據集,因為是通過數據增強的方式構造的,因此非常的 balance,但是對于更加真實的數據集,這個關系顯然是不成立的,這就是我對于文中?object information is not always available, and in many datasets there maynot be a perfect “counterfactual” match based on same object across domain?這句話的理解。
那么如何避免我們對 class-balance 的過度依賴,在沒有非常好的?counterfactual sample?的情況下也能近似上述的約束呢?答案是學習一個 matching,這才是文章的關鍵。
因此接下來文章的邏輯就比較清晰了,作者將算法分成了兩部分,首先學習一個 match,然后再利用這個 match 近似上面的約束:
具體的實現過程是這樣的:
Initialization(構造 random match):首先我們對每一個類選擇一個基域(包含該類元素最多的類),對基類的所有數據點進行遍歷。對每個數據點,我們隨機的在剩下 K-1 個域中給他匹配標簽相同的元素,因此會構造出一個 (N',K) 大小的數據矩陣,這里 N' 即所有類的基域大小之和,K 是總共的域的數目。
Phase 1:采樣一個 batch 的數據 (B,K),對 batch 中的每個數據點最小化對比損失,和他具有相同 object 不同域的樣本作為正樣本,不同 object 樣本作為負樣本。
每 t 個 epoch 使用通過對比學習學到的 representation 更新一次我們的 match。首先還是要選基域,但是在基域選定后,我們不再隨機的在剩下域中挑選 sample,我們為基域中的該類的每個樣本在其他域中找 representation 距離最近的點作為正樣本。
在 Phase 1 結束時,我們根據學習到的最終表示的 距離更新匹配的數據矩陣。我們稱這些匹配為推論匹配。
Phase 2:我們使用下列損失函數,但是 match 使用我們第一階段學到的。網絡從頭開始訓練(第一階段學到的網絡只是用來做匹配而已)。但是第一階段學到的匹配可能不能包含所有的數據點,因此作者在每次訓練除了從數據矩陣采樣(B,K)的數據外,還通過隨機匹配再產生(B,K)的數據。
簡單看一下實驗效果,對 MNIST 類的任務,存在 perfect match,效果非常顯著。
對 PACS 這類任務不存在 perfect match,作者將 MatchDG 結合數據增強進一步提升(MDGHybrid),效果也是挺不錯的。
ICML 2021 Spotlight
論文標題:
Environment Inference for Invariant Learning
論文鏈接:
https://arxiv.org/abs/2010.07249
代碼鏈接:
https://github.com/ecreager/eiil
沒有 domain label 怎么做 OOD 泛化?這篇文章就回答了這樣一個有趣的問題。給出的答案也非常的 interesting:我們自己推斷 domain label 甚至能達到比使用真實域標簽更好的性能。
首先文章的 motivation 在于,無論是從隱私還是標簽的獲取來看,域標簽都是難以取得的。除此之外,在某些情況下,相關的信息或元數據(例如,人的注釋、用于拍攝醫療圖像的設備 ID、醫院或部門 ID 等)可能非常豐富,但目前還不清楚如何最好地基于這些信息指定環境。設計算法避免人工定義環境是這篇文章的出發點。
所以很直觀的,算法應該分成兩部走:
推斷環境標簽;
利用環境標簽學習域不變性特征。
文章的模型很有意思,我們先選擇一個已有的學習域不變特征的算法(模型 ),文中用了 IRM 和 GroupDRO。
在第一步推斷標簽的時候,我們選擇最違背域不變特征的標簽分配方式,分配標簽使得 IRM,GroupDRO 這些算法的分類性能最差。即固定住模型 ,然后優化 EI(environment inference EI)目標,估計標簽變量 最違背域不變特征。
固定住我們 inference 的標簽 ,優化 invariant learning(IL)目標來產生新模型 。
那么現在未知量就剩下了 EI, IL 這兩個目標如何構造。IL 其實就是 IRM,GroupDRO 的優化目標本身。對于 IRM 來說既是:
那么對于 EI 目標,差別只是在于 EI 的時候沒有現成的 enviroment label 可以用,也就是說傳統 IRM 的逐環境損失可以寫作如下形式:
其中 是環境自帶的 domain label,這時候我們沒有這個東西,因此將其替換為我們的概率分布 ,一個 soft 版本的 domain label。我們只需要優化這個概率分布,使得本輪固定的分類器 更差即可。
有趣的是,即使 domain label 是可用的,他也不比我們推斷出來的 label 好。EIIL 使用推斷出來的 label,比直接使用 IRM 好更多。
CVPR 2021 Oral
論文標題:
Reducing Domain Gap by Reducing Style Bias
論文鏈接:
https://arxiv.org/abs/1910.11645
代碼鏈接:
https://github.com/hyeonseobnam/sagnet
CNN 對圖像紋理這類的風格元素具有很強的歸納偏置,因此對域變化非常敏感。相反其對物體形狀這類真正和標簽相關的元素卻不敏感。本文提出了一種將 style和 content 分離開的簡單方法,可以作為一種新的 backbone。
文章結構非常簡單,一個 feature extractor 兩個 head。content-bias head 想要做的事是將 style 信息打亂,同時還確保分類結果正確,也就是讓這個 head 更關注于 content 信息。相反 style-bias head 將風格信息打亂,讓這個 head 更關注于 style 信息,與此同時一個對抗學習就可以讓 backbone 產生更少的 style-bias representation。
看到這里其實難點已經很明確了,如何將 style/content 信息打亂? 文章基于這樣一個假設,channel-wise 的均值和方差作為風格信息,spatial configuration 作為 style 信息,這樣一個假設已經被以往很多工作采用了,不過本文提出了一個更新的使用方式。首先我們來看如何打亂 style 信息。
先求一個 channel-wise 的均值和方差:
然后文中提出了一個 SR 模塊,通過將 和另一個隨機的 的風格信息進行插值,構造一個隨機的 ,然后通過 AdaIN 將 的風格信息替換成這個隨機的風格,這樣就完成了風格的 shuffle。
接下來我們只需要將這個通過 SR 模塊的 representation 喂給內容分類器 進行分類正常計算分類損失即可。
對于 style-bias head,我們反其道而行之,構造一個 CR 模塊:
同樣通過 CR 模塊的 representation 喂給風格分類器,然后風格分類器的優化就是利用 style 信息來預測標簽:
很直觀的,我們可以想到,插入一個 GRL 來訓練 backbone,使得 backbone 產生更少的 style 信息,文中采取了類似的策略,只不過不是插入 GRL,而是用了最大熵:
最大熵其實是具有很不錯的性質的,在我最近的一篇工作中我簡單的分析了這個類型的損失函數,他能起到風格信息和 representation 互信息最小化的作用。
https://arxiv.org/abs/2103.15890
具體如下面的式子, 即互信息, 即熵。
簡單看一下在 domain generalization 上的實驗結果:
文章選擇的 baseline 其實并不多,也沒有 resnet50 這種大型 backbone 的結果,但是從文中展示的內容來看,SagNet 相比于現有的大多數方法還是有一定優勢的。對我而言我覺得難得的是,它提供了一種 style/content 信息新的提取方式,以往的工作往往需要兩個 encoder 來提取 content/style 信息。
更多閱讀
#投 稿?通 道#
?讓你的文字被更多人看到?
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學術熱點剖析、科研心得或競賽經驗講解等。我們的目的只有一個,讓知識真正流動起來。
?????稿件基本要求:
? 文章確系個人原創作品,未曾在公開渠道發表,如為其他平臺已發表或待發表的文章,請明確標注?
? 稿件建議以?markdown?格式撰寫,文中配圖以附件形式發送,要求圖片清晰,無版權問題
? PaperWeekly 尊重原作者署名權,并將為每篇被采納的原創首發稿件,提供業內具有競爭力稿酬,具體依據文章閱讀量和文章質量階梯制結算
?????投稿通道:
? 投稿郵箱:hr@paperweekly.site?
? 來稿請備注即時聯系方式(微信),以便我們在稿件選用的第一時間聯系作者
? 您也可以直接添加小編微信(pwbot02)快速投稿,備注:姓名-投稿
△長按添加PaperWeekly小編
????
現在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」訂閱我們的專欄吧
關于PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報道人工智能前沿論文成果的學術平臺。如果你研究或從事 AI 領域,歡迎在公眾號后臺點擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
總結
以上是生活随笔為你收集整理的从2021年多篇顶会论文看OOD泛化新理论、新方法和新讨论的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 基金回调提前知道的技巧 一定要注意股市
- 下一篇: ACL 2021 | ConSERT:基