KBQA-Bert学习记录-CRF模型
目錄
一、整體架構
1.定義CRF類,初始化相關參數
2.定義forward函數
3.forword調用的函數:_validate
4.forward調用的函數:_conputer_score
5.forward調用的函數:_compute_normalizer
6.forward調用的函數:_viterbi_decode
7.外部調用的函數:decode
該項目中,使用BERT+CRF進行NER任務,因此先構造CRF模型。具體實現過程中需要注意的細節已在代碼中包含。
一、整體架構
通過bert生成序列之后(其他的模型比如LSTM什么的也一樣,都會生成一個預測序列),我們得到了形狀是(batch_size, sentence_length, number_of_tags)的結果,也就是,對每一句話,的每一個字,有number_of_tags這么多的預測結果。假如我們的實體類型有三個"B", "I", "O",一個batch有32句話,一句話被統一成了64個單詞,那么生成的結果就是(32, 64, 3)。注意這里的batch_size和sentence_length的位置,可能會由于代碼的不同,調換順序。
生成的結果就是我們的發射分數。要計算損失,我們還需要計算發射分數中,正確路徑對應的分數;以及發射分數中,所有路徑合起來的分數。
同時,我們還需要對所有路徑合起來的分數進行處理,由于計算損失的時候,會讓這個總分作為分母,因此,采用的是先取exp(),再求和sum(),再取對數log(),而這個運算只需要pytorch的一行代碼即可完成:torch.logsumexp()。
最后,我們還希望得到一條最佳路徑,于是需要維特比解碼得到。
因此,在這個類中,我們需要定義不同的函數來實現不同的功能:
1. __init__必須定義,初始化參數
2.forward必須定義,前向傳播,得到損失值。這里面會調用其他函數,用于計算損失。
3.計算正確路徑分數的函數
4.計算所有路徑總分的函數
5.維特比解碼函數
6.能讓外接調用,得到最佳路徑的函數
注意:下面所有函數,都在CRF類里面,這里以分段的形式記錄。
1.定義CRF類,初始化相關參數
class CRF(nn.Module):def __init__(self, num_tags : int = 2, batch_first : bool = True) -> None:super(CRF, self).__init__()self.num_tags = num_tagsself.batch_first = batch_first# start到其他(不含end)的得分self.start_transitions = nn.Parameter(torch.empty(num_tags))# 其他(不含start)到end的得分self.end_transitions = nn.Parameter(torch.empty(num_tags))# 轉移分數矩陣self.transitions = nn.Parameter(torch.empty((num_tags, num_tags)))self.reset_parameters()def reset_parameters(self):'''將初始化的分數限定在-0.1到0.1之間'''init_range = 0.1nn.init.uniform_(self.start_transitions, -init_range, init_range)nn.init.uniform_(self.end_transitions, -init_range, init_range)nn.init.uniform_(self.transitions, -init_range, init_range)2.定義forward函數
forward函數所需要的其他函數,后面補充。通過forward函數之后,返回的是我們所需要的損失值。
def forward(self, emissions: torch.Tensor,tags: torch.Tensor = None,mask: Optional[torch.ByteTensor] = None,reduction: str = 'mean') -> torch.Tensor:self._validate(emissions, tags=tags, mask=mask)# reduction:損失值模式,是均值還是求和作為損失reduction = reduction.lower()if reduction not in ('none', 'sum', 'mean', 'token_mean'):raise ValueError(f"invalid reduction {reduction}")if mask is None:mask = torch.ones_like(tags, dtype=torch.uint8)if self.batch_first:# 發射分數形狀:(seq_length, batch_size, tag_num)emissions = emissions.transpose(0, 1)tags = tags.transpose(0, 1)mask = mask.transpose(0, 1)# 計算正確標簽序列的發射分數和轉移分數之和, shape: (batch_size, )numerator = self._cumputer_score(emissions=emissions, tags=tags, mask=mask)# 計算所有序列發射分數和轉移分數之和, shape: (batch_size, )denominator = self._compute_normalizer(emissions=emissions, mask=mask)# 二者相減, shape: (batch_size, )llh = denominator - numerator# 根據不同的設定返回不同形式的分數if reduction == 'none':return llhif reduction == 'sum':return llh.sum()if reduction == 'mean':return llh.mean()if reduction == 'token_mean':return llh.sum() / mask.float().sum()3.forword調用的函數:_validate
主要是用來確保所有輸入數據的維度應該是我們所要求的維度。
def _validate(self, emissions: torch.Tensor,tags: Optional[torch.LongTensor] = None,mask: Optional[torch.ByteTensor] = None) -> None:if emissions.dim() != 3:raise ValueError(f"emissions must have dimension of 3, got{emissions.dim()}")if emissions.size(2) != self.num_tags:raise ValueError(f"expected last dimission of emission is {self.num_tags},"f"got {emissions.size(2)}")if tags is not None:if emissions.shape[:2] != mask.shape:raise ValueError(f"the first two dimensions of mask and emissions must match"f"got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}")no_empty_seq = not self.batch_first and mask[0].all()no_empty_seq_bf = self.batch_first and mask[:, 0].all()if not no_empty_seq and not no_empty_seq_bf:raise ValueError('mask of the first timestep must all be on.')4.forward調用的函數:_computer_score
該函數用來計算最佳路徑的分數。
def _computer_score(self, emissions: torch.Tensor,tags: torch.LongTensor,mask: torch.ByteTensor) -> torch.Tensor:# batch secondassert emissions.dim() == 3 and tags.dim() == 2assert emissions.shape[:2] == tags.shapeassert emissions.size(2) == self.num_tagsassert mask.shape == tags.shape# 每個mask,開頭一定是1,否則相當于句子就沒了。assert mask[0].all()seq_length, batch_size = tags.shapemask = mask.float()# start,轉移到其他所有標簽的分數,不包含end# 根據實際的tag的開頭的詞,得到從start到每句話開頭的類型的分數。# 這里是start到第一個詞的轉移分數,shape: (batch_size,)score = self.start_transitions[tags[0]]# 接下來是預測的每句話的開頭應當是什么tag,如果有3個tag,那么每個詞都會有對應的三個分數,分別對應每一個tag# 但是我們實際的tag是在tags[0]里面的,而預測的三個值,分數不一定是多少# 比如實際的第一個詞tag是B,預測的BIO的三個分數分別為:(0.1,0.5,04)# 那么我們把0.1這個分數加上。這個就是發射分數,也就是預測的分數。score += emissions[0, torch.arange(batch_size), tags[0]]# 至此,我們完成了從start轉移到第一個詞的轉移分數+發射分數# 接下來是每個詞到下一個詞的轉移分數+發射分數,全加到一塊for i in range(1, seq_length):# 轉移分數score += self.transitions[tags[i-1], tags[i]] * mask[i]# 發射分數score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]# 取到最后一個詞的tag# 使用的mask是形如:[1,1,1,1,0,0,0],后面的0是padding的,因此沒字了# 因此通過下面的方式,取到1的和,減去1,就是最后一個詞的索引了。seq_end = mask.long().sum(dim=0) - 1last_tag = tags[seq_end, torch.arange(batch_size)]# 最后一個詞轉移到end的分數score += self.end_transitions[last_tag]return score5.forward調用的函數:_compute_normalizer
這里計算所有路徑的分數之和。并取一個logsumexp
def _compute_normalizer(self, emissions: torch.Tensor,mask: torch.ByteTensor) -> torch.Tensor:# emissions: (seq_length, batch_size, num_tags)# mask: (seq_length, batch_size)assert emissions.dim() == 3 and mask.dim() == 2assert emissions.shape[:2] == mask.shapeassert emissions.size(2) == self.num_tagsassert mask[0].all()seq_length = emissions.size(0)# emissions[0],因為第一個維度是句子長度,因此emissions[0]就是每一個句子的開頭的詞,對應的發射分數# 并且每一個分數是有num_tags這么多。因此emissions[0]就是對所有開頭的詞,對每一個標簽預測的分數。# 再加上start標志到每一個標簽的分數,就是一個整體的開頭分數之和。score = self.start_transitions + emissions[0]# 接下來把所有的轉移分數,發射分數全部加起來。for i in range(1, seq_length):# 原來是(batch_size, num_tags), 現在是(batch_size, num_tags, 1)broadcast_score = score.unsqueeze(dim=2)# 對于第i個詞,原來是(batch_size, num_tags), 現在是(batch_size, 1, num_tags)broadcast_emission = emissions[i].unsqueeze(1)# 先把開頭的分數和轉移矩陣加起來,便得到了開頭的每一個tag,轉移到其他每一個tag的概率# 再把發射矩陣加上,便得到了該單詞的總分數。其中會自動使用broad cast機制next_score = broadcast_score + self.transitions + broadcast_emission# 對總分數,在第二個維度求一個對數域的分數。第二個維度,也就是轉移矩陣的行# 我們求的是所有路徑的總分數,要對這個分數求和。# 假如對第二個詞來說,可能由第一個詞的num_tags那么多的可能性過來,那么就把所有的可能性加起來# 這樣得到的就是對于第二個詞來說的總分數。因此,把轉移矩陣的行,也就是前一個詞可能的tag,全部加起來即可# 也就是在第二個維度上求和。這樣就得到了總分數,我們對這個總分數進行對數域計算即可(取e,求和,取對數)。next_score = torch.logsumexp(next_score, dim=1)# 通過mask,如果對應的單詞位置有值,也就是我們需要這個分數,那么就使用next_score# 如果對應的位置沒值,那么這個分數不需要加上,就取原來的scorescore = torch.where(mask[i].unsqueeze(1), next_score, score)# 最后把單詞轉移到end的分數加上score += self.end_transitions# 返回值取對數域的值,把所有的詞的分數再求和一遍return torch.logsumexp(score, dim=1)6.forward調用的函數:_viterbi_decode
維特比解碼,得到最佳路徑。
def _viterbi_decode(self, emissions: torch.FloatTensor,mask: torch.ByteTensor) -> List[List[int]]:assert emissions.dim() == 3 and mask.dim() == 2assert emissions.shape[:2] == mask.shapeassert emissions.size(2) == self.num_tagsassert mask[0].all()seq_length, batch_size = mask.shapescore = self.start_transitions + emissions[0]history = []for i in range(1, seq_length):broadcast_score = score.unsqueeze(2)broadcast_emission = emissions[i].unsqueeze(1)next_score = broadcast_score + self.transitions + broadcast_emission# 在第一個維度上面求最大,消掉第一個維度,那么剩下的就是"到下一個類型概率最大的那個類型"# 這個max返回值有2個,一個是求完最大值后的結果,形狀是(B, tag_num),一個是每個最大值所在的索引# 兩個返回結果形狀一致# 選最好的轉移分數next_score, indices = next_score.max(dim=1)score = torch.where(mask[i].unsqueeze(1), next_score, score)# 上一個詞轉移到這個詞時,分數最高的那些值的索引history.append(indices)score += self.end_transitionsseq_ends = mask.long().sum(dim=0) - 1best_tags_list = []for idx in range(batch_size):# 取到分數最高的標簽,就是最后一個詞的標簽的索引# 選最好的發射分數_, best_last_tag = score[idx].max(dim=0)best_tags = [best_last_tag.item()]# seq_ends存了每個句子序列的最后一個詞的索引。for hist in reversed(history[:seq_ends[idx]]):best_last_tag = hist[idx][best_tags[-1]]best_tags.append(best_last_tag.item())best_tags.reverse()best_tags_list.append(best_tags)return best_tags_list7.外部調用的函數:decode
該函數調用了上面的維特比解碼,外部可通過model.decode調用,返回最佳路徑。
def decode(self, emissions: torch.Tensor,mask: Optional[torch.ByteTensor]=None) -> List[List[int]]:self._validate(emissions=emissions, mask=mask)if mask is None:mask = emissions.new_ones(emissions.shape[:2],dtype=torch.uint8)if self.batch_first:emissions = emissions.transpose(0, 1)mask = mask.transpose(0, 1)return self._viterbi_decode(emissions, mask)總結
以上是生活随笔為你收集整理的KBQA-Bert学习记录-CRF模型的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 微信小程序中使用tabBar
- 下一篇: 华为rh2285 v1的装上独立显卡,并