5-RNN-03_双向rnn_英文小短文
生活随笔
收集整理的這篇文章主要介紹了
5-RNN-03_双向rnn_英文小短文
小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
import os
import numpy as np
import tensorflow as tfdef load_data(file_path):"""加載原始數(shù)據(jù):param file_path::return:"""with open(file_path, 'r') as reader:data = reader.readlines()return datadef create_lookup_table(text):"""構(gòu)建字典表 {單詞:數(shù)字} {數(shù)字:單詞}:param text::return:"""words = sorted(list(set(text)))# 構(gòu)建字典word2int = {word:idx for idx,word in enumerate(words)}int2word = dict(enumerate(words))return word2int, int2worddef create_X_and_Y(data, word2int, number_time_steps=3):"""基于原始數(shù)據(jù),構(gòu)建訓(xùn)練數(shù)據(jù)集的 X和Y:param data::param word2int::param number_time_steps::return:"""X, Y = [], []for content in data:# 得到當(dāng)前文本對(duì)應(yīng)的單詞序列。 strip()去除前后空格words = content.strip().split(' ')# 獲得單詞總數(shù)量words_number = len(words)offset = 0while offset < words_number - number_time_steps:temp_x = words[offset: offset+number_time_steps]temp_y = words[offset+number_time_steps]X.append([word2int[tx] for tx in temp_x])Y.append(word2int[temp_y])offset +=1# 將列表轉(zhuǎn)為numpy ndarrayX = np.asarray(X).reshape([-1, number_time_steps])Y = np.asarray(Y).reshape(-1)return X, Ydef create_model(vocab_size, num_units=32, number_time_steps=3):""":param vocab_size: 詞表大小:param num_units: 隱藏層的節(jié)點(diǎn)數(shù)量(神經(jīng)元個(gè)數(shù)):param number_time_steps: 時(shí)間步:return:"""with tf.variable_scope('Network', initializer=tf.truncated_normal_initializer(stddev=0.1)):with tf.variable_scope('input'):# 輸入數(shù)據(jù)的形狀"""x:[[2, 3, 4],[7, 8, 9]]y:[[5],[10]],"""_x = tf.placeholder(tf.int32, shape=[None, number_time_steps], name='x')_y = tf.placeholder(tf.int32, shape=[None], name='y')_x = tf.cast(_x, tf.float32)# 需要將原始的輸入_x 按照時(shí)間步進(jìn)行分割,變成列表。# todo 用的真實(shí)的值,但實(shí)際項(xiàng)目中 應(yīng)該 用one-hot或者embedding。input_x = tf.split(_x, num_or_size_splits=number_time_steps, axis=1)# [[N, 1], [N,1], ......]with tf.variable_scope('rnn'):# a、定義cellcell_fw = tf.nn.rnn_cell.BasicLSTMCell(num_units=num_units)cell_bw = tf.nn.rnn_cell.BasicLSTMCell(num_units=num_units)# b、調(diào)用雙向靜態(tài)rnn 獲取隱藏層輸出結(jié)果rnn_outputs, _, _ = tf.nn.static_bidirectional_rnn(cell_fw=cell_fw, cell_bw=cell_bw, inputs=input_x, dtype=tf.float32)# rnn_outputs: [[N, 2*lstm_size], [N, 2*lstm_size], ....]with tf.variable_scope('logits'):# a、獲取隱藏層最后一個(gè)時(shí)刻的輸出rnn_output = rnn_outputs[-1]# b、構(gòu)建輸出層變量softmax_w = tf.get_variable('w', shape=[2*num_units, vocab_size], dtype=tf.float32)softmax_b = tf.get_variable('b', shape=[vocab_size], dtype=tf.float32, initializer=tf.zeros_initializer())logits = tf.nn.xw_plus_b(rnn_output, softmax_w, softmax_b)with tf.variable_scope('Predict'):predictions = tf.argmax(logits, axis=1)return _x, _y, logits, predictionsdef create_loss(logits, labels):"""創(chuàng)建損失:param logits::param labels::return:"""with tf.name_scope('loss'):# a\將標(biāo)簽轉(zhuǎn)換為1維的形式labels = tf.reshape(labels, shape=[-1])loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))return lossdef create_optimizer(loss, lr=1e-3):"""構(gòu)建優(yōu)化器:param loss::param lr::return:"""with tf.name_scope('optimizer'):optimizer = tf.train.AdamOptimizer(learning_rate=lr)train_opt = optimizer.minimize(loss)return train_optdef train(checkpoint_dir, max_steps=10000, batch_size=64, num_units=32, number_time_steps=10):graph = tf.Graph()with graph.as_default():# 加載數(shù)據(jù)data = load_data(file_path='../datas/belling_the_cat.txt')text = []for line in data:line = line.strip()for word in line.split(' '):text.append(word)word2int, int2word = create_lookup_table(text)x, y = create_X_and_Y(data, word2int, number_time_steps=number_time_steps)# print(word2int, '\n', int2word)# 1、構(gòu)建網(wǎng)絡(luò)_x, _y, logits, predictions = create_model(len(word2int), num_units=num_units, number_time_steps=number_time_steps)# 2、模型損失loss = create_loss(logits, _y)# 3、優(yōu)化器train_opt = create_optimizer(loss)saver = tf.train.Saver()with tf.Session(graph=graph) as sess:sess.run(tf.global_variables_initializer())# 構(gòu)建迭代數(shù)據(jù)total_samples = x.shape[0]n_batches = total_samples // batch_sizetime = 0# 返回一個(gè)隨機(jī)打亂下標(biāo)的 array, 功能就是shufflerandom_index = np.random.permutation(total_samples)for step in range(1, max_steps):# 獲取當(dāng)前批量的訓(xùn)練數(shù)據(jù)start_idx = time * batch_sizeend_idx = start_idx + batch_sizeidx = random_index[start_idx: end_idx]train_x = x[idx]train_y = y[idx]# 構(gòu)建輸入數(shù)據(jù)對(duì)象feed = {_x: train_x, _y: train_y}sess.run(train_opt, feed)if step % 200==0:train_loss = sess.run(loss, feed)print('step:{} - Train loss:{}'.format(step, train_loss))# 做一個(gè)預(yù)測(cè)的index = np.random.randint(low=0, high=total_samples)sample_in = np.reshape(x[index], newshape=[-1, number_time_steps])sample_out = sess.run(predictions, feed_dict={_x: sample_in})print('輸入:{} - 預(yù)測(cè):{} VS 真實(shí)值:{}'.format(x[index], int2word[sample_out[0]], int2word[y[index]]))if step % 1000 == 0:# 模型持久化files = 'model.ckpt'save_files = os.path.join(checkpoint_dir, files)saver.save(sess, save_path=save_files, global_step=step)print('model saved!!')# 更新樣本順序的time += 1if time == n_batches:time =0random_index = np.random.permutation(total_samples)if __name__ == '__main__':checkpoint_dir = './models'if not os.path.exists(checkpoint_dir):os.makedirs(checkpoint_dir)train(checkpoint_dir, max_steps=10000, batch_size=64, num_units=32, number_time_steps=10)
D:\Anaconda\python.exe D:/AI20/HJZ/04-深度學(xué)習(xí)/4-RNN/20191228___AI20_RNN/03_雙向rnn_英文小短文.py
2020-02-18 10:42:51.076290: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX AVX2
step:200 - Train loss:3.9305477142333984
輸入:[ 8 7 1 76 39 1 85 66 24 30] - 預(yù)測(cè):, VS 真實(shí)值:consists
step:400 - Train loss:3.4807467460632324
輸入:[ 77 100 44 86 60 2 91 69 56 109] - 預(yù)測(cè):, VS 真實(shí)值:general
step:600 - Train loss:2.976116418838501
輸入:[ 86 59 63 86 23 2 21 91 53 101] - 預(yù)測(cè):the VS 真實(shí)值:should
step:800 - Train loss:2.6706745624542236
輸入:[65 12 11 61 83 2 88 86 64 58] - 預(yù)測(cè):the VS 真實(shí)值:said
step:1000 - Train loss:2.3042399883270264
輸入:[ 39 92 110 55 86 22 2 111 108 8] - 預(yù)測(cè):agree VS 真實(shí)值:agree
model saved!!
step:1200 - Train loss:1.834552526473999
輸入:[64 58 37 96 11 76 0 85 45 8] - 預(yù)測(cè):very VS 真實(shí)值:very
step:1400 - Train loss:1.6621453762054443
輸入:[ 40 14 1 101 28 31 34 35 40 2] - 預(yù)測(cè):i VS 真實(shí)值:i
step:1600 - Train loss:1.3067286014556885
輸入:[ 29 93 26 103 54 90 28 84 93 67] - 預(yù)測(cè):their VS 真實(shí)值:their
step:1800 - Train loss:1.2478928565979004
輸入:[ 1 86 57 38 4 36 29 93 26 103] - 預(yù)測(cè):measures VS 真實(shí)值:measures
step:2000 - Train loss:0.9508646726608276
輸入:[ 92 110 55 86 22 2 111 108 8 7] - 預(yù)測(cè):, VS 真實(shí)值:,
model saved!!
step:2200 - Train loss:0.8384971618652344
輸入:[61 83 2 88 86 64 58 76 46 45] - 預(yù)測(cè):easy VS 真實(shí)值:easy
step:2400 - Train loss:0.5928971171379089
輸入:[64 58 37 96 11 76 0 85 45 8] - 預(yù)測(cè):very VS 真實(shí)值:very
step:2600 - Train loss:0.5592891573905945
輸入:[31 34 35 40 2 41 98 1 89 1] - 預(yù)測(cè):to VS 真實(shí)值:to
step:2800 - Train loss:0.4942898452281952
輸入:[ 14 1 101 28 31 34 35 40 2 41] - 預(yù)測(cè):venture VS 真實(shí)值:venture
step:3000 - Train loss:0.35549503564834595
輸入:[ 77 100 44 86 60 2 91 69 56 109] - 預(yù)測(cè):general VS 真實(shí)值:general
model saved!!
step:3200 - Train loss:0.3749193549156189
輸入:[93 19 86 23 3 86 57 50 16 65] - 預(yù)測(cè):another VS 真實(shí)值:another
step:3400 - Train loss:0.29113298654556274
輸入:[86 23 2 82 76 91 1 11 82 76] - 預(yù)測(cè):that VS 真實(shí)值:that
step:3600 - Train loss:0.25510722398757935
輸入:[93 70 85 4 81 19 18 68 1 11] - 預(yù)測(cè):attached VS 真實(shí)值:attached
step:3800 - Train loss:0.2347201555967331
輸入:[ 26 103 54 90 28 84 93 67 87 25] - 預(yù)測(cè):enemy VS 真實(shí)值:enemy
step:4000 - Train loss:0.18923482298851013
輸入:[ 1 95 10 64 58 37 96 11 76 0] - 預(yù)測(cè):that VS 真實(shí)值:that
model saved!!
step:4200 - Train loss:0.16195067763328552
輸入:[23 3 86 57 50 16 65 12 11 61] - 預(yù)測(cè):spoke VS 真實(shí)值:spoke
step:4400 - Train loss:0.13978148996829987
輸入:[ 86 59 63 86 23 2 21 91 53 101] - 預(yù)測(cè):should VS 真實(shí)值:should
step:4600 - Train loss:0.1490730196237564
輸入:[103 54 90 28 84 93 67 87 25 33] - 預(yù)測(cè):, VS 真實(shí)值:,
step:4800 - Train loss:0.10961226373910904
輸入:[ 1 11 82 76 85 20 16 48 4 112] - 預(yù)測(cè):mouse VS 真實(shí)值:mouse
step:5000 - Train loss:0.11105622351169586
輸入:[ 86 22 2 111 108 8 7 1 76 39] - 預(yù)測(cè):, VS 真實(shí)值:,
model saved!!
step:5200 - Train loss:0.0975622832775116
輸入:[ 78 9 47 104 77 100 5 1 11 28] - 預(yù)測(cè):easily VS 真實(shí)值:easily
step:5400 - Train loss:0.0716937854886055
輸入:[18 68 1 11 17 21 4 74 75 86] - 預(yù)測(cè):neck VS 真實(shí)值:neck
step:5600 - Train loss:0.07302534580230713
輸入:[ 91 53 101 78 9 47 104 77 100 5] - 預(yù)測(cè):, VS 真實(shí)值:,
step:5800 - Train loss:0.05743904039263725
輸入:[66 24 30 27 44 86 80 11 94 52] - 預(yù)測(cè):in VS 真實(shí)值:in
step:6000 - Train loss:0.05397602543234825
輸入:[ 1 85 66 24 30 27 44 86 80 11] - 預(yù)測(cè):treacherous VS 真實(shí)值:treacherous
model saved!!
step:6200 - Train loss:0.054213933646678925
輸入:[87 25 33 1 86 23 2 82 76 91] - 預(yù)測(cè):, VS 真實(shí)值:,
step:6400 - Train loss:0.0373719185590744
輸入:[23 2 82 76 91 1 11 82 76 85] - 預(yù)測(cè):but VS 真實(shí)值:but
step:6600 - Train loss:0.046218693256378174
輸入:[ 42 101 28 71 82 79 63 40 14 1] - 預(yù)測(cè):we VS 真實(shí)值:we
step:6800 - Train loss:0.03185339272022247
輸入:[ 20 16 48 4 112 58 37 96 11 76] - 預(yù)測(cè):he VS 真實(shí)值:he
step:7000 - Train loss:0.026730481535196304
輸入:[ 1 11 17 21 4 74 75 86 59 63] - 預(yù)測(cè):the VS 真實(shí)值:the
model saved!!
step:7200 - Train loss:0.02903711423277855
輸入:[111 108 8 7 1 76 39 1 85 66] - 預(yù)測(cè):chief VS 真實(shí)值:chief
step:7400 - Train loss:0.026526065543293953
輸入:[ 57 38 4 36 29 93 26 103 54 90] - 預(yù)測(cè):could VS 真實(shí)值:could
step:7600 - Train loss:0.02054942026734352
輸入:[107 45 93 19 86 23 3 86 57 50] - 預(yù)測(cè):at VS 真實(shí)值:at
step:7800 - Train loss:0.01777688041329384
輸入:[ 85 20 16 48 4 112 58 37 96 11] - 預(yù)測(cè):said VS 真實(shí)值:said
step:8000 - Train loss:0.014596270397305489
輸入:[112 58 37 96 11 76 39 38 4 69] - 預(yù)測(cè):to VS 真實(shí)值:to
model saved!!
step:8200 - Train loss:0.015546170994639397
輸入:[107 45 93 19 86 23 3 86 57 50] - 預(yù)測(cè):at VS 真實(shí)值:at
step:8400 - Train loss:0.01338121946901083
輸入:[16 65 12 11 61 83 2 88 86 64] - 預(yù)測(cè):mouse VS 真實(shí)值:mouse
step:8600 - Train loss:0.014673653990030289
輸入:[ 1 85 66 24 30 27 44 86 80 11] - 預(yù)測(cè):treacherous VS 真實(shí)值:treacherous
step:8800 - Train loss:0.010602903552353382
輸入:[75 86 59 63 86 23 2 21 91 53] - 預(yù)測(cè):we VS 真實(shí)值:we
step:9000 - Train loss:0.01917443238198757
輸入:[ 1 20 107 45 93 19 86 23 3 86] - 預(yù)測(cè):mice VS 真實(shí)值:mice
model saved!!
step:9200 - Train loss:0.012528151273727417
輸入:[ 1 89 1 93 70 85 4 81 19 18] - 預(yù)測(cè):procured VS 真實(shí)值:procured
step:9400 - Train loss:0.009897150099277496
輸入:[ 26 103 54 90 28 84 93 67 87 25] - 預(yù)測(cè):enemy VS 真實(shí)值:enemy
step:9600 - Train loss:0.007559692487120628
輸入:[ 20 16 48 4 112 58 37 96 11 76] - 預(yù)測(cè):he VS 真實(shí)值:he
step:9800 - Train loss:0.008503235876560211
輸入:[33 1 86 23 2 82 76 91 1 11] - 預(yù)測(cè):some VS 真實(shí)值:someProcess finished with exit code 0
總結(jié)
以上是生活随笔為你收集整理的5-RNN-03_双向rnn_英文小短文的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 1.HTML技术
- 下一篇: 永磁同步电机控制笔记:foc控制原理通俗