多标签文本分类 [ALBERT](附代码)
目前,中文多標簽文本分類的方法主要有3種,今天我們來詳細介紹及實踐其中的一種,算法框架使用的是ALBERT。
一、介紹
假設個人愛好的集合一共有6個元素:運動、旅游、讀書、工作、睡覺、美食。
一般情況下,一個人的愛好有這其中的一個或者多個,那么這就是一個典型的多標簽分類任務。
二、框架及算法
1、Placeholder
首先,我們需要設置一些占位符(Placeholder),占位符的作用是在訓練和推理的過程中feed模型需要的數據。我們這里需要4個占位符,分別是input_ids、input_masks、segment_ids和label_ids。前面3個是我們了解的BERT輸入特征,最后面一個是標簽的id。
2、ALBERT token-vectors
從圖中紅色的框內可以看出,ALBERT需要傳入3個參數(input_ids、input_masks、segment_ids),就可以得到我們所需要的一個2維向量output_layer:(batch_size, hidden_size)。
有人在這里就會好奇,為什么ALBERT輸出的是一個2維向量,而不是一個3維向量(batch_size, sequence_length, hidden_size)呢?那我們來看一下源碼,弄清楚self.model.get_pooled_output()的來歷。
其中self.sequence_ouput其實就是我們所說的那個3維向量(batch_size, sequence_length, hidden_size)。我們對這個3維向量做了一個"pooler"的操作,從而使之變成了一個2維的向量,這個操作是上面藍色方框內的內容。
藍色方框內的解釋為:”We "pool" the model by simply taking the hidden state corresponding to the first token. We assume that this has been pre-trained“。這句話怎么理解呢?意思是將整個句子的特征信息投射到句子第一個字的隱藏狀態向量上面。并且,認為這個它是通過預訓練得到的。
3、Full connection
最后,就是一個全連接層了。很簡單,全連接層的作用是將output_layer投射到我們的標簽上面。
4、上面3點在多標簽文本分類和文本分類并沒有區別。那么區別在哪里呢?
主要有以下3個區別:
- 交叉熵
- 輸出概率
- 輸出標簽
4.1、交叉熵
在文本分類中,我們使用的交叉熵為tf.nn.softmax_cross_entropy_with_logits;在多標簽文本分類中,我們使用的交叉熵則為tf.nn.sigmoid_cross_entropy_with_logits。這樣做的原因:
- tf.nn.sigmoid_cross_entropy_with_logits測量離散分類任務中的概率誤差,其中每個類是獨立的而不是互斥的。這適用于多標簽分類問題。
- tf.nn.softmax_cross_entropy_with_logits測量離散分類任務中的概率誤差,其中類之間是互斥的(每個條目恰好在一個類中)。這適用多分類問題。
4.2、輸出概率
在文本分類中,輸出概率為tf.nn.softmax(logits, axis=-1);在多標簽文本分類中,輸出概率為tf.nn.sigmoid(logits)。這樣做的原因:
- 在簡單的二進制分類中,sigmoid和softmax沒有太大的區別。
- 在多分類的情況下,sigmoid允許處理非獨占標簽(也稱為多標簽),而softmax處理獨占類。
4.3、輸出標簽
在文本分類(多元文本分類)中,label_ids的維度為(batch_size);在多標簽文本分類中,它的維度為(batch_size,num_labels)。這樣做的原因:
- 在多元文本分類中,最后得到的標簽只有一個,并且必須是其中的一個。
- 在多標簽文本分類中,最后得到的標簽可能有1個或者多個。
一般的多元分類是通過tf.argmax(logits)實現,返回的是最大的那個數值所在的label_id,因為logits對應每一個label_id都有一個概率。但是,在多標簽分類中,我們需要得到的是每一個標簽是否可以作為輸出標簽,所以每一個標簽可以作為輸出標簽的概率都會量化為一個0到1之間的值。所以當某一個標簽對應輸出概率小于0.5時,我們認為它不能作為當前句子的輸出標簽;反之,如果大于等于0.5,那么它代表了當前句子的輸出標簽之一。
三、實踐及框架圖
1、框架圖
2、模型Loss和Accuracy變化曲線圖
我們可以發現,這里的Loss和Accuracy的變化趨勢和多元文本分類有較大的區別。在多標簽文本分類的訓練過程中,Loss的下降幅度非???#xff0c;但并不代表模型的收斂快。在多元文本分類的訓練過程中,Loss一般在0.1-0.2之間的時候,模型基本上已經收斂。但是,在多標簽文本分類(當前框架下)的過程中,當Loss到達0.1-0.2時,模型收斂還需較多的steps。根據訓練經驗,在多標簽文本分類(這個框架下)的情況下,Loss往往要達到0.0001-0.001之間,模型才收斂。
四、代碼鏈接
hellonlp/classifier_multi_label?github.com
其它相關文章:HelloNLP:多標簽文本分類介紹,以及對比訓練?zhuanlan.zhihu.com
編輯于 03-02
總結
以上是生活随笔為你收集整理的多标签文本分类 [ALBERT](附代码)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【功能升级】达摩盘3.0全新标签介绍前言
- 下一篇: CML 2020 | 显式引入对分类标签