文章目錄
TransE
知識圖譜基礎
三元組(h,r,t)
知識表示
即將實體和關系向量化,embedding
算法描述
思想:一個正確的三元組的embedding會滿足:h+r=t
定義距離d表示向量之間的距離,一般取L1或者L2,期望正確的三元組的距離越小越好,而錯誤的三元組的距離越大越好。為此給出目標函數為:
梯度求解:
代碼分析
參數:目標函數的常數——margin學習率——learningRate向量維度——dim實體列表——entityList(讀取文本文件,實體+id)關系列表——relationList(讀取文本文件,關系 + id)三元關系列表——tripleList(讀取文本文件,實體 + 實體 + 關系)損失值——loss距離公式——L1
規定初始化維度和取值范圍(TransE算法原理中的取值范圍)
涉及的函數:
init:隨機生成值norm:歸一化
getSample——隨機選取部分三元關系,SbatchgetCorruptedTriplet(sbatch)——隨機替換三元組的實體,h、t中任意一個被替換,但不同時替換。update——更新
L2更新向量的推導過程:
python 函數
uniform(a, b)#隨機生成a,b之間的數,左閉右開。
求向量的模,var = linalg.norm(list)
"""
@version: 3.7
@author: jiayalu
@file: trainTransE.py
@time: 22/08/2019 10:56
@description: 用于對知識圖譜中的實體、關系基于TransE算法訓練獲取向量
數據:三元關系
實體id和關系id
結果為:兩個文本文件,即entityVector.txt和relationVector.txt 實體 [array向量]"""
from random
import uniform
, sample
from numpy
import *
from copy
import deepcopy
class TransE:def __init__(self
, entityList
, relationList
, tripleList
, margin
= 1, learingRate
= 0.00001, dim
= 10, L1
= True):self
.margin
= marginself
.learingRate
= learingRateself
.dim
= dimself
.entityList
= entityListself
.relationList
= relationListself
.tripleList
= tripleListself
.loss
= 0self
.L1
= L1
def initialize(self
):'''初始化向量'''entityVectorList
= {}relationVectorList
= {}for entity
in self
.entityList
:n
= 0entityVector
= []while n
< self
.dim
:ram
= init
(self
.dim
)entityVector
.append
(ram
)n
+= 1entityVector
= norm
(entityVector
)entityVectorList
[entity
] = entityVector
print("entityVector初始化完成,數量是%d"%len(entityVectorList
))for relation
in self
. relationList
:n
= 0relationVector
= []while n
< self
.dim
:ram
= init
(self
.dim
)relationVector
.append
(ram
)n
+= 1relationVector
= norm
(relationVector
)relationVectorList
[relation
] = relationVector
print("relationVectorList初始化完成,數量是%d"%len(relationVectorList
))self
.entityList
= entityVectorListself
.relationList
= relationVectorList
def transE(self
, cI
= 20):print("訓練開始")for cycleIndex
in range(cI
):Sbatch
= self
.getSample
(3)Tbatch
= []for sbatch
in Sbatch
:tripletWithCorruptedTriplet
= (sbatch
, self
.getCorruptedTriplet
(sbatch
))if(tripletWithCorruptedTriplet
not in Tbatch
):Tbatch
.append
(tripletWithCorruptedTriplet
)self
.update
(Tbatch
)if cycleIndex
% 100 == 0:print("第%d次循環"%cycleIndex
)print(self
.loss
)self
.writeRelationVector
("E:\pythoncode\knownlageGraph\\transE-master\\relationVector.txt")self
.writeEntilyVector
("E:\pythoncode\knownlageGraph\\transE-master\\entityVector.txt")self
.loss
= 0def getSample(self
, size
):return sample
(self
.tripleList
, size
)def getCorruptedTriplet(self
, triplet
):'''training triplets with either the head or tail replaced by a random entity (but not both at the same time):param triplet::return corruptedTriplet:'''i
= uniform
(-1, 1)if i
< 0: while True:entityTemp
= sample
(self
.entityList
.keys
(), 1)[0]if entityTemp
!= triplet
[0]:breakcorruptedTriplet
= (entityTemp
, triplet
[1], triplet
[2])else: while True:entityTemp
= sample
(self
.entityList
.keys
(), 1)[0]if entityTemp
!= triplet
[1]:breakcorruptedTriplet
= (triplet
[0], entityTemp
, triplet
[2])return corruptedTriplet
def update(self
, Tbatch
):copyEntityList
= deepcopy
(self
.entityList
)copyRelationList
= deepcopy
(self
.relationList
)for tripletWithCorruptedTriplet
in Tbatch
:headEntityVector
= copyEntityList
[tripletWithCorruptedTriplet
[0][0]] tailEntityVector
= copyEntityList
[tripletWithCorruptedTriplet
[0][1]]relationVector
= copyRelationList
[tripletWithCorruptedTriplet
[0][2]]headEntityVectorWithCorruptedTriplet
= copyEntityList
[tripletWithCorruptedTriplet
[1][0]]tailEntityVectorWithCorruptedTriplet
= copyEntityList
[tripletWithCorruptedTriplet
[1][1]]headEntityVectorBeforeBatch
= self
.entityList
[tripletWithCorruptedTriplet
[0][0]] tailEntityVectorBeforeBatch
= self
.entityList
[tripletWithCorruptedTriplet
[0][1]]relationVectorBeforeBatch
= self
.relationList
[tripletWithCorruptedTriplet
[0][2]]headEntityVectorWithCorruptedTripletBeforeBatch
= self
.entityList
[tripletWithCorruptedTriplet
[1][0]]tailEntityVectorWithCorruptedTripletBeforeBatch
= self
.entityList
[tripletWithCorruptedTriplet
[1][1]]if self
.L1
:distTriplet
= distanceL1
(headEntityVectorBeforeBatch
, tailEntityVectorBeforeBatch
,relationVectorBeforeBatch
)distCorruptedTriplet
= distanceL1
(headEntityVectorWithCorruptedTripletBeforeBatch
,tailEntityVectorWithCorruptedTripletBeforeBatch
,relationVectorBeforeBatch
)else:distTriplet
= distanceL2
(headEntityVectorBeforeBatch
, tailEntityVectorBeforeBatch
,relationVectorBeforeBatch
)distCorruptedTriplet
= distanceL2
(headEntityVectorWithCorruptedTripletBeforeBatch
,tailEntityVectorWithCorruptedTripletBeforeBatch
,relationVectorBeforeBatch
)eg
= self
.margin
+ distTriplet
- distCorruptedTriplet
if eg
> 0: self
.loss
+= eg
if self
.L1
:tempPositive
= 2 * self
.learingRate
* (tailEntityVectorBeforeBatch
- headEntityVectorBeforeBatch
- relationVectorBeforeBatch
)tempNegtative
= 2 * self
.learingRate
* (tailEntityVectorWithCorruptedTripletBeforeBatch
- headEntityVectorWithCorruptedTripletBeforeBatch
- relationVectorBeforeBatch
)tempPositiveL1
= []tempNegtativeL1
= []for i
in range(self
.dim
): if tempPositive
[i
] >= 0:tempPositiveL1
.append
(1)else:tempPositiveL1
.append
(-1)if tempNegtative
[i
] >= 0:tempNegtativeL1
.append
(1)else:tempNegtativeL1
.append
(-1)tempPositive
= array
(tempPositiveL1
)tempNegtative
= array
(tempNegtativeL1
)else:tempPositive
= 2 * self
.learingRate
* (tailEntityVectorBeforeBatch
- headEntityVectorBeforeBatch
- relationVectorBeforeBatch
)tempNegtative
= 2 * self
.learingRate
* (tailEntityVectorWithCorruptedTripletBeforeBatch
- headEntityVectorWithCorruptedTripletBeforeBatch
- relationVectorBeforeBatch
)headEntityVector
= headEntityVector
+ tempPositivetailEntityVector
= tailEntityVector
- tempPositiverelationVector
= relationVector
+ tempPositive
- tempNegtativeheadEntityVectorWithCorruptedTriplet
= headEntityVectorWithCorruptedTriplet
- tempNegtativetailEntityVectorWithCorruptedTriplet
= tailEntityVectorWithCorruptedTriplet
+ tempNegtativecopyEntityList
[tripletWithCorruptedTriplet
[0][0]] = norm
(headEntityVector
)copyEntityList
[tripletWithCorruptedTriplet
[0][1]] = norm
(tailEntityVector
)copyRelationList
[tripletWithCorruptedTriplet
[0][2]] = norm
(relationVector
)copyEntityList
[tripletWithCorruptedTriplet
[1][0]] = norm
(headEntityVectorWithCorruptedTriplet
)copyEntityList
[tripletWithCorruptedTriplet
[1][1]] = norm
(tailEntityVectorWithCorruptedTriplet
)self
.entityList
= copyEntityListself
.relationList
= copyRelationList
def writeEntilyVector(self
, dir):print("寫入實體")entityVectorFile
= open(dir, 'w', encoding
="utf-8")for entity
in self
.entityList
.keys
():entityVectorFile
.write
(entity
+ " ")entityVectorFile
.write
(str(self
.entityList
[entity
].tolist
()))entityVectorFile
.write
("\n")entityVectorFile
.close
()def writeRelationVector(self
, dir):print("寫入關系")relationVectorFile
= open(dir, 'w', encoding
="utf-8")for relation
in self
.relationList
.keys
():relationVectorFile
.write
(relation
+ " ")relationVectorFile
.write
(str(self
.relationList
[relation
].tolist
()))relationVectorFile
.write
("\n")relationVectorFile
.close
()def init(dim
):return uniform
(-6/(dim
**0.5), 6/(dim
**0.5))def norm(list):'''歸一化:param 向量:return: 向量的平方和的開方后的向量'''var
= linalg
.norm
(list)i
= 0while i
< len(list):list[i
] = list[i
]/vari
+= 1return array
(list)def distanceL1(h
, t
,r
):s
= h
+ r
- t
sum = fabs
(s
).sum()return sumdef distanceL2(h
, t
, r
):s
= h
+ r
- t
sum = (s
*s
).sum()return sumdef openDetailsAndId(dir,sp
=" "):idNum
= 0list = []with open(dir,"r", encoding
="utf-8") as file:lines
= file.readlines
()for line
in lines
:DetailsAndId
= line
.strip
().split
(sp
)list.append
(DetailsAndId
[0])idNum
+= 1return idNum
, listdef openTrain(dir,sp
=" "):num
= 0list = []with open(dir, "r", encoding
="utf-8") as file:lines
= file.readlines
()for line
in lines
:triple
= line
.strip
().split
(sp
)if(len(triple
)<3):continuelist.append
(tuple(triple
))num
+= 1return num
, listif __name__
== '__main__':dirEntity
= "E:\pythoncode\ZXknownlageGraph\TransEgetvector\entity2id.txt"entityIdNum
, entityList
= openDetailsAndId
(dirEntity
)dirRelation
= "E:\pythoncode\ZXknownlageGraph\TransEgetvector\\relation2id.txt"relationIdNum
, relationList
= openDetailsAndId
(dirRelation
)dirTrain
= "E:\pythoncode\ZXknownlageGraph\TransEgetvector\\train.txt"tripleNum
, tripleList
= openTrain
(dirTrain
)print("打開TransE")transE
= TransE
(entityList
,relationList
,tripleList
, margin
=1, dim
= 128)print("TranE初始化")transE
.initialize
()transE
.transE
(1500)transE
.writeRelationVector
("E:\pythoncode\ZXknownlageGraph\TransEgetvector\\relationVector.txt")transE
.writeEntilyVector
("E:\pythoncode\ZXknownlageGraph\TransEgetvector\\entityVector.txt")
數據
結果向量
總結
以上是生活随笔為你收集整理的TransE算法原理与案例的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。