TransE
網(wǎng)絡(luò)上已經(jīng)存在了大量知識庫(KBs),比如OpenCyc,WordNet,Freebase,Dbpedia等等。這些知識庫是為了各種各樣的目的建立的,因此很難用到其他系統(tǒng)上面。為了發(fā)揮知識庫的圖(graph)性,也為了得到統(tǒng)計學(xué)習(xí)(包括機器學(xué)習(xí)和深度學(xué)習(xí))的優(yōu)勢,我們需要將知識庫嵌入(embedding)到一個低維空間里(比如10、20、50維)。我們都知道,獲得了向量后,就可以運用各種數(shù)學(xué)工具進行分析。深度學(xué)習(xí)的輸入也是向量。(考慮一下,word2vec,我們訓(xùn)練出一個向量后,可以做好多事情,深度學(xué)習(xí)的輸入也往往是一個矩陣。
TransE的直觀含義,就是TransE基于實體和關(guān)系的分布式向量表示,將每個三元組實例(head,relation,tail)中的關(guān)系relation看做從實體head到實體tail的翻譯。通過不斷調(diào)整h、r和t(head、relation和tail的向量),使(h + r) 盡可能與 t 相等,即 h + r = t。
以前有很多種訓(xùn)練三元組的方法,但是參數(shù)過多,以至于模型過于復(fù)雜難以理解(作者表達的意思就是,我們的工作效果和你們一樣,但我們的簡單易擴展)。(ps:作者以前也做過類似的工作,叫做Structured Embeddings,簡稱SE,只是將實體轉(zhuǎn)為向量,關(guān)系是一個矩陣,利用矩陣的不可逆性反映關(guān)系的不可逆性。距離表達公式是1-norm)。也就是:
d(h,r,t)=∥h+r?t∥22d(h,r,t) = \left \| h+r-t \right \|_2^2 d(h,r,t)=∥h+r?t∥22?
TransE的訓(xùn)練
這里注意,entity需要在每次更新前進行歸一化,這是通過人為增加embedding的norm來防止Loss在訓(xùn)練過程中極小化。
TransE的Loss function (Hinge Loss Function) 為:
L=∑(h,l,t)∈S∑(h′,l,t′)∈S′(h,l,t)[γ+d(h+l,t)?d(h′+l,t′)]+\mathcal{L} = \sum_{(h,l,t) \in S }\sum_{(h^\prime,l,t^\prime) \in S^\prime(h,l,t)} [\gamma + d(h+l, t)-d(h^\prime+l, t^\prime)]_+ L=(h,l,t)∈S∑?(h′,l,t′)∈S′(h,l,t)∑?[γ+d(h+l,t)?d(h′+l,t′)]+?
其中[x]+[x]_+[x]+?表示x最小值為0, 即[x]+=max(0,x)[x]_+=max(0,x)[x]+?=max(0,x),S′S^\primeS′表示負例的集合。
直觀上,我們要前面的項(原三元組)變小(positive),后面的項(打碎的三元組corrupt triple)變大(negative)。就跟喂小狗一樣,它做對了,就給骨頭吃;做錯了,就打兩下。前面的項是對的(來自于訓(xùn)練集),后面的項是錯的(我們隨機生成的)。不同時打碎主體和客體,隨機挑選一個打碎,另一個保持不變,這樣才能夠有對照性。
合頁損失函數(shù)(hinge loss function),這種訓(xùn)練方法叫做margin-based ranking criterion。是不是聽起來很熟悉?對的,就是來自SVM。支持向量機也是如此,要將正和負盡可能分開,找出最大距離的支持向量。同理,TransE也是如此,我們盡可能將對的和錯的分開。margin值一般設(shè)為1了
L2L_2L2?范數(shù)形式如下:
∣∣x∣∣2=(∣x1∣2+∣x2∣2+...+∣xn∣2)||x||_2=\sqrt{(|x_1|^2+|x_2|^2+...+|x_n|^2)} ∣∣x∣∣2?=(∣x1?∣2+∣x2?∣2+...+∣xn?∣2)?
因此以norm=L2norm = L2norm=L2范數(shù)為例,求解正確三元組的hhh的相對于hinge loss function的梯度:
?loss?h=?[γ+(h+r?t)2?(h′+r?t′)2]?h={2(h+r?t),ifγ+(h+r?t)2?(h′+r?t′)2≥00,ifγ+(h+r?t)2?(h′+r?t′)2<0\frac{\partial loss}{\partial h}= \frac {\partial [\gamma + (h+r-t)^2- (h^\prime + r -t^\prime)^2]}{\partial h}= \begin{cases} 2(h+r-t) , if \gamma + (h+r-t)^2- (h^\prime + r -t^\prime)^2 \ge 0 \\ 0, if \gamma + (h+r-t)^2- (h^\prime + r -t^\prime)^2 < 0 \end{cases} ?h?loss?=?h?[γ+(h+r?t)2?(h′+r?t′)2]?={2(h+r?t),ifγ+(h+r?t)2?(h′+r?t′)2≥00,ifγ+(h+r?t)2?(h′+r?t′)2<0?
L1L_1L1?范數(shù)形式如下:
∣∣x∣∣1=∣x1∣+∣x2∣+...+∣xn∣||x||_1=|x_1|+|x_2|+...+|x_n| ∣∣x∣∣1?=∣x1?∣+∣x2?∣+...+∣xn?∣
而L1范數(shù)的梯度則以[1,1,-1,1…]形式出現(xiàn)。
對于模型中Margin的個人理解如下:margin 的作用相當(dāng)于是一個正確triple與錯誤triple之前的間隔修正,margin越大,則兩個triple之前被修正的間隔就越大,則對于詞向量的修正就越嚴(yán)格。
開始閱讀論文的時候在糾結(jié)一個問題,”這個模型的參數(shù)是什么?如何更新?“。后來通過讀原文發(fā)現(xiàn)其實文章后續(xù)中有說明,第3章提到了參數(shù)的總量為O(nek+nrk)O(n_ek+n_rk)O(ne?k+nr?k),也就是說Loss更新的參數(shù),是所有entities和relations的Embedding數(shù)據(jù),每一次SGD更新的參數(shù)就是一個Batch中所有embedding的值。TransE里面SGD和一般機器學(xué)習(xí)方法或者深度學(xué)習(xí)中SGD中的參數(shù)還是有些區(qū)別的。
關(guān)于參數(shù)的更新:我們使用的是隨機梯度下降(Stochastic Gradient Descent,SGD)訓(xùn)練方法。SGD不用對所有的和求梯度,而是對一個batch求梯度之后就立即更新theta值。
對于數(shù)據(jù)集大的情況下,有速度。但是每一次更新都是針對這一個batch里的三元組的向量更新的,也就是意味著,一次更新最多更新(3+2)*batch_size*d 個參數(shù)(設(shè)一個batch的長度為batch_size)。并不是把所有的theta值都更新了, 或者說不用更新整個( |E| + |R| ) * d 矩陣,只需要更新sample里抽出來的batch里的向量即可。為什么可以這樣呢(也就是為什么可以不用把參數(shù)全更新了,而是只更新一部分)?因為參數(shù)之間并沒有依賴(或者說沖突conflict),對于此,可以參考論文 Hogwild!: A Lock-Free Approach to Parallelizing Stochastic。
另外,距離公式d(h+r?t)d(h+r-t)d(h+r?t)可以取L1L_1L1?或者L2L_2L2?范數(shù),對于L1L_1L1?范數(shù),ddd是求絕對值結(jié)果d(h+r,t)=∣h+r?t∣d(h+r,t)=|h+r-t|d(h+r,t)=∣h+r?t∣;而對于L2L_2L2?范數(shù),d(h+r,t)=(h+r?t)2d(h+r,t)=(h+r-t)^2d(h+r,t)=(h+r?t)2。其中要注意的是,L1L_1L1?范數(shù)在x=0x=0x=0處不可導(dǎo),所以需要使用次微分概念。另一方面,Loss function 希望達到的理想情況是,正確的triple的d(h+r,t)d(h+r,t)d(h+r,t)盡可能小, 而錯誤triple的d(h′+r,t′)d(h^\prime+r, t^\prime)d(h′+r,t′)盡可能大,這樣才能讓總體的losslossloss趨向于0。因此,在SDG的update過程中,正例中hhh和rrr逐漸減小,但ttt要逐漸增大;反例中h′h^\primeh′和rrr要增大,但t′t^\primet′要減小。
測試環(huán)節(jié)中,可將測試集分為Raw及Filter兩種情況,Filter是指過濾corrupted triplets中在training, validation,test三個數(shù)據(jù)集中出現(xiàn)的正確的三元組。這是因為只是圖譜中存在1對N的情況,當(dāng)在測試一個三元組的時,用其他實體去替換頭實體或者尾實體,這個新生成的反例corrupted triple確可能是一個正確triple,因此當(dāng)遇見這種情況時,將這個triple從測試中過濾掉,從而得到Filter測試結(jié)果。
歸一化公式的分母是向量的平方和再開方;而對于距離公式,是向量的平方和(沒有開方)。公式的錯誤書寫,會引起收斂的失敗。
對于每一次迭代,每一次的歸一化約束(constraint)實體長度為1(減少任意度量(scaling freedoms(SE)),使得收斂有效(避免 trivially minimize(transE)),但對關(guān)系不做此要求。(然而我自己試驗的結(jié)果是,歸一化關(guān)系,會使精度加大和收斂加強)
代碼設(shè)計
有一點需要注意,在進行Corrupted_triple的entity替換中時,最開始的設(shè)想是保證替換后的triple不能是原有triple_List中任意一個。這個設(shè)計思路并沒有錯,但在在代碼實現(xiàn)過程中會導(dǎo)致代碼的運行速度變得奇慢無比。這是因為,FB15k數(shù)據(jù)集中有超過48萬個Triple,每次都要遍歷整個List,在對大數(shù)據(jù)集進行訓(xùn)練時,這個環(huán)節(jié)會消耗大量的資源,因此在效率和性能中間進行平衡,最終放棄了這個步驟。
另一方面,通過SGD更新詞向量時,會同時更新Correct triple和Corrupted triple。這兩個triple中其實只有一個實體不同,因此另一個實體就需要更新兩次,需要使用同一個,不然后一次更新的結(jié)果會將前一次的更新結(jié)果覆蓋掉。下面是關(guān)于SGD 參數(shù)更新的代碼:
correct_copy_head -= self.learning_rate * correct_gradient relation_copy -= self.learning_rate * correct_gradient correct_copy_tail -= -1 * self.learning_rate * correct_gradientrelation_copy -= -1 * self.learning_rate * corrupted_gradient if correct_sample[0] == corrupted_sample[0]:# if corrupted_triples replaces the tail entity, the head entity's embedding need to be updated twicecorrect_copy_head -= -1 * self.learning_rate * corrupted_gradientcorrupted_copy_tail -= self.learning_rate * corrupted_gradient elif correct_sample[1] == corrupted_sample[1]:# if corrupted_triples replaces the head entity, the tail entity's embedding need to be updated twicecorrupted_copy_head -= -1 * self.learning_rate * corrupted_gradientcorrect_copy_tail -= self.learning_rate * corrupted_gradient# normalising these new embedding vector, instead of normalising all the embedding together copy_entity[correct_sample[0]] = self.normalization(correct_copy_head) copy_entity[correct_sample[1]] = self.normalization(correct_copy_tail) if correct_sample[0] == corrupted_sample[0]:# if corrupted_triples replace the tail entity, update the tail entity's embeddingcopy_entity[corrupted_sample[1]] = self.normalization(corrupted_copy_tail) elif correct_sample[1] == corrupted_sample[1]:# if corrupted_triples replace the head entity, update the head entity's embeddingcopy_entity[corrupted_sample[0]] = self.normalization(corrupted_copy_head) # the paper mentions that the relation's embedding doesn't need to be normalised copy_relation[correct_sample[2]] = relation_copy對FB15k數(shù)據(jù)集進行100個epoch的訓(xùn)練,超參數(shù)設(shè)置為dimension = 50, margin = 1.0, norm = L1,總耗時約為2小時, 模型后期訓(xùn)練一次epoch的total loss穩(wěn)定在14000左右。
指標(biāo)
Mean rank
對于測試集的每個三元組,以預(yù)測tail實體為例,我們將(h,r,t)(h,r,t)(h,r,t)中的t用知識圖譜中的每個實體來代替,然后通過distance(h,r,t)distance(h, r, t)distance(h,r,t)函數(shù)來計算距離,這樣我們可以得到一系列的距離,之后按照升序?qū)⑦@些分數(shù)排列。
distance(h,r,t)distance(h, r, t)distance(h,r,t)函數(shù)值是越小越好,那么在上個排列中,排的越前越好。
現(xiàn)在重點來了,我們?nèi)タ疵總€三元組中正確答案也就是真實的t到底能在上述序列中排多少位,比如說t1排100,t2排200,t3排60…,之后對這些排名求平均,mean rank就得到了。
Hit@10
還是按照上述進行函數(shù)值排列,然后去看每個三元組正確答案是否排在序列的前十,如果在的話就計數(shù)+1
最終 排在前十的個數(shù)/總個數(shù) 就是Hit@10
結(jié)論
經(jīng)過transE建模后,在測試集的13584個實體,961個關(guān)系的 59071個三元組中,測試結(jié)果如下:
mean rank: 353.06935721419984 hit@3: 0.12181950534103028 hit@10: 0.2754989758087725一方面可以看出訓(xùn)練后的結(jié)果是有效的,但不是十分優(yōu)秀,可能與transE模型的局限性有關(guān),transE只能處理一對一的關(guān)系,不適合一對多/多對一關(guān)系。
雖然TransE模型的參數(shù)較少,計算的復(fù)雜度顯著降低,并且在大規(guī)模稀疏知識庫上也同樣具有較好的性能與可擴展性。但是TransE 模型不能用在處理復(fù)雜關(guān)系上 ,原因如下:以一對多為例,對于給定的事實,以姜文拍的民國三部曲電影為例,即《讓子彈飛》、《一步之遙》和《邪不壓正》。可以得到三個事實三元組即(姜文,導(dǎo)演,讓子彈飛)、(姜文,導(dǎo)演,一步之遙)和(姜文,導(dǎo)演,邪不壓正)。按照上面對于TransE模型的介紹,可以得到,讓子彈飛≈一步之遙≈邪不壓正,但實際上這三部電影是不同的實體,應(yīng)該用不同的向量來表示。多對一和多對多也類似。
總結(jié)
- 上一篇: 拔牙大概要多少钱一颗
- 下一篇: 【转】自然语言系列学习之表示学习与知识获