何使用BERT模型实现中文的文本分类
生活随笔
收集整理的這篇文章主要介紹了
何使用BERT模型实现中文的文本分类
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
原文網址:https://blog.csdn.net/Real_Brilliant/article/details/84880528
如何使用BERT模型實現中文的文本分類
- 前言
- Pytorch
- readme
- 參數表
- 算法流程
- 1. 概述
- 2. 讀取數據
- 3. 特征轉換
- 4. 模型訓練
- 5. 模型測試
- 6. 測試結果
- 7. 總結
前言
Pytorch
readme
- 請先安裝pytorch的BERT代碼,代碼源見前言(2)pip install pytorch-pretrained-bert
- 1
參數表
| 輸入數據目錄 | 加載的bert模型,對于中文文本請輸入’bert-base-chinese’ | 輸入數據預處理模塊,最好根據應用場景自定義 |
| model_save_pth | max_seq_length* | train_batch_size |
| 模型參數保存地址 | 最大文本長度 | batch大小 |
| learning_rate | num_train_epochs | |
| Adam初始學習步長 | 最大epoch數 |
* max_seq_length = 所設定的文本長度 + 2 ,BERT會給每個輸入文本開頭和結尾分別加上[CLS]和[SEP]標識符,因此會占用2個字符空間,其作用會在后續進行詳細說明。
算法流程
1. 概述
訓練階段利用驗證集調整參數選取驗證集上得分最高的模型測試階段加載預訓練模型讀取數據特征轉換模型訓練保存最佳模型參數加載訓練階段最佳模型讀取數據特征轉換輸入模型并進行測試2. 讀取數據
- 對應于參數表中的task_name,是用于數據讀取的模塊
- 可以根據自身需要自定義新的數據讀取模塊
- 以輸入數據為json文件時為例,數據讀取模塊包含兩個部分:
- 基類DataProcessor:class DataProcessor(object): def get_train_examples(self, data_dir):raise NotImplementedError() def get_dev_examples(self, data_dir):raise NotImplementedError()def get_test_examples(self, data_dir):raise NotImplementedError()def get_labels(self):raise NotImplementedError()@classmethod def _read_json(cls, input_file, quotechar=None):"""Reads a tab separated value file."""dicts = []with codecs.open(input_file, 'r', 'utf-8') as infs:for inf in infs:inf = inf.strip()dicts.append(json.loads(inf))return dicts
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 用于數據讀取的模塊MyPro:class MyPro(DataProcessor):def get_train_examples(self, data_dir):return self._create_examples(self._read_json(os.path.join(data_dir, "train.json")), 'train') def get_dev_examples(self, data_dir):return self._create_examples(self._read_json(os.path.join(data_dir, "val.json")), 'dev')def get_test_examples(self, data_dir):return self._create_examples(self._read_json(os.path.join(data_dir, "test.json")), 'test')def get_labels(self):return [0, 1]def _create_examples(self, dicts, set_type):examples = []for (i, infor) in enumerate(dicts):guid = "%s-%s" % (set_type, i)text_a = infor['question']label = infor['label']examples.append(InputExample(guid=guid, text_a=text_a, label=label))return examples
-
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 需要注意的幾點是:
- data_dir目錄下應包含名為train、val、test的三個文件,根據文件格式不同需要對讀取方式進行修改
- get_labels()返回的是所有可能的類別label_list,比如['數學', '英語', '語文']、[1, 2, 3]…
- 模塊最終返回一個名為examples的列表,每個列表元素中包含序號、中文文本、類別三個元素
3. 特征轉換
- convert_examples_to_features是用于將examples轉換為特征,也即features的函數。
- features包含4個數據:
- input_ids:分詞后每個詞語在vocabulary中的id,補全符號對應的id為0,[CLS]和[SEP]的id分別為101和102。應注意的是,在中文BERT模型中,中文分詞是基于字而非詞的分詞。
- input_mask:真實字符/補全字符標識符,真實文本的每個字對應1,補全符號對應0,[CLS]和[SEP]也為1。
- segment_ids:句子A和句子B分隔符,句子A對應的全為0,句子B對應的全為1。但是在多數文本分類情況下并不會用到句子B,所以基本不用管。
- label_id :將label_list中的元素利用字典轉換為index標識,即label_map = {}
for (i, label) in enumerate(label_list):label_map[label] = i
- 1
- 2
- 3
- features中一個元素的例子是:
- 轉換完成后的特征值就可以作為輸入,用于模型的訓練和測試
4. 模型訓練
- 完成讀取數據、特征轉換之后,將特征送入模型進行訓練
- 訓練算法為BERT專用的Adam算法
- 訓練集、測試集、驗證集比例為6:2:2
- 每一個epoch后會在驗證集上進行驗證,并給出相應的f1值,如果f1值大于此前最高分則保存模型參數,否則flags加1。如果flags大于6,也即連續6個epoch模型的性能都沒有繼續優化,停止訓練過程。f1 = val(model, processor, args, label_list, tokenizer, device)
if f1 > best_score:best_score = f1print('*f1 score = {}'.format(f1))flags = 0checkpoint = {'state_dict': model.state_dict()}torch.save(checkpoint, args.model_save_pth)
else:print('f1 score = {}'.format(f1))flags += 1if flags >=6:break
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 如果epoch數超過先前設定的num_train_epochs,同樣會停止迭代。
5. 模型測試
- 先加載模型
- 送數據,取得分,完事
- 暫時還沒加打印測試結果到文件的功能,后續會加上
6. 測試結果
val_F1test_F1 Fast text 0.7218 0.7094 Text rnn + bigru 0.7383 0.7194 Text cnn 0.7292 0.7088 bigru + attention 0.7335 0.7146 RCNN 0.7355 0.7213 BERT 0.7938 0.787 - 基于真實數據做的文本分類,用過不少模型,BERT的性能可以說是獨一檔
- BERT確實牛逼,不過一部分原因也是模型量級就不一樣
7. 總結
- 使用代碼的時候按照參數表修改下參數,把數據按照命名規范放data_dir目錄下一般就沒啥問題了
- 最多還要修改下讀取數據的代碼(如果數據不是.json格式的),就可以跑通了
- 最后可以根據個人需要,對模型訓練邏輯、epoch數、學習步長等地方做進一步修改
- 代碼地址已經放在前言(3)里了
總結
以上是生活随笔為你收集整理的何使用BERT模型实现中文的文本分类的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: jupyter notebook切换到其
- 下一篇: 使用docker部署flask项目