十分钟读懂Beam Search(1/2)
最近研究了一下用基于BERT的encoder-decoder結(jié)構(gòu)做文本生成任務(wù),碰巧管老師昨天的文章也介紹了以生成任務(wù)見長的GPT模型,于是決定用兩篇文章大家介紹一下在文本生成任務(wù)中常用的解碼策略Beam Search(集束搜索)。
解碼及貪心搜索
生成式任務(wù)相比普通的分類、tagging等NLP任務(wù)會復(fù)雜不少。在生成的時候,模型的輸出是一個時間步一個時間步依次獲得的,而且前面時間步的結(jié)果還會影響后面時間步的結(jié)果。也就是說,每一個時間步,模型給出的都是基于歷史生成結(jié)果的條件概率。為了生成完整的句子,需要一個稱為解碼的額外動作來融合模型多個時間步的輸出,而且使得最終得到的序列的每一步條件概率連乘起來最大。
在文本生成任務(wù)中,每一個時間步可能的輸出種類稱為字典大小(vocabulary size,我們用 表示),進(jìn)行T步隨機(jī)的生成可能獲得的結(jié)果總共有 種。拿中文文本生成來說, 的值大約是5000-6000,即常用漢字的個數(shù)。在如此大的基數(shù)下,遍歷整個生成空間是不現(xiàn)實(shí)的。
最容易想到的策略是貪心搜索,即每一個時間步都取出一個條件概率最大的輸出,再將從開始到當(dāng)前步的結(jié)果作為輸入去獲得下一個時間步的輸出,直到模型給出生成結(jié)束的標(biāo)志。例如下圖,每一個時間步都取出了條件概率最大一個結(jié)果,生成了序列[A,B,C]。
貪心搜索示意圖很明顯,這樣做將原來指數(shù)級別的求解空間直接壓縮到了與長度線性相關(guān)的大小。由于丟棄了絕大多數(shù)的可能解,這種關(guān)注當(dāng)下的策略無法保證最終得到的序列概率是最優(yōu)的。
Beam Search
而beam search是對貪心策略一個改進(jìn)。思路也很簡單,就是稍微放寬一些考察的范圍。在每一個時間步,不再只保留當(dāng)前分?jǐn)?shù)最高的1個輸出,而是保留num_beams個。當(dāng)num_beams=1時集束搜索就退化成了貪心搜索。
下圖是一個實(shí)際的例子,每個時間步有ABCDE共5種可能的輸出,即 ,圖中的num_beams=2,也就是說每個時間步都會保留到當(dāng)前步為止條件概率最優(yōu)的2個序列。
beam search示意圖在第一個時間步,A和C是最優(yōu)的兩個,因此得到了兩個結(jié)果[A],[C],其他三個就被拋棄了;
第二步會基于這兩個結(jié)果繼續(xù)進(jìn)行生成,在A這個分支可以得到5個候選人,[AA],[AB],[AC],[AD],[AE],C也同理得到5個,此時會對這10個進(jìn)行統(tǒng)一排名,再保留最優(yōu)的兩個,即圖中的[AB]和[CE];
第三步同理,也會從新的10個候選人里再保留最好的兩個,最后得到了[ABD],[CED]兩個結(jié)果。
可以發(fā)現(xiàn),beam search在每一步需要考察的候選人數(shù)量是貪心搜索的num_beams倍,因此是一種犧牲時間換性能的方法。
以上就是Beam Search的基本概念,下面我們解析一種高效率實(shí)現(xiàn)方式。
Beam Search代碼解析
Beam Search的原理雖然簡單,但實(shí)際實(shí)現(xiàn)的時候卻有很多細(xì)節(jié)要考慮。下面要解析這個實(shí)現(xiàn)出自于NLP界著名Python包Transformers[1],我為了說明方便做了一些改動。
一個正確且高效的算法需要處理的問題大概有兩個:
充分利用硬件,可以處理批量數(shù)據(jù),且盡量使用并行計算少用循環(huán)
處理好長短不同的生成結(jié)果
下面是基礎(chǔ)版的beam search函數(shù)定義。其中context是編碼器編碼獲得的向量,batch_size是每批數(shù)據(jù)中包含的樣本量,bos_token_id是句子開頭標(biāo)志的token id,pad_token_id是用于填充的token id,eos_token_id是句子結(jié)束標(biāo)志的token id。這里給參數(shù)填上的默認(rèn)值和我們后面講解時使用的例子是一致的。
def beam_search_generate(context,batch_size=3,max_length=20,min_length=2,num_beams=2,bos_token_id=101,pad_token_id=0,eos_token_id=102,):pass在函數(shù)中主要執(zhí)行以下三個步驟:
準(zhǔn)備初始輸入
在當(dāng)前生成的序列長度未達(dá)到max_length時擴(kuò)展生成序列
準(zhǔn)備最終輸出的序列
下面我們分別解析。
準(zhǔn)備初始輸入
# 建立beam容器,每個樣本一個 generated_hyps = [BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)for _ in range(batch_size) ]# 每個beam容器的得分,共batch_size*num_beams個 beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=encoder_input_ids.device) beam_scores = beam_scores.view(-1)# 每個樣本是否完成生成,共batch_size個 done = [False for _ in range(batch_size)]# 為了并行計算,一次生成batch_size*num_beams個序列 # 第一步自動填入bos_token input_ids = torch.full((batch_size*num_beams, 1),bos_token_id,dtype=torch.long,device=next(self.parameters()).device, )# 當(dāng)前長度設(shè)為1 cur_len = 1其中BeamHypotheses是一個容器類,每個樣本綁定一個。每個容器中會維護(hù)num_beams個當(dāng)前最優(yōu)的序列。當(dāng)往容器中添加一個序列而導(dǎo)致序列數(shù)大于num_beams的時候,它會自動踢掉分?jǐn)?shù)最低的那個序列。類代碼如下。
class BeamHypotheses(object):def __init__(self, num_beams, max_length, length_penalty):self.max_length = max_length - 1 # ignoring bos_tokenself.num_beams = num_beamsself.beams = []self.worst_score = 1e9def __len__(self):return len(self.beams)def add(self, hyp, sum_logprobs):score = sum_logprobs / len(hyp) ** self.length_penaltyif len(self) < self.num_beams or score > self.worst_score:# 可更新的情況:數(shù)量未飽和或超過最差得分self.beams.append((score, hyp))if len(self) > self.num_beams:# 數(shù)量飽和需要刪掉一個最差的sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])del self.beams[sorted_scores[0][1]]self.worst_score = sorted_scores[1][0]else:self.worst_score = min(score, self.worst_score)def is_done(self, best_sum_logprobs, cur_len=None):"""相關(guān)樣本是否已經(jīng)完成生成。best_sum_logprobs是新的候選序列中的最高得分。"""if len(self) < self.num_beams:return Falseelse:if cur_len is None:cur_len = self.max_lengthcur_score = best_sum_logprobs / cur_len ** self.length_penalty# 是否最高分比當(dāng)前保存的最低分還差ret = self.worst_score >= cur_scorereturn ret序列擴(kuò)展
序列擴(kuò)展是beam search的核心過程,我們特地畫了一張圖來解釋這個版本的實(shí)現(xiàn)策略。
序列擴(kuò)展示意圖下面對照這個圖來講解代碼。
while cur_len < max_length:# 將編碼器得到的上下文向量和當(dāng)前結(jié)果輸入解碼器,即圖中1output = decoder.decode_next_step(context, input_ids)# 輸出矩陣維度為:(batch*num_beams)*cur_len*vocab_size# 取出最后一個時間步的各token概率,即當(dāng)前條件概率# (batch*num_beams)*vocab_sizescores = next_token_logits = output[:, -1, :]############################ 這里可以做一大堆操作減少重復(fù) ############################# 計算序列條件概率的,因?yàn)槿×薼og,所以直接相加即可。得到圖中2矩陣# (batch_size * num_beams, vocab_size)next_scores = scores + beam_scores[:, None].expand_as(scores)# 為了提速,將結(jié)果重排成圖中3的形狀next_scores = next_scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)# 取出分?jǐn)?shù)最高的token(圖中黑點(diǎn))和其對應(yīng)得分# sorted=True,保證返回序列是有序的next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)# 下一個時間步整個batch的beam列表# 列表中的每一個元素都是三元組# (分?jǐn)?shù), token_id, beam_id)next_batch_beam = []# 對每一個樣本進(jìn)行擴(kuò)展for batch_idx in range(batch_size):# 檢查樣本是否已經(jīng)生成結(jié)束if done[batch_idx]:# 對于已經(jīng)結(jié)束的句子,待添加的是pad tokennext_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batchcontinue# 當(dāng)前樣本下一個時間步的beam列表next_sent_beam = []# 對于還未結(jié)束的樣本需要找到分?jǐn)?shù)最高的num_beams個擴(kuò)展# 注意,next_scores和next_tokens是對應(yīng)的# 而且已經(jīng)按照next_scores排好順序for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(zip(next_tokens[batch_idx], next_scores[batch_idx])):# get beam and word IDs# 這兩行可參考圖中3進(jìn)行理解beam_id = beam_token_id // vocab_sizetoken_id = beam_token_id % vocab_sizeeffective_beam_id = batch_idx * num_beams + beam_id# 如果出現(xiàn)了EOS token說明已經(jīng)生成了完整句子if (eos_token_id is not None) and (token_id.item() == eos_token_id):# if beam_token does not belong to top num_beams tokens, it should not be addedis_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beamsif is_beam_token_worse_than_top_num_beams:continue# 往容器中添加這個序列g(shù)enerated_hyps[batch_idx].add(input_ids[effective_beam_id].clone(), beam_token_score.item(),)else:# add next predicted word if it is not eos_tokennext_sent_beam.append((beam_token_score, token_id, effective_beam_id))# 擴(kuò)展num_beams個就夠了if len(next_sent_beam) == num_beams:break# 檢查這個樣本是否已經(jīng)生成完了,有兩種情況# 1. 已經(jīng)記錄過該樣本結(jié)束# 2. 新的結(jié)果沒有使結(jié)果改善done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(next_scores[batch_idx].max().item(), cur_len=cur_len)# 把當(dāng)前樣本的結(jié)果添加到batch結(jié)果的后面next_batch_beam.extend(next_sent_beam)# 如果全部樣本都已經(jīng)生成結(jié)束便可以直接退出了if all(done):break# 把三元組列表再還原成三個獨(dú)立列表beam_scores = beam_scores.new([x[0] for x in next_batch_beam])beam_tokens = input_ids.new([x[1] for x in next_batch_beam])beam_idx = input_ids.new([x[2] for x in next_batch_beam])# 準(zhǔn)備下一時刻的解碼器輸入# 取出實(shí)際被擴(kuò)展的beaminput_ids = input_ids[beam_idx, :]# 在這些beam后面接上新生成的tokeninput_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)# 更新當(dāng)前長度cur_len = cur_len + 1# end of length while乍一看是不是有些復(fù)雜,我感覺關(guān)鍵的有以下幾點(diǎn):
只有出現(xiàn)了EOS token才會將生成的序列裝進(jìn)該樣本對應(yīng)的容器中
當(dāng)前input_ids保存著當(dāng)前得分最高的num_beams個序列
準(zhǔn)備輸出
上面那個while循環(huán)跳出意味著已經(jīng)生成了長度為max_length的文本,比較理想的情況是所有的句子都已經(jīng)生成出了eos_token_id,即句子生成結(jié)束了。但并不是所有情況都這樣,對于那些”意猶未盡“的樣本,我們需要先手動結(jié)束。
# 將未結(jié)束的生成結(jié)果結(jié)束,并置入容器中 for batch_idx in range(batch_size):# 已經(jīng)結(jié)束的樣本不需處理if done[batch_idx]:continue# 把結(jié)果加入到generated_hyps容器for beam_id in range(num_beams):effective_beam_id = batch_idx * num_beams + beam_idfinal_score = beam_scores[effective_beam_id].item()final_tokens = input_ids[effective_beam_id]generated_hyps[batch_idx].add(final_tokens, final_score)經(jīng)過上面的處理,所有生成好的句子都已經(jīng)保存在generated_hyps容器中,每個容器內(nèi)保存著num_beams個序列,最后就是輸出期望個數(shù)的句子。
# select the best hypotheses,最終輸出 # 每個樣本返回幾個句子 output_num_return_sequences_per_batch = 1 # 記錄每個返回句子的長度,用于后面pad sent_lengths = input_ids.new(output_batch_size) best = []# 對每個樣本取出最好的output_num_return_sequences_per_batch個句子 for i, hypotheses in enumerate(generated_hyps):sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])for j in range(output_num_return_sequences_per_batch):effective_batch_idx = output_num_return_sequences_per_batch * i + jbest_hyp = sorted_hyps.pop()[1]sent_lengths[effective_batch_idx] = len(best_hyp)best.append(best_hyp)# 如果長短不一則pad句子,使得最后返回結(jié)果的長度一樣 if sent_lengths.min().item() != sent_lengths.max().item():sent_max_len = min(sent_lengths.max().item() + 1, max_length)# 先把輸出矩陣填滿PAD tokendecoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)# 填入真正的內(nèi)容for i, hypo in enumerate(best):decoded[i, : sent_lengths[i]] = hypo# 填上eos tokenif sent_lengths[i] < max_length:decoded[i, sent_lengths[i]] = eos_token_id else:# 所有生成序列都還沒結(jié)束,直接堆疊即可decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)# 返回的結(jié)果包含BOS token return decoded總結(jié)
好了,上面就是最基礎(chǔ)的beam search算法。這樣生成出來的結(jié)果已經(jīng)會比貪心搜索好一些,但還是會遇到諸如詞語重復(fù)這樣的問題。其實(shí)已經(jīng)有很多針對重復(fù)問題的研究,我們在代碼中也已經(jīng)留出了位置,下期再見咯。
參考資料
[1]
Transformers: https://github.com/huggingface/transformers
個人微信:加時請注明 (昵稱+公司/學(xué)校+方向)
總結(jié)
以上是生活随笔為你收集整理的十分钟读懂Beam Search(1/2)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 基于HTML电商购物项目的设计与实现——
- 下一篇: 2023年Digitalocean权威评