多任务学习漫谈:以损失之名
?PaperWeekly 原創 ·?作者 | 蘇劍林
單位 | 追一科技
研究方向 | NLP、神經網絡
能提升模型性能的方法有很多,多任務學習(Multi-Task Learning)也是其中一種。簡單來說,多任務學習是希望將多個相關的任務共同訓練,希望不同任務之間能夠相互補充和促進,從而獲得單任務上更好的效果(準確率、魯棒性等)。然而,多任務學習并不是所有任務堆起來就能生效那么簡單,如何平衡每個任務的訓練,使得各個任務都盡量獲得有益的提升,依然是值得研究的課題。
最近,筆者機緣巧合之下,也進行了一些多任務學習的嘗試,借機也學習了相關內容,在此挑部分結果與大家交流和討論。
加權求和
從損失函數的層面看,多任務學習就是有多個損失函數 ,一般情況下它們有大量的共享參數、少量的獨立參數,而我們的目標是讓每個損失函數都盡可能地小。為此,我們引入權重,通過加權求和的方式將它轉化為如下損失函數的單任務學習:
在這個視角下,多任務學習的主要難點就是如何確定各個 了。
初始狀態
按道理,在沒有任務先驗和偏見的情況下,最自然的選擇就是平等對待每個任務,即。然而,事實上每個任務可能有很大差別,比如不同類別數的分類任務混合、分類與回歸任務混合、分類與生成任務混合等等,從物理的角度看,每個損失函數的量綱和量級都不一樣,直接相加是沒有意義的。
如果我們將每個損失函數看成具有不同量綱的物理量,那么從“無量綱化”的思想出發,我們可以用損失函數的初始值倒數作為權重,即
其中 表示任務 的初始損失值。該式關于每個 是“齊次”的,所以它的一個明顯優點是縮放不變性,即如果讓任務 的損失乘上一個常數,那么結果不會變化。此外,由于每個損失都除以了自身的初始值,較大的損失會縮小,較小的損失會放大,從而使得每個損失能夠大致得到平衡。
那么,怎么估計 呢?最直接的方法當然是直接拿幾個 batch 的數據來估算一下。除此之外,我們可以基于一些假設得到一個理論值。比如,在主流的初始化之下,我們可以認為初始模型(加激活函數之前)的輸出是一個零向量,如果加上 softmax 則是均勻分布,那么對于一個“ 分類+交叉熵”問題,它的初始損失就是 ;對于“回歸+ L2 損失”問題,則可以用零向量來估計初始損失,即 , 是訓練集的全體標簽。
先驗狀態
用初始損失的一個問題是初始狀態不一定能很好地反應當前任務的學習難度,更好的方案應該是將“初始狀態”改為“先驗狀態”:
比如,如果 分類中每個類的頻率分別是 (先驗分布),那么雖然初始狀態的預測分布為均勻分布,但我們可以合理地認為模型可以很容易學會將每個樣本的結果都預測為 ,此時模型的損失為熵
某種意義上來說,“先驗分布”比“初始分布”更能體現出“初始”的本質,它是“就算模型啥都學不會,也知道按照先驗分布來隨機出結果”的體現,所以此時的損失值更能代表當前任務的初始難度,因此用 代替 應該更加合理;類似地,對于“回歸+L2損失”問題,它的先驗結果應該是全體標簽的期望 ,所以我們用 代替 ,有望取得更合理的結果。
動態調節
不管是用初始狀態的式(2)還是先驗狀態的式(3),它們的任務權重在確定之后就保持不變了,并且它們確定權重的方法不依賴于學習過程。然而,盡管我們可以通過先驗分布等信息簡單感知一下學習難度,但究竟有多難其實要真正去學習才知道,所以更合理的方案應該是根據訓練進程動態地調整權重。
實時狀態
縱觀前文,式(2)和式(3)的核心思想都是用損失值的倒數來作為任務權重,那么能不能干脆用“實時”的損失值倒數來實現動態調整權重?即:
這里的 是 的簡寫。在這個方案中,每個任務的損失函數都被調整恒為 1,所以不管是量綱還是量級上都是一致的。由于 算子的存在,雖然損失恒為 1,但梯度并非恒為 0:
簡單來說就是加上 算子后,它的值不變,但是導數為 0,所以最終結果就是以動態權重 來實時調整了梯度的比例。很多“民間實驗”表明,式(5)確實在多數情況下都可以作為一個相當不錯的 baseline。
等價梯度
我們可以從另一個角度來看該方案。從式(6)我們可以得到:
因此從梯度上看,式(5)與 沒有實質區別,而我們進一步有:
由于 是單調遞增的,所以式(5)與下式在梯度方向上是一致:
廣義平均
顯然,上式正是 的“幾何平均”,而如果我們約定 恒等于 ,那么原始的式(1)就是 的“代數平均”。也就是說,我們發現這一系列的推導其實隱藏了從代數平均到幾何平均的轉變,這啟發我們或許可以考慮“廣義平均”:
也就是將每個損失函數算 次方后再平均最后再開 次方,這里的 可以是任意實數,代數平均對應 ,而幾何平均對應 (需要取極限)。可以證明, 是關于 的單調遞增函數,并且有:
這就意味著,當 增大時,模型愈發關心損失中的最大值,反之則更關心損失中的最小值。這樣一來,雖然依然存在超參數 要調整,但是相比于原始的式(1),超參數的個數已經從 個變為只有 1 個,簡化了調參過程。
平移不變
重新回顧式(2)、式(3)和式(5),它們都是通過每個任務損失除以自身的某個狀態來調節權重,并且獲得了縮放不變性。然而,盡管它們都具備了縮放不變性,但卻失去了更基本的“平移不變性”,也就是說,如果每個損失都加上一個常數,(2)、式(3)和式(5)的梯度方向是有可能改變的,這對于優化來說并不是一個好消息,因為原則上來說常數沒有帶來任何有意義的信息,優化結果不應該隨之改變。
理想目標
一方面,我們用損失函數(的某個狀態)的倒數作為當前任務的權重,但損失函數的導數不具備平移不變性;另一方面,損失函數可以理解為當前模型與目標狀態的距離,而梯度下降本質上是在尋找梯度為 0 的點,所以梯度的模長其實也能起到類似作用,因此我們可以用梯度的模長來替換掉損失函數,從而將式(5)變成:
跟損失函數的一個明顯區別是,梯度模長顯然具備平移不變性,并且分子分母關于 依然是齊次的,所以上式還保留了縮放不變性。因此,這是一個能同時具備平移和縮放不變性的理想目標。
梯度歸一
對式(12)求梯度,我們得到:
可以看到,式(12)本質上是將每個任務損失的梯度進行歸一化后再把梯度累加起來。它同時也告訴了我們一種實現方案,即可以讓每個任務依次訓練,每次只訓練一個任務,然后將每個任務的梯度歸一化后累積起來再更新,這樣就免除了在定義損失函數的時候就要算梯度的麻煩了。
關于梯度歸一化,筆者能找到相關工作是《GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks》[1],它本質上是式(2)和式(13)的混合,里邊也包含了對梯度模長重新標定的思想,但卻要通過額外的優化來確定任務權重,個人認為顯得繁瑣和冗余了。
本文小結
在損失函數的視角下,多任務學習的關鍵問題是如何調節每個任務的權重來平衡各自的損失,本文從縮放不變和平移不變兩個角度介紹了一些參考做法,并補充了“廣義平均”的概念,將多個任務的權重調節轉化為單個參數的調節問題,可以簡化調參難度。
參考文獻
[1]?https://arxiv.org/abs/1711.02257
特別鳴謝
感謝 TCCI 天橋腦科學研究院對于 PaperWeekly 的支持。TCCI 關注大腦探知、大腦功能和大腦健康。
更多閱讀
#投 稿?通 道#
?讓你的文字被更多人看到?
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學術熱點剖析、科研心得或競賽經驗講解等。我們的目的只有一個,讓知識真正流動起來。
📝?稿件基本要求:
? 文章確系個人原創作品,未曾在公開渠道發表,如為其他平臺已發表或待發表的文章,請明確標注?
? 稿件建議以?markdown?格式撰寫,文中配圖以附件形式發送,要求圖片清晰,無版權問題
? PaperWeekly 尊重原作者署名權,并將為每篇被采納的原創首發稿件,提供業內具有競爭力稿酬,具體依據文章閱讀量和文章質量階梯制結算
📬?投稿通道:
? 投稿郵箱:hr@paperweekly.site?
? 來稿請備注即時聯系方式(微信),以便我們在稿件選用的第一時間聯系作者
? 您也可以直接添加小編微信(pwbot02)快速投稿,備注:姓名-投稿
△長按添加PaperWeekly小編
🔍
現在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」訂閱我們的專欄吧
·
總結
以上是生活随笔為你收集整理的多任务学习漫谈:以损失之名的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 当特种兵需要什么条件?
- 下一篇: 坐p22路去黄陂一中,在哪里下车?