知识图谱嵌入:TransE算法原理及代码详解
目錄
KGE
TransE
TransE代碼詳解
KGE
知識圖譜中,離散符號化的知識不能夠進行語義計算,為幫助計算機對知識進行計算,解決數據稀疏性,可以將知識圖譜中的實體、關系映射到低維連續的向量空間中,這類方法稱為知識圖譜嵌入(Knowledge Graph Embedding, KGE)。
TransE
受到詞向量中平移不變性的啟發,TransE將關系的向量表示解釋成頭、尾實體向量之間的轉移向量,算法簡單而高效。并且在模型訓練過程中,可以學習到一定的語義信息。其基本思想是,如果一個三元組(h, l, t)為真,那么向量空間中對應向量需要符合h + l ≈ t。例如:
vec(Rome) + vec(is-capital-of) ≈ vec(Italy)
vec(Paris) + vec(is-capital-of) ≈ vec(France)
TransE-平移距離據此可以對缺失的三元組(Beijing,is-capital-of,?)、(Beijing,?,China)、(?,is-capital-of,China)進行補全,即鏈接預測。
TransE是最早的翻譯模型,后面還推出了TransD、TransR、TransH、TransA等等,換湯不換藥,主要是對TransE進行改進和補充。
優點:
能夠解決數據稀疏的難題,提升知識計算的效率。
能夠自動捕捉推理特征,無須人工設計。
算法簡單,學習的參數少,計算復雜度低。
缺點:
無法有效處理一對多、多對一、多對多、自反等復雜關系。
僅考慮一跳關系,忽略了長距離的隱關系。
嵌入模型不能快速收斂。
偽代碼:
輸入:訓練集,實體集E,關系集L,margin值γ,嵌入向量維度k
1:初始化 ?? 對于每個關系向量? ←? 從區間內隨機采樣
2:?????????????? 對于每個關系向量? ←? 除以自身的L2范數
3:???????????? ? 對于每個實體向量? ←? 從區間內隨機采樣
4:循環:
5:?????????????? 對于每個實體向量? ←? 除以自身的L2范數
6:?????????????? 從訓練集S中取出數量為b的樣本作為一個
7:?????????????? 初始化三元組集合為一個空列表
8:?????????????? 遍歷:,執行
9:??????????????? ????? ????? 替換正確三元組的頭實體或者尾實體構造負樣本或
10:??????????????? ??????? ? 將正樣本三元組和負樣本三元組都放在列表中
11:????? ?????? 遍歷結束
12:????? ????? 根據梯度下降更新實體、關系向量
13:循環結束
TransE代碼詳解
1、加載數據
傳入訓練集,實體集E,關系集L這三個數據文件的地址
返回三個列表:實體,關系,三元組。(其中實體、關系都以id表示)
import codecs import numpy as np import copy import time import randomdef dataloader(file1, file2, file3):print("load file...")entity = []relation = []entities2id = {}relations2id = {}with open(file2, 'r') as f1, open(file3, 'r') as f2:lines1 = f1.readlines()lines2 = f2.readlines()for line in lines1:line = line.strip().split('\t')if len(line) != 2:continueentities2id[line[0]] = line[1]entity.append(line[1])for line in lines2:line = line.strip().split('\t')if len(line) != 2:continuerelations2id[line[0]] = line[1]relation.append(line[1])triple_list = []with codecs.open(file1, 'r') as f:content = f.readlines()for line in content:triple = line.strip().split("\t")if len(triple) != 3:continueh_ = entities2id[triple[0]]r_ = relations2id[triple[1]]t_ = entities2id[triple[2]]triple_list.append([h_, r_, t_])print("Complete load. entity : %d , relation : %d , triple : %d" % (len(entity), len(relation), len(triple_list)))return entity, relation, triple_list2、傳參
傳入實體id列表entity,關系id列表relation,三元組列表triple_list,向量維度embedding_dim=50,學習率lr=0.01,margin(正負樣本三元組之間的間隔修正),norm范數,loss損失值。
class TransE:def __init__(self, entity, relation, triple_list, embedding_dim=50, lr=0.01, margin=1.0, norm=1):self.entities = entityself.relations = relationself.triples = triple_listself.dimension = embedding_dimself.learning_rate = lrself.margin = marginself.norm = normself.loss = 0.03、初始化
即偽代碼中的步驟1-3。
將實體id列表、關系id列表轉變為{實體id:實體向量}、{關系id:關系向量}這兩個字典。
class TransE:def data_initialise(self):entityVectorList = {}relationVectorList = {}for entity in self.entities:entity_vector = np.random.uniform(-6.0 / np.sqrt(self.dimension), 6.0 / np.sqrt(self.dimension),self.dimension)entityVectorList[entity] = entity_vectorfor relation in self.relations:relation_vector = np.random.uniform(-6.0 / np.sqrt(self.dimension), 6.0 / np.sqrt(self.dimension),self.dimension)relation_vector = self.normalization(relation_vector)relationVectorList[relation] = relation_vectorself.entities = entityVectorListself.relations = relationVectorListdef normalization(self, vector):return vector / np.linalg.norm(vector)4、訓練過程
即偽代碼中的步驟4-13。
nbatches=100,即數據集分為100個batch依次訓練,每個batch的樣本數量即batch_size。epochs=1,即完整跑完100個batch的次數。
首先對實體向量進行歸一化。
對于每一個batch,隨機采樣batch_size數量的三元組作為,即代碼中的batch_samples。
初始化三元組集合為一個空列表。
對于batch_samples中的每一個樣本,隨機替換頭實體或者尾實體生成負樣本三元組。
其中,while corrupted_sample[0] == sample[0]是一個過濾正樣本三元組的過程,避免從實體集中采樣的實體仍是原實體。不過,此處嚴格來說應使用while corrupted_sample in self.triples,防止采樣的實體h2雖然不是原實體h1,但該三元組仍是正樣本(即(h1,l,t)和(h2,l,t)都在三元組列表中,都成立)。但是這句代碼需要遍歷整個三元組列表,會使訓練時間增加10倍,故將其簡化。
將正樣本和負樣本三元組都放入列表中。
調用update_triple_embedding函數,計算這一個batch的損失值,根據梯度下降法更新向量,然后再進行下一個batch的訓練。
所有的100個batch訓練完成后,將訓練好的實體向量、關系向量輸出到out_file_title目錄下(為空,代表保存在當前目錄)
class TransE:def training_run(self, epochs=1, nbatches=100, out_file_title = ''):batch_size = int(len(self.triples) / nbatches)print("batch size: ", batch_size)for epoch in range(epochs):start = time.time()self.loss = 0.0# Normalise the embedding of the entities to 1for entity in self.entities.keys():self.entities[entity] = self.normalization(self.entities[entity]);for batch in range(nbatches):batch_samples = random.sample(self.triples, batch_size)Tbatch = []for sample in batch_samples:corrupted_sample = copy.deepcopy(sample)pr = np.random.random(1)[0]if pr > 0.5:# change the head entitycorrupted_sample[0] = random.sample(self.entities.keys(), 1)[0]while corrupted_sample[0] == sample[0]:corrupted_sample[0] = random.sample(self.entities.keys(), 1)[0]else:# change the tail entitycorrupted_sample[2] = random.sample(self.entities.keys(), 1)[0]while corrupted_sample[2] == sample[2]:corrupted_sample[2] = random.sample(self.entities.keys(), 1)[0]if (sample, corrupted_sample) not in Tbatch:Tbatch.append((sample, corrupted_sample))self.update_triple_embedding(Tbatch)end = time.time()print("epoch: ", epoch, "cost time: %s" % (round((end - start), 3)))print("running loss: ", self.loss)with codecs.open(out_file_title +"TransE_entity_" + str(self.dimension) + "dim_batch" + str(batch_size), "w") as f1:for e in self.entities.keys():f1.write(e + "\t")f1.write(str(list(self.entities[e])))f1.write("\n")with codecs.open(out_file_title +"TransE_relation_" + str(self.dimension) + "dim_batch" + str(batch_size), "w") as f2:for r in self.relations.keys():f2.write(r + "\t")f2.write(str(list(self.relations[r])))f2.write("\n")5、梯度下降
首先調用deepcopy函數深拷貝實體和關系向量,取出實體和關系id分別對應的向量,根據L1范數或L2范數計算得分函數。
L1范數計算得分:np.sum(np.fabs(h + r - t))
L2范數計算得分:np.sum(np.square(h + r - t))
再根據以下公式計算損失值loss:( 即margin值)
L2范數根據以下公式計算梯度:
L1范數的梯度向量中每個元素為-1或1。
最后根據梯度對實體、關系向量進行更新和歸一化。
class TransE:def update_triple_embedding(self, Tbatch):copy_entity = copy.deepcopy(self.entities)copy_relation = copy.deepcopy(self.relations)for correct_sample, corrupted_sample in Tbatch:correct_copy_head = copy_entity[correct_sample[0]]correct_copy_tail = copy_entity[correct_sample[2]]relation_copy = copy_relation[correct_sample[1]]corrupted_copy_head = copy_entity[corrupted_sample[0]]corrupted_copy_tail = copy_entity[corrupted_sample[2]]correct_head = self.entities[correct_sample[0]]correct_tail = self.entities[correct_sample[2]]relation = self.relations[correct_sample[1]]corrupted_head = self.entities[corrupted_sample[0]]corrupted_tail = self.entities[corrupted_sample[2]]# calculate the distance of the triplesif self.norm == 1:correct_distance = norm_l1(correct_head, relation, correct_tail)corrupted_distance = norm_l1(corrupted_head, relation, corrupted_tail)else:correct_distance = norm_l2(correct_head, relation, correct_tail)corrupted_distance = norm_l2(corrupted_head, relation, corrupted_tail)loss = self.margin + correct_distance - corrupted_distanceif loss > 0:self.loss += losscorrect_gradient = 2 * (correct_head + relation - correct_tail)corrupted_gradient = 2 * (corrupted_head + relation - corrupted_tail)if self.norm == 1:for i in range(len(correct_gradient)):if correct_gradient[i] > 0:correct_gradient[i] = 1else:correct_gradient[i] = -1if corrupted_gradient[i] > 0:corrupted_gradient[i] = 1else:corrupted_gradient[i] = -1correct_copy_head -= self.learning_rate * correct_gradientrelation_copy -= self.learning_rate * correct_gradientcorrect_copy_tail -= -1 * self.learning_rate * correct_gradientrelation_copy -= -1 * self.learning_rate * corrupted_gradientif correct_sample[0] == corrupted_sample[0]:# if corrupted_triples replaces the tail entity, the head entity's embedding need to be updated twicecorrect_copy_head -= -1 * self.learning_rate * corrupted_gradientcorrupted_copy_tail -= self.learning_rate * corrupted_gradientelif correct_sample[2] == corrupted_sample[2]:# if corrupted_triples replaces the head entity, the tail entity's embedding need to be updated twicecorrupted_copy_head -= -1 * self.learning_rate * corrupted_gradientcorrect_copy_tail -= self.learning_rate * corrupted_gradient# normalising these new embedding vector, instead of normalising all the embedding togethercopy_entity[correct_sample[0]] = self.normalization(correct_copy_head)copy_entity[correct_sample[2]] = self.normalization(correct_copy_tail)if correct_sample[0] == corrupted_sample[0]:# if corrupted_triples replace the tail entity, update the tail entity's embeddingcopy_entity[corrupted_sample[2]] = self.normalization(corrupted_copy_tail)elif correct_sample[2] == corrupted_sample[2]:# if corrupted_triples replace the head entity, update the head entity's embeddingcopy_entity[corrupted_sample[0]] = self.normalization(corrupted_copy_head)# the paper mention that the relation's embedding don't need to be normalisedcopy_relation[correct_sample[1]] = relation_copy# copy_relation[correct_sample[1]] = self.normalization(relation_copy)self.entities = copy_entityself.relations = copy_relation6、__main__
if __name__ == '__main__':# file1 = "FB15k\\train.txt"# file2 = "FB15k\\entity2id.txt"# file3 = "FB15k\\relation2id.txt"file1 = "WN18\\wordnet-mlj12-train.txt"file2 = "WN18\\entity2id.txt"file3 = "WN18\\relation2id.txt"entity_set, relation_set, triple_list = dataloader(file1, file2, file3)transE = TransE(entity_set, relation_set, triple_list, embedding_dim=50, lr=0.01, margin=1.0, norm=2)transE.data_initialise()transE.training_run(out_file_title="WN18_")參考:
代碼來自于:論文筆記(一):TransE論文詳解及代碼復現 - 知乎,點擊完整代碼可下載代碼。
總結
以上是生活随笔為你收集整理的知识图谱嵌入:TransE算法原理及代码详解的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: TransE如何进行向量更新?
- 下一篇: TransE算法详解