Bert系列(三)——源码解读之Pre-train
https://www.jianshu.com/p/22e462f01d8c
pre-train是遷移學習的基礎,雖然Google已經發布了各種預訓練好的模型,而且因為資源消耗巨大,自己再預訓練也不現實(在Google Cloud TPU v2 上訓練BERT-Base要花費近500刀,耗時達到兩周。在GPU上可想而知只會更貴),但是學習bert的預訓練方法可以為我們弄懂整個bert的運行流程提供莫大的幫助。預訓練涉及到的模塊有點多,所以這也將會是一篇長文,在能簡略的地方我盡量簡略,還是那句話,我的文章只能是起到一個導讀的作用,如果想摸清里面的各種細節還是要自己把源碼過一遍的。
pre-train涉及到的模塊分為以下三個,我將為大家一一介紹:
1.tokenization.py
2.create_pretraining_data.py
3.run_pretraining.py
其中tokenization是對原始句子內容的解析,分為BasicTokenizer和WordpieceTokenizer兩個,不只是在預訓練中,在fine-tune和推斷過程同樣要用到它;create_pretraining_data顧名思義就是將原始語料轉換成適合模型預訓練的輸入數據;run_pretraining就是預訓練的執行代碼了。
一、tokenization.py
1、BasicTokenizer
class BasicTokenizer(object):"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""def __init__(self, do_lower_case=True):self.do_lower_case = do_lower_casedef tokenize(self, text):"""Tokenizes a piece of text."""text = convert_to_unicode(text)text = self._clean_text(text)text = self._tokenize_chinese_chars(text)orig_tokens = whitespace_tokenize(text)split_tokens = []for token in orig_tokens:if self.do_lower_case:token = token.lower()token = self._run_strip_accents(token)split_tokens.extend(self._run_split_on_punc(token))output_tokens = whitespace_tokenize(" ".join(split_tokens))return output_tokensdef _run_strip_accents(self, text):"""Strips accents from a piece of text."""text = unicodedata.normalize("NFD", text)output = []for char in text:cat = unicodedata.category(char)if cat == "Mn":continueoutput.append(char)return "".join(output)def _run_split_on_punc(self, text):"""Splits punctuation on a piece of text."""chars = list(text)i = 0start_new_word = Trueoutput = []while i < len(chars):char = chars[i]if _is_punctuation(char):output.append([char])start_new_word = Trueelse:if start_new_word:output.append([])start_new_word = Falseoutput[-1].append(char)i += 1return ["".join(x) for x in output]def _tokenize_chinese_chars(self, text):"""Adds whitespace around any CJK character."""output = []for char in text:cp = ord(char)if self._is_chinese_char(cp):output.append(" ")output.append(char)output.append(" ")else:output.append(char)return "".join(output)def _is_chinese_char(self, cp):"""Checks whether CP is the codepoint of a CJK character."""if ((cp >= 0x4E00 and cp <= 0x9FFF) or #(cp >= 0x3400 and cp <= 0x4DBF) or #(cp >= 0x20000 and cp <= 0x2A6DF) or #(cp >= 0x2A700 and cp <= 0x2B73F) or #(cp >= 0x2B740 and cp <= 0x2B81F) or #(cp >= 0x2B820 and cp <= 0x2CEAF) or(cp >= 0xF900 and cp <= 0xFAFF) or #(cp >= 0x2F800 and cp <= 0x2FA1F)): #return Truereturn Falsedef _clean_text(self, text):"""Performs invalid character removal and whitespace cleanup on text."""output = []for char in text:cp = ord(char)if cp == 0 or cp == 0xfffd or _is_control(char):continueif _is_whitespace(char):output.append(" ")else:output.append(char)return "".join(output)
BasicTokenizer的主要是進行unicode轉換、標點符號分割、小寫轉換、中文字符分割、去除重音符號等操作,最后返回的是關于詞的數組(中文是字的數組)
2、WordpieceTokenizer
class WordpieceTokenizer(object):"""Runs WordPiece tokenziation."""def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):self.vocab = vocabself.unk_token = unk_tokenself.max_input_chars_per_word = max_input_chars_per_worddef tokenize(self, text):text = convert_to_unicode(text)output_tokens = []for token in whitespace_tokenize(text):chars = list(token)if len(chars) > self.max_input_chars_per_word:output_tokens.append(self.unk_token)continueis_bad = Falsestart = 0sub_tokens = []while start < len(chars):end = len(chars)cur_substr = Nonewhile start < end:substr = "".join(chars[start:end])if start > 0:substr = "##" + substrif substr in self.vocab:cur_substr = substrbreakend -= 1if cur_substr is None:is_bad = Truebreaksub_tokens.append(cur_substr)start = endif is_bad:output_tokens.append(self.unk_token)else:output_tokens.extend(sub_tokens)return output_tokens
WordpieceTokenizer的目的是將合成詞分解成類似詞根一樣的詞片。例如將"unwanted"分解成["un", "##want", "##ed"]這么做的目的是防止因為詞的過于生僻沒有被收錄進詞典最后只能以[UNK]代替的局面,因為英語當中這樣的合成詞非常多,詞典不可能全部收錄。
3、FullTokenizer
class FullTokenizer(object):"""Runs end-to-end tokenziation."""def __init__(self, vocab_file, do_lower_case=True):self.vocab = load_vocab(vocab_file)self.inv_vocab = {v: k for k, v in self.vocab.items()}self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)def tokenize(self, text):split_tokens = []for token in self.basic_tokenizer.tokenize(text):for sub_token in self.wordpiece_tokenizer.tokenize(token):split_tokens.append(sub_token)return split_tokensdef convert_tokens_to_ids(self, tokens):return convert_by_vocab(self.vocab, tokens)def convert_ids_to_tokens(self, ids):return convert_by_vocab(self.inv_vocab, ids)
FullTokenizer的作用就很顯而易見了,對一個文本段進行以上兩種解析,最后返回詞(字)的數組,同時還提供token到id的索引以及id到token的索引。這里的token可以理解為文本段處理過后的最小單元。
二、create_pretraining_data.py
1、配置
flags.DEFINE_string("input_file", None,"Input raw text file (or comma-separated list of files).")
flags.DEFINE_string("output_file", None,"Output TF example file (or comma-separated list of files).")
flags.DEFINE_string("vocab_file", None,"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_bool("do_lower_case", True,"Whether to lower case the input text. Should be True for uncased ""models and False for cased models.")
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
flags.DEFINE_integer("max_predictions_per_seq", 20,"Maximum number of masked LM predictions per sequence.")
flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
flags.DEFINE_integer("dupe_factor", 10,"Number of times to duplicate the input data (with different masks).")
flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
flags.DEFINE_float("short_seq_prob", 0.1,"Probability of creating sequences which are shorter than the ""maximum length.")
配置input_file、output_file分別代表輸入的源語料文件和處理過的預料文件地址;
do_lower_case:是否全部轉為小寫字母,是否轉換成小寫字母的意義在Bert系列(一)——demo運行里面已經說過了。
dupe_factor:默認重復10次,目的是可以生成不同情況的masks;
short_seq_prob:構造長度小于指定"max_seq_length"的樣本比例。因為在fine-tune過程里面輸入的target_seq_length是可變的(小于等于max_seq_length),那么為了防止過擬合也需要在pre-train的過程當中構造一些短的樣本。
2、main入口
def main(_):tf.logging.set_verbosity(tf.logging.INFO)tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)input_files = []for input_pattern in FLAGS.input_file.split(","):input_files.extend(tf.gfile.Glob(input_pattern))tf.logging.info("*** Reading from input files ***")for input_file in input_files:tf.logging.info(" %s", input_file)rng = random.Random(FLAGS.random_seed)instances = create_training_instances(input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,rng)output_files = FLAGS.output_file.split(",")tf.logging.info("*** Writing to output files ***")for output_file in output_files:tf.logging.info(" %s", output_file)write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,FLAGS.max_predictions_per_seq, output_files)
從入口開始看,步驟很簡單:1)構造tokenizer ;2)構造instances ;3)保存instances
3、構造instances
def create_training_instances(input_files, tokenizer, max_seq_length,dupe_factor, short_seq_prob, masked_lm_prob,max_predictions_per_seq, rng):"""Create `TrainingInstance`s from raw text."""all_documents = [[]]for input_file in input_files:with tf.gfile.GFile(input_file, "r") as reader:while True:line = tokenization.convert_to_unicode(reader.readline())if not line:breakline = line.strip()# Empty lines are used as document delimitersif not line:all_documents.append([])tokens = tokenizer.tokenize(line)if tokens:all_documents[-1].append(tokens)# Remove empty documentsall_documents = [x for x in all_documents if x]rng.shuffle(all_documents)vocab_words = list(tokenizer.vocab.keys())instances = []for _ in range(dupe_factor):for document_index in range(len(all_documents)):instances.extend(create_instances_from_document(all_documents, document_index, max_seq_length, short_seq_prob,masked_lm_prob, max_predictions_per_seq, vocab_words, rng))rng.shuffle(instances)return instances
這一步是閱讀數據,數據的輸入文本可以是一個文件也可以是用逗號分割的若干文件;
文件里用換行來表示句子的邊界,即一句一行,同理段落之間用空一行來表示段落的邊界,一個段落表示成一個document;具體的構造方法在create_instances_from_document函數里面。
def create_instances_from_document(all_documents, document_index, max_seq_length, short_seq_prob,masked_lm_prob, max_predictions_per_seq, vocab_words, rng):"""Creates `TrainingInstance`s for a single document."""document = all_documents[document_index]# Account for [CLS], [SEP], [SEP]max_num_tokens = max_seq_length - 3target_seq_length = max_num_tokensif rng.random() < short_seq_prob:target_seq_length = rng.randint(2, max_num_tokens)instances = []current_chunk = []current_length = 0i = 0while i < len(document):segment = document[i]current_chunk.append(segment)current_length += len(segment)if i == len(document) - 1 or current_length >= target_seq_length:if current_chunk:# `a_end` is how many segments from `current_chunk` go into the `A`# (first) sentence.a_end = 1if len(current_chunk) >= 2:a_end = rng.randint(1, len(current_chunk) - 1)tokens_a = []for j in range(a_end):tokens_a.extend(current_chunk[j])tokens_b = []# Random nextis_random_next = Falseif len(current_chunk) == 1 or rng.random() < 0.5:is_random_next = Truetarget_b_length = target_seq_length - len(tokens_a)for _ in range(10):random_document_index = rng.randint(0, len(all_documents) - 1)if random_document_index != document_index:breakrandom_document = all_documents[random_document_index]random_start = rng.randint(0, len(random_document) - 1)for j in range(random_start, len(random_document)):tokens_b.extend(random_document[j])if len(tokens_b) >= target_b_length:breaknum_unused_segments = len(current_chunk) - a_endi -= num_unused_segments# Actual nextelse:is_random_next = Falsefor j in range(a_end, len(current_chunk)):tokens_b.extend(current_chunk[j])truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)assert len(tokens_a) >= 1assert len(tokens_b) >= 1tokens = []segment_ids = []tokens.append("[CLS]")segment_ids.append(0)for token in tokens_a:tokens.append(token)segment_ids.append(0)tokens.append("[SEP]")segment_ids.append(0)for token in tokens_b:tokens.append(token)segment_ids.append(1)tokens.append("[SEP]")segment_ids.append(1)(tokens, masked_lm_positions,masked_lm_labels) = create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)instance = TrainingInstance(tokens=tokens,segment_ids=segment_ids,is_random_next=is_random_next,masked_lm_positions=masked_lm_positions,masked_lm_labels=masked_lm_labels)instances.append(instance)current_chunk = []current_length = 0i += 1return instances
這一段算是整個模塊的核心了。
instance = TrainingInstance(tokens=tokens,segment_ids=segment_ids,is_random_next=is_random_next,masked_lm_positions=masked_lm_positions,masked_lm_labels=masked_lm_labels)
1)一個instance 包含一個tokens,實際上就是輸入的詞序列;該序列表現形式為:
[CLS] A [SEP] B [SEP]
A=[token_0, token_1, ...,token_i]
B=[token_i+1, token_i+2, ...,token_n-1]
其中:
2<= n < max_seq_length - 3 (in short_seq_prob)
n=max_seq_length - 3 (in 1-short_seq_prob)
token 最后表現形式如下圖所示:
? tokens示意圖segment_ids 指的形式為[0,0,0...1,1,111] 0的個數為i+1個,1的個數為max_seq_length - (i+1)
對應到模型輸入就是token_type
is_random_next:其實就是上圖的Label,0.5的概率為True(和當只有一個segment的時候),如果為True則B和A不屬于同一document。剩下的情況為False,則B為A同一document的后續句子。
masked_lm_positions:序列里被[MASK]的位置;
masked_lm_labels:序列里被[MASK]的token
2)在create_masked_lm_predictions函數里,一個序列在指定MASK數量之后,有80%被真正MASK,10%還是保留原來token,10%被隨機替換成其他token。
4、保存instance
def write_instance_to_example_files(instances, tokenizer, max_seq_length,max_predictions_per_seq, output_files):"""Create TF example files from `TrainingInstance`s."""writers = []for output_file in output_files:writers.append(tf.python_io.TFRecordWriter(output_file))writer_index = 0total_written = 0for (inst_index, instance) in enumerate(instances):input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)input_mask = [1] * len(input_ids)segment_ids = list(instance.segment_ids)assert len(input_ids) <= max_seq_lengthwhile len(input_ids) < max_seq_length:input_ids.append(0)input_mask.append(0)segment_ids.append(0)assert len(input_ids) == max_seq_lengthassert len(input_mask) == max_seq_lengthassert len(segment_ids) == max_seq_lengthmasked_lm_positions = list(instance.masked_lm_positions)masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)masked_lm_weights = [1.0] * len(masked_lm_ids)while len(masked_lm_positions) < max_predictions_per_seq:masked_lm_positions.append(0)masked_lm_ids.append(0)masked_lm_weights.append(0.0)next_sentence_label = 1 if instance.is_random_next else 0features = collections.OrderedDict()features["input_ids"] = create_int_feature(input_ids)features["input_mask"] = create_int_feature(input_mask)features["segment_ids"] = create_int_feature(segment_ids)features["masked_lm_positions"] = create_int_feature(masked_lm_positions)features["masked_lm_ids"] = create_int_feature(masked_lm_ids)features["masked_lm_weights"] = create_float_feature(masked_lm_weights)features["next_sentence_labels"] = create_int_feature([next_sentence_label])tf_example = tf.train.Example(features=tf.train.Features(feature=features))writers[writer_index].write(tf_example.SerializeToString())writer_index = (writer_index + 1) % len(writers)total_written += 1if inst_index < 20:tf.logging.info("*** Example ***")tf.logging.info("tokens: %s" % " ".join([tokenization.printable_text(x) for x in instance.tokens]))for feature_name in features.keys():feature = features[feature_name]values = []if feature.int64_list.value:values = feature.int64_list.valueelif feature.float_list.value:values = feature.float_list.valuetf.logging.info("%s: %s" % (feature_name, " ".join([str(x) for x in values])))for writer in writers:writer.close()tf.logging.info("Wrote %d total instances", total_written)
instance保存沒什么好說的,只有兩點:
while len(input_ids) < max_seq_length:input_ids.append(0)input_mask.append(0)segment_ids.append(0)
1)之前不是有short_seq_prob的概率導致樣本的長度小于max_predictions_per_seq嗎,這里把這些樣本補齊,padding為0,同樣的還有input_mask和segment_ids;
2) 把instance的is_random_next轉化成變量next_sentence_label保存。
為了驗證這個數據模塊對中文輸入輸出的支持,我做了個測試:
python3 create_pretraining_data.py --input_file=/tmp/zh_test.txt --output_file=/tmp/output.txt --vocab_file=$BERT_ZH_DIR/vocab.txt
zh_test.txt是我臉滾鍵盤隨意輸入的一些漢字,共有兩段,每段兩句話:
酒店附近開房的艱苦的飛機飛抵發窘惹風波,覺得覅奇偶均衡能否v不。
極度瘋狂減肥的人能否打開v高科技就而后就覅哦冏結構i惡如桂萼黑人牙膏覅u我也【發票未開u俄日附件二我就佛i額外階級感v,我為何軍方的我i和服i好熱哦iu均為輻9為uui和覅文化覅哦佛為進度覅u蠱蛾i巨乳古人規格i兼顧如果我是破看到v個ui就火熱i今年的付款了幾個vi哦素問。就覺發給金佛i為借口破碎的夢
i覺得覅u而非各位i風格較為哦個粉色哦i多發幾個v二哥i文件哦i怪獸決斗盤可加熱管覅u個人文集狗哥
vocab.txt是下載的bert中文預訓練模型里的詞典
最后的部分輸出如下所示:
INFO:tensorflow:*** Example ***
INFO:tensorflow:tokens: [CLS] i 覺 得 [UNK] u [MASK] 非 [MASK] 位 i 風 格 較 ##by 哦 個 驅 色 哦 i 多 發 [MASK] 個 v 二 哥 i 文 件 哦 i 怪 [MASK] 決 斗 盤 可 加 熱 管 [MASK] u [MASK] [MASK] 文 集 狗 哥 [SEP] [MASK] [UNK] 奇 偶 均 衡 能 否 v 不 。 極 [MASK] 瘋 狂 減 肥 的 人 能 否 打 開 v 高 科 技 就 而 [MASK] 就 [UNK] 哦 冏 結 構 i 惡 如 桂 萼 黑 人 牙 膏 [UNK] u 我 也 【 發 票 未 開 [MASK] 俄 日 [MASK] 件 二 我 就 佛 i 額 [MASK] 階 [MASK] 感 v [MASK] 我 為 [MASK] 軍 方 [SEP]
INFO:tensorflow:input_ids: 101 151 6230 2533 100 163 103 7478 103 855 151 7599 3419 6772 8684 1521 702 7705 5682 1521 151 1914 1355 103 702 164 753 1520 151 3152 816 1521 151 2597 103 1104 3159 4669 1377 1217 4178 5052 103 163 103 103 3152 7415 4318 1520 102 103 100 1936 981 1772 6130 5543 1415 164 679 511 3353 103 4556 4312 1121 5503 4638 782 5543 1415 2802 2458 164 7770 4906 2825 2218 5445 103 2218 100 1521 1087 5310 3354 151 2626 1963 3424 5861 7946 782 4280 5601 100 163 2769 738 523 1355 4873 3313 2458 103 915 3189 103 816 753 2769 2218 867 151 7583 103 7348 103 2697 164 103 2769 711 103 1092 3175 102
INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
INFO:tensorflow:segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
INFO:tensorflow:masked_lm_positions: 6 8 14 17 23 34 42 44 45 46 51 63 80 105 108 116 118 121 124 0
INFO:tensorflow:masked_lm_ids: 5445 1392 711 5106 1126 1077 100 702 782 3152 2533 2428 1400 163 7353 1912 5277 8024 862 0
INFO:tensorflow:masked_lm_weights: 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 0.0
INFO:tensorflow:next_sentence_labels: 1
可以看到token序列里的中文確實是以字的形式出現的
三、run_pretraining.py
終于到預訓練的執行模塊了,里面大部分都是tensorflow訓練的常規代碼,感覺沒什么好分析的。
看過前面的內容和我前兩章內容的朋友我想已經初步知道預訓練的整個邏輯了,這里作一個簡單的介紹:
1、X和Y的確定
input_ids = features["input_ids"]input_mask = features["input_mask"]segment_ids = features["segment_ids"]masked_lm_positions = features["masked_lm_positions"]masked_lm_ids = features["masked_lm_ids"]masked_lm_weights = features["masked_lm_weights"]next_sentence_labels = features["next_sentence_labels"]model = modeling.BertModel(config=bert_config,is_training=is_training,input_ids=input_ids,input_mask=input_mask,token_type_ids=segment_ids,use_one_hot_embeddings=use_one_hot_embeddings)
其中input_ids、input_mask 、segment_ids 作為X,剩下的masked_lm_positions、masked_lm_ids 、masked_lm_weights 、next_sentence_labels 共同作為Y
2、 loss
(masked_lm_loss,masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(bert_config, model.get_sequence_output(), model.get_embedding_table(),masked_lm_positions, masked_lm_ids, masked_lm_weights)(next_sentence_loss, next_sentence_example_loss,next_sentence_log_probs) = get_next_sentence_output(bert_config, model.get_pooled_output(), next_sentence_labels)total_loss = masked_lm_loss + next_sentence_loss
可以看到loss 分別由masked_lm_loss和next_sentence_loss組成,masked_lm_loss針對的是語言模型對MASK起來的標簽的預測,即上下文語境預測當前詞;而next_sentence_loss是對于句子關系的預測。前者在遷移學習中可以用于標注類任務(分詞、NER等),后者可以用于句子關系任務(QA、自然語言推理等)。
需要多說一句的是,masked_lm_loss,用到了模型的sequence_output和embedding_table,這是因為對多個MASK的標簽進行預測是一個標注問題,所以需要獲取最后一層的整個sequence,而embedding_table用來反embedding,這樣就映射到token的學習了。而next_sentence_loss用到的是pooled_output,對應的是第一個token [CLS],它一般用于分類任務的學習。
總結:
本文介紹了以下幾個內容:
1、tokenization模塊:我把它叫做對原始文本段的解析,只有解析過后才能標準化輸入;
2、create_pretraining_data模塊:對原始數據進行轉換,原始數據本是無標簽的數據,通過句子的拼接可以產生句子關系的標簽,通過MASK可以產生標注的標簽,其本質是語言模型的應用;
3、run_pretraining模塊:在執行預訓練的時候針對以上兩種標簽分別利用bert模型的不同輸出部件,計算loss,然后進行梯度下降優化。
本文系列
Bert系列(一)——demo運行
Bert系列(二)——模型主體源碼解讀
Bert系列(四)——源碼解讀之Fine-tune
Bert系列(五)——中文分詞實踐 F1 97.8%(附代碼)
Reference
1.https://github.com/google-research/bert
2.BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding
作者:西溪雷神
鏈接:https://www.jianshu.com/p/22e462f01d8c
來源:簡書
簡書著作權歸作者所有,任何形式的轉載都請聯系作者獲得授權并注明出處。
總結
以上是生活随笔為你收集整理的Bert系列(三)——源码解读之Pre-train的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: BERT-Pytorch demo初探
- 下一篇: BERT大火却不懂Transformer