细节满满!理解对比学习和SimCSE,就看这6个知识点
???????PaperWeekly 原創 · 作者 |?海晨威
研究方向?|?自然語言處理??????
2020 年的 Moco 和 SimCLR 等,掀起了對比學習在 CV 領域的熱潮,2021 年的 SimCSE,則讓 NLP 也乘上了對比學習的東風。下面就嘗試用 QA 的形式挖掘其中一些細節知識點,去更好地理解對比學習和 SimCSE。
如何去理解對比學習,它和度量學習的差別是什么?
對比學習中一般選擇一個 batch 中的所有其他樣本作為負例,那如果負例中有很相似的樣本怎么辦?
infoNCE loss 如何去理解,和 CE loss有什么區別?
對比學習的 infoNCE loss 中的溫度常數的作用是什么?
SimCSE 中的 dropout mask 指的是什么,dropout rate 的大小影響的是什么?
SimCSE 無監督模式下的具體實現流程是怎樣的,標簽生成和 loss 計算如何實現?
如何去理解對比學習,它和度量學習的差別是什么?
對比學習的思想是去拉近相似的樣本,推開不相似的樣本,而目標是要從樣本中學習到一個好的語義表示空間。
論文 [1] 給出的 “Alignment and Uniformity on the Hypersphere”,就是一個非常好的去理解對比學習的角度。
好的對比學習系統應該具備兩個屬性:Alignment和Uniformity(參考上圖)。
所謂“Alignment”,指的是相似的例子,也就是正例,映射到單位超球面后,應該有接近的特征,也即是說,在超球面上距離比較近;
所謂“Uniformity”,指的是系統應該傾向在特征里保留盡可能多的信息,這等價于使得映射到單位超球面的特征,盡可能均勻地分布在球面上,分布得越均勻,意味著保留的信息越充分。分布均勻意味著兩兩有差異,也意味著各自保有獨有信息,這代表信息保留充分(參考自 [2])。
度量學習和對比學習的思想是一樣的,都是去拉近相似的樣本,推開不相似的樣本。但是對比學習是無監督或者自監督學習方法,而度量學習一般為有監督學習方法。而且對比學習在 loss 設計時,為單正例多負例的形式,因為是無監督,數據是充足的,也就可以找到無窮的負例,但如何構造有效正例才是重點。
而度量學習多為二元組或三元組的形式,如常見的 Triplet 形式(anchor,positive,negative),Hard Negative 的挖掘對最終效果有較大的影響。
對比學習中一般選擇一個 batch 中的所有其他樣本作為負例,那如果負例中有很相似的樣本怎么辦?
在無監督無標注的情況下,這樣的偽負例,其實是不可避免的,首先可以想到的方式是去擴大語料庫,去加大 batch size,以降低 batch 訓練中采樣到偽負例的概率,減少它的影響。
另外,神經網絡是有一定容錯能力的,像偽標簽方法就是一個很好的印證,但前提是錯誤標簽數據或偽負例占較小的比例。
PS:也確有人考慮研究過這個問題,可以參考論文 [3][4]。
infoNCE loss 如何去理解,和 CE loss 有什么區別?
infoNCE loss 全稱 info Noise Contrastive Estimation loss,對于一個 batch 中的樣本 i,它的 loss 為:
要注意的是,log 里面的分母疊加項是包括了分子項的。分子是正例對的相似度,分母是正例對+所有負例對的相似度,最小化 infoNCE loss,就是去最大化分子的同時最小化分母,也就是最大化正例對的相似度,最小化負例對的相似度。
上面公式直接看可能沒那么清晰,可以把負號放進去,分子分母倒過來化簡一下就會很明了了。
CE loss,Cross Entropy loss,在輸入 p 是 softmax 的輸出時:
在分類場景下,真實標簽 y 一般為 one-hot 的形式,因此,CE loss 可以簡化成(i 位置對應標簽 1):
看的出來,info NCE loss 和在一定條件下簡化后的 CE loss 是非常相似的,但有一個區別要注意的是:
infoNCE loss 中的 K 是 batch 的大小,是可變的,是第 i 個樣本要和 batch 中的每個樣本計算相似度,而 batch 里的每一個樣本都會如此計算,因此上面公式只是樣本 i 的 loss。
CE loss 中的 K 是分類類別數的大小,任務確定時是不變的,i 位置對應標簽為 1 的位置。不過實際上,infoNCE loss 就是直接可以用 CE loss 去計算的。
注:1)info NCE loss 不同的實現方式下,它的計算方式和 K 的含義可能會有差異;2)info NCE loss 是基于 NCE loss 的,對公式推導感興趣的可以參考 [5]。
對比學習的 infoNCE loss 中的溫度常數 t 的作用是什么?
論文 [6] 給出了非常細致的分析,知乎博客 [7] 則對論文 [6] 做了細致的解讀,這里摘錄它的要點部分:
溫度系數的作用是調節對困難樣本的關注程度:越小的溫度系數越關注于將本樣本和最相似的困難樣本分開,去得到更均勻的表示。然而困難樣本往往是與本樣本相似程度較高的,很多困難負樣本其實是潛在的正樣本,過分強迫與困難樣本分開會破壞學到的潛在語義結構,因此,溫度系數不能過小。
考慮兩個極端情況,溫度系數趨向于 0 時,對比損失退化為只關注最困難的負樣本的損失函數;當溫度系數趨向于無窮大時,對比損失對所有負樣本都一視同仁,失去了困難樣本關注的特性。
還有一個角度:
可以把不同的負樣本想像成同極點電荷在不同距離處的受力情況,距離越近的點電荷受到的庫倫斥力更大,而距離越遠的點電荷受到的斥力越小。
對比損失中,越近的負例受到的斥力越大,具體的表現就是對應的負梯度值越大 [4]。這種性質更有利于形成在超球面均勻分布的特征。
對照著公式去理解:
當溫度系數很小時,越相似也即越困難的負例,對應的 就會越大,在分母疊加項中所占的比重就會越大,對整體 loss 的影響就會越大,具體的表現就是對應的負梯度值越大 [7]。
當然,這僅僅是提供了一種定性的認識,定量的認識和推導可以參見博客 [7]。
SimCSE 中的 dropout mask 指的是什么,dropout rate 的大小影響的是什么?
一般而言的 mask 是對 token 級別的 mask,比如說 BERT MLM 中的 mask,batch 訓練時對 padding 位的 mask 等。
SimCSE 中的 dropout mask,對于 BERT 模型本身,是一種網絡模型的隨機,是對網絡參數 W 的 mask,起到防止過擬合的作用。
而 SimCSE 巧妙的把它作為了一種 noise,起到數據增強的作用,因為同一句話,經過帶 dropout 的模型兩次,得到的句向量是不一樣的,但是因為是相同的句子輸入,最后句向量的語義期望是相同的,因此作為正例對,讓模型去拉近它們之間的距離。
在實現上,因為一個 batch 中的任意兩個樣本,經歷的 dropout mask 都是不一樣的,因此,一個句子過兩次 dropout,SimCSE 源碼中實際上是在一個 batch 中實現的,即 [a,a,b,b...] 作為一個 batch 去輸入。
dropout rate 大小的影響,可以理解為,這個概率會對應有 dropout 的句向量相對無 dropout 句向量,在整個單位超球體中偏移的程度,因為 BERT 是多層的結構,每一層都會有 dropout,這些 noise 的累積,會讓句向量在每個維度上都會有偏移的,只是 p 較小的情況下,兩個向量在空間中仍較為接近,如論文所說,“keeps a steady alignment”,保證了一個穩定的對齊性。
SimCSE 無監督模式下的具體實現流程是怎樣的,標簽生成和 loss 計算如何實現?
這里用一個簡單的例子和 Pytorch 代碼來說明:
前向句子 embedding 計算:
假設初始輸入一個句子集 sents = [a,b],每一句要過兩次 BERT,因此復制成? sents = [a,a,b,b]。
sents 以 batch 的形式過 BERT 等語言模型得到句向量:batch_emb = [a1,a2,b1,b2]。
batch 標簽生成:
標簽為 1 的地方是相同句子不同 embedding 對應的位置。
pytorch 中的 CE_loss,要使用一維的數字標簽,上面的 one-hot 標簽可轉換成:[1,0,3,2]。
可以把 label 拆成兩個部分:奇數部分 [1,3...] 和偶數部分 [0,2...],交替的每個奇數在偶數前面。因此實際生成的時候,可以分別生成兩個部分再 concat 并 reshape 成一維。
pytorch 中 label 的生成代碼如下:
#?構造標簽 batch_size?=?batch_emb.size(0) y_true?=?torch.cat([torch.arange(1,batch_size,step=2,dtype=torch.long).unsqueeze(1),torch.arange(0,batch_size,step=2,dtype=torch.long).unsqueeze(1)],dim=1).reshape([batch_size,])score 和 loss計算:
batch_emb 會先 norm,再計算任意兩個向量之間的點積,得到向量間的余弦相似度,維度是:[batch_size, batch_size]。
但是對角線的位置,也就是自身的余弦相似度,需要 mask 掉,因為它肯定是 1,是不產生 loss 的。
然后,要除以溫度系數,再進行 loss 的計算,loss_func 采用 CE loss,注意 CE loss 中是自帶 softmax 計算的。
????#?計算score和lossnorm_emb?=?F.normalize(batch_emb,?dim=1,?p=2)sim_score?=?torch.matmul(norm_emb,?norm_emb.transpose(0,1))sim_score?=?sim_score?-?torch.eye(batch_size)?*?1e12sim_score?=?sim_score?*?20????????#?溫度系數為?0.05,也就是乘以20loss?=?loss_func(sim_score,?y_true)完整代碼:
loss_func?=?nn.CrossEntropyLoss() def?simcse_loss(batch_emb):"""用于無監督SimCSE訓練的loss"""#?構造標簽batch_size?=?batch_emb.size(0)y_true?=?torch.cat([torch.arange(1,?batch_size,?step=2,?dtype=torch.long).unsqueeze(1),torch.arange(0,?batch_size,?step=2,?dtype=torch.long).unsqueeze(1)],dim=1).reshape([batch_size,])#?計算score和lossnorm_emb?=?F.normalize(batch_emb,?dim=1,?p=2)sim_score?=?torch.matmul(norm_emb,?norm_emb.transpose(0,1))sim_score?=?sim_score?-?torch.eye(batch_size)?*?1e12sim_score?=?sim_score?*?20loss?=?loss_func(sim_score,?y_true)return?loss注:看過論文源碼 [8] 的同學可能會發現,這個和論文源碼中的實現方式不一樣,論文源碼是為了兼容無監督 SimCSE 和有監督 SimCSE,并兼容有 hard negative 的三句輸入設計的,因此實現上有差異。
看過蘇神源碼 [9] 的同學也會發現,構造標簽的地方不一樣,那是因為 keras 的 CE loss 用的是 one-hot 標簽,pytorch 用的是數字標簽,但本質一樣。
參考文獻
[1] Understanding Contrastive Representation Learning through Alignment and Uniformity on the Hypersphere
[2] https://zhuanlan.zhihu.com/p/367290573
[3] Debiased Contrastive Learning
[4] ADACLR: Adaptive Contrastive Learning Of Representation By Nearest Positive Expansion
[5] https://zhuanlan.zhihu.com/p/334772391
[6] Understanding the Behaviour of Contrastive Loss
[7] https://zhuanlan.zhihu.com/p/357071960
[8] https://github.com/princeton-nlp/SimCSE
[9] https://github.com/bojone/SimCSE
更多閱讀
#投 稿?通 道#
?讓你的文字被更多人看到?
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學術熱點剖析、科研心得或競賽經驗講解等。我們的目的只有一個,讓知識真正流動起來。
?????稿件基本要求:
? 文章確系個人原創作品,未曾在公開渠道發表,如為其他平臺已發表或待發表的文章,請明確標注?
? 稿件建議以?markdown?格式撰寫,文中配圖以附件形式發送,要求圖片清晰,無版權問題
? PaperWeekly 尊重原作者署名權,并將為每篇被采納的原創首發稿件,提供業內具有競爭力稿酬,具體依據文章閱讀量和文章質量階梯制結算
?????投稿通道:
? 投稿郵箱:hr@paperweekly.site?
? 來稿請備注即時聯系方式(微信),以便我們在稿件選用的第一時間聯系作者
? 您也可以直接添加小編微信(pwbot02)快速投稿,備注:姓名-投稿
△長按添加PaperWeekly小編
????
現在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」訂閱我們的專欄吧
關于PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報道人工智能前沿論文成果的學術平臺。如果你研究或從事 AI 領域,歡迎在公眾號后臺點擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
總結
以上是生活随笔為你收集整理的细节满满!理解对比学习和SimCSE,就看这6个知识点的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 10两辆车发生碰撞责任怎么划分?
- 下一篇: 鸡西至勃利汽车坐几个小时?