[深度学习] 自然语言处理 --- Self-Attention(三) 知识点与源码解析
在當(dāng)前的 NLP 領(lǐng)域,Transformer / BERT 已然成為基礎(chǔ)應(yīng)用,而 Self-Attention? 則是兩者的核心部分,下面嘗試用 Q&A 和源碼的形式深入 Self-Attention 的細(xì)節(jié)。
?
一 Q&A
1. Self-Attention 的核心是什么?
Self-Attention 的核心是用文本中的其它詞來增強(qiáng)目標(biāo)詞的語義表示,從而更好的利用上下文的信息。
2. Self-Attention 的時(shí)間復(fù)雜度是怎么計(jì)算的?
Self-Attention 時(shí)間復(fù)雜度:O(*d),這里n 是序列的長度,d 是 embedding 的維度,不考慮 batch 維。
Self-Attention 包括三個(gè)步驟:相似度計(jì)算,softmax 和加權(quán)平均。
它們分別的時(shí)間復(fù)雜度是:
- 相似度計(jì)算 可以看作大小為(n, d) 和 (d, n) 的兩個(gè)矩陣相乘, (n, d) * (d, n) = O(*d), 得到一個(gè)(n, n)的矩陣
- softmax 就是直接計(jì)算了,時(shí)間復(fù)雜度為 O()
- 加權(quán)平均 可以看作大小為(n, n)和(n, d)的兩個(gè)矩陣相乘 (n, n) * (n, d) = O(*d), 得到一個(gè)(n, d) 的矩陣。
因此,Self-Attention 的時(shí)間復(fù)雜度是 O(*d)
2.1? 為什么在進(jìn)行softmax之前需要對attention進(jìn)行scaled(為什么除以 sdk的平方根)?
????? 假設(shè) Q 和 K 的均值為0,方差為1。它們的矩陣乘積將有均值為0,方差為dk,因此使用dk的平方根被用于縮放,因?yàn)?#xff0c;Q 和 K 的矩陣乘積的均值本應(yīng)該為 0,方差本應(yīng)該為1,這樣可以獲得更平緩的softmax。當(dāng)維度很大時(shí),點(diǎn)積結(jié)果會(huì)很大,會(huì)導(dǎo)致softmax的梯度很小。為了減輕這個(gè)影響,對點(diǎn)積進(jìn)行縮放。
2.3? 計(jì)算attention的時(shí)候?yàn)楹芜x擇點(diǎn)乘而不是加法?兩者計(jì)算復(fù)雜度和效果上有什么區(qū)別?
K和Q的點(diǎn)乘是為了得到一個(gè)attention score 矩陣,用來對V進(jìn)行提純。K和Q使用了不同的W_k, W_Q來計(jì)算,可以理解為是在不同空間上的投影。正因?yàn)?有了這種不同空間的投影,增加了表達(dá)能力,這樣計(jì)算得到的attention score矩陣的泛化能力更高。
?
3. Transformer為何使用多頭注意力機(jī)制?
In practice, the multi-headed attention are done with transposes and reshapes rather than actual separate tensors. —— 來自 google BERT 源代碼注釋
Tansformer 中的 Multi-Head Attention,簡單來說就是多個(gè) Self-Attention 的組合,它的作用類似于 CNN 中的多核。多頭的實(shí)現(xiàn)不是循環(huán)的計(jì)算每個(gè)頭,而是通過 transposes and reshapes,用矩陣乘法來完成的。多頭可以使參數(shù)矩陣形成多個(gè)子空間,矩陣整體的size不變,只是改變了每個(gè)head對應(yīng)的維度大小,這樣做使矩陣對多方面信息進(jìn)行學(xué)習(xí),但是計(jì)算量和單個(gè)head差不多。
將原有的高維空間轉(zhuǎn)化為多個(gè)低維空間并再最后進(jìn)行拼接,形成同樣維度的輸出,借此豐富特性信息,降低了計(jì)算量
?
Transformer/BERT 中把 d ,也就是 hidden_size/embedding_size 這個(gè)維度做了 reshape 拆分,可以去看 Google 的 TF 源碼或者上面的 pytorch 源碼
?hidden_size (d) = num_attention_heads (m) * attention_head_size (a),也即 d=m*a。
并將 num_attention_heads 維度 transpose 到前面,使得 Q 和 K 的維度都是 (m,n,a),這里不考慮 batch 維度。
這樣點(diǎn)積可以看作大小為 (m,n,a) 和 (m,a,n) 的兩個(gè)張量相乘,得到一個(gè) (m,n,n) 的矩陣,其實(shí)就相當(dāng)于 m 個(gè)頭,時(shí)間復(fù)雜度是:
=
張量乘法時(shí)間復(fù)雜度分析參見:矩陣、張量乘法的時(shí)間復(fù)雜度分析 [1]。
因此 Multi-Head Attention 時(shí)間復(fù)雜度就是 ,而實(shí)際上,張量乘法可以加速,因此實(shí)際復(fù)雜度會(huì)更低一些。
?
4. 不考慮多頭的原因,self-attention中詞向量不乘QKV參數(shù)矩陣,會(huì)怎么樣?
對于 Attention 機(jī)制,都可以用統(tǒng)一的 query/key/value 模式去解釋,而對于? self-attention,一般會(huì)說它的 q=k=v,這里的相等實(shí)際上是指它們來自同一個(gè)基礎(chǔ)向量,而在實(shí)際計(jì)算時(shí),它們是不一樣的,因?yàn)檫@三者都是乘了 QKV 參數(shù)矩陣的。那如果不乘,每個(gè)詞對應(yīng)的 q,k,v 就是完全一樣的。
在 self-attention 中,sequence 中的每個(gè)詞都會(huì)和 sequence 中的每個(gè)詞做點(diǎn)積去計(jì)算相似度,也包括這個(gè)詞本身。
在相同量級(jí)的情況下,qi 與 ki 點(diǎn)積的值會(huì)是最大的(可以從“兩數(shù)和相同的情況下,兩數(shù)相等對應(yīng)的積最大”類比過來)。
那在 softmax 后的加權(quán)平均中,該詞本身所占的比重將會(huì)是最大的,使得其他詞的比重很少,無法有效利用上下文信息來增強(qiáng)當(dāng)前詞的語義表示。
而乘以 QKV 參數(shù)矩陣,會(huì)使得每個(gè)詞的 q,k,v 都不一樣,能很大程度上減輕上述的影響。
當(dāng)然,QKV 參數(shù)矩陣也使得多頭,類似于 CNN 中的多核,去捕捉更豐富的特征/信息成為可能。
?
5. 在常規(guī) attention 中,一般有 K=V,那 self-attention 可以嘛?
Self-Attention 實(shí)際只是 attention 中的一種特殊情況,因此 k=v 是沒有問題的,也即 K,V 參數(shù)矩陣相同。
擴(kuò)展到 Multi-Head Attention 中,乘以 Q、K 參數(shù)矩陣之后,其實(shí)就已經(jīng)保證了多頭之間的差異性了,在 q 和 k 點(diǎn)積 +softmax 得到相似度之后,從常規(guī) attention 的角度,覺得再去乘以和 k 相等的 v 會(huì)更合理一些。
在 Transformer / BERT 中,完全獨(dú)立的 QKV 參數(shù)矩陣,可以擴(kuò)大模型的容量和表達(dá)能力。
但采用 Q,K=V 這樣的參數(shù)模式,我認(rèn)為也是沒有問題的,也能減少模型的參數(shù),又不影響多頭的實(shí)現(xiàn)。
當(dāng)然,上述想法并沒有做過實(shí)驗(yàn),為個(gè)人觀點(diǎn),僅供參考。
?
6. Q和K使用不同的權(quán)重矩陣生成,為何不能使用同一個(gè)值進(jìn)行自身的點(diǎn)乘?
答:請求和鍵值初始為不同的權(quán)重是為了解決可能輸入句長與輸出句長不一致的問題。并且假如QK維度一致,如果不用Q,直接拿K和K點(diǎn)乘的話,你會(huì)發(fā)現(xiàn)attention score 矩陣是一個(gè)對稱矩陣。因?yàn)槭峭瑯右粋€(gè)矩陣,都投影到了同樣一個(gè)空間,所以泛化能力很差。
?
?
?
二 源碼
在整個(gè) Transformer / BERT 的代碼中,(Multi-Head Scaled Dot-Product) Self-Attention 的部分是相對最復(fù)雜的,也是 Transformer / BERT 的精髓所在,這里給出 Pytorch 版本的實(shí)現(xiàn) [2],并對重要的代碼加上了注釋和維度說明。
話不多說,都在代碼里,它主要有三個(gè)部分:
初始化:包括有幾個(gè)頭,每個(gè)頭的大小,并初始化 Q K V 三個(gè)參數(shù)矩陣。
class SelfAttention(nn.Module):def __init__(self, config):super(SelfAttention, self).__init__()if config.hidden_size % config.num_attention_heads != 0:raise ValueError("The hidden size (%d) is not a multiple of the number of attention ""heads (%d)" % (config.hidden_size, config.num_attention_heads))# 在Transformer/BERT中,這里的 all_head_size 就等于 config.hidden_size# 應(yīng)該是一種簡化,為了從embedding到最后輸出維度都保持一致# 這樣使得多個(gè)attention頭合起來維度還是config.hidden_size# 而 attention_head_size 就是每個(gè)attention頭的維度,要保證可以整除self.num_attention_heads = config.num_attention_headsself.attention_head_size = int(config.hidden_size / config.num_attention_heads)self.all_head_size = self.num_attention_heads * self.attention_head_size# 三個(gè)參數(shù)矩陣self.query = nn.Linear(config.hidden_size, self.all_head_size)self.key = nn.Linear(config.hidden_size, self.all_head_size)self.value = nn.Linear(config.hidden_size, self.all_head_size)self.dropout = nn.Dropout(config.attention_probs_dropout_prob)transposes and reshapes:這個(gè)函數(shù)主要是把維度大小為 [batch_size * seq_length * hidden_size] 的 q,k,v 向量變換成 [batch_size * num_attention_heads * seq_length * attention_head_size],便于后面做 Multi-Head Attention。
def transpose_for_scores(self, x):"""shape of x: batch_size * seq_length * hidden_size這個(gè)操作是把hidden_size分解為 self.num_attention_heads * self.attention_head_size然后再交換 seq_length 維度 和 num_attention_heads 維度為什么要做這一步:因?yàn)閍ttention是要對query中的每個(gè)字和key中的每個(gè)字做點(diǎn)積,即是在 seq_length 維度上query和key的點(diǎn)積是 [seq_length * attention_head_size] * [attention_head_size * seq_length]=[seq_length * seq_length]"""# 這里是一個(gè)維度拼接:(1,2)+(3,4) -> (1, 2, 3, 4)new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)x = x.view(*new_x_shape)return x.permute(0, 2, 1, 3)前向計(jì)算: 乘以 QKV 參數(shù)矩陣 —> transposes and reshapes —> 做 scaled —> 加 attention mask —> Softmax —> 加權(quán)平均 —> 維度恢復(fù)。
def forward(self, hidden_states, attention_mask):# shape of hidden_states and mixed_*_layer: batch_size * seq_length * hidden_sizemixed_query_layer = self.query(hidden_states)mixed_key_layer = self.key(hidden_states)mixed_value_layer = self.value(hidden_states)# shape of *_layer: batch_size * num_attention_heads * seq_length * attention_head_sizequery_layer = self.transpose_for_scores(mixed_query_layer)key_layer = self.transpose_for_scores(mixed_key_layer)value_layer = self.transpose_for_scores(mixed_value_layer)# Take the dot product between "query" and "key" to get the raw attention scores.# shape of attention_scores: batch_size * num_attention_heads * seq_length * seq_lengthattention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))# 這里就是做 Scaled,將方差統(tǒng)一到1,避免維度的影響attention_scores /= math.sqrt(self.attention_head_size)# shape of attention_mask: batch_size * 1 * 1 * seq_length. 它可以自動(dòng)廣播到和attention_scores一樣的維度# 我們初始輸入的attention_mask是:batch_size * seq_length,做了兩次unsqueeze之后得到當(dāng)前的attention_maskattention_scores = attention_scores + attention_mask# Normalize the attention scores to probabilities. Softmax 不改變維度# shape of attention_scores: batch_size * num_attention_heads * seq_length * seq_lengthattention_probs = nn.Softmax(dim=-1)(attention_scores)attention_probs = self.dropout(attention_probs)# shape of value_layer: batch_size * num_attention_heads * seq_length * attention_head_size# shape of first context_layer: batch_size * num_attention_heads * seq_length * attention_head_size# shape of second context_layer: batch_size * seq_length * num_attention_heads * attention_head_size# context_layer 維度恢復(fù)到:batch_size * seq_length * hidden_sizecontext_layer = torch.matmul(attention_probs, value_layer)context_layer = context_layer.permute(0, 2, 1, 3).contiguous()new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)context_layer = context_layer.view(*new_context_layer_shape)return context_layerAttention is all you need ! 希望這篇文章能讓你對 Self-Attention 有更深的理解。
參考文獻(xiàn)
[1] https://liwt31.github.io/2018/10/12/mul-complexity/
[2] https://github.com/hichenway/CodeShare/tree/master/bert_pytorch_source_code
https://blog.csdn.net/c9Yv2cf9I06K2A9E/article/details/109212878
21個(gè)Transformer面試題的簡單回答
?
?
?
總結(jié)
以上是生活随笔為你收集整理的[深度学习] 自然语言处理 --- Self-Attention(三) 知识点与源码解析的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 再也不怕手一抖跳广告了!规范App乱跳转
- 下一篇: [链表] --- 反转链表(leetco