tensorflow包_在Keras和Tensorflow中使用深度卷积网络生成Meme(表情包)文本
作者?| dylan wenzlau
來源?|?Medium
編輯?|?代碼醫(yī)生團(tuán)隊(duì)
本文介紹如何構(gòu)建深度轉(zhuǎn)換網(wǎng)絡(luò)實(shí)現(xiàn)端到端的文本生成。在這一過程中,包括有關(guān)數(shù)據(jù)清理,訓(xùn)練,模型設(shè)計(jì)和預(yù)測算法相關(guān)的內(nèi)容。
第1步:構(gòu)建訓(xùn)練數(shù)據(jù)
數(shù)據(jù)集使用了Imgflip Meme Generator(一款根據(jù)文本生成表情包的工具)用戶的~100M公共memes標(biāo)題。為了加速訓(xùn)練并降低模型的復(fù)雜性,僅使用48個(gè)最受歡迎的Meme(表情包)和每個(gè)Meme(表情包)準(zhǔn)確的20,000個(gè)字幕,總計(jì)960,000個(gè)字幕作為訓(xùn)練數(shù)據(jù)。每個(gè)角色都會有一個(gè)訓(xùn)練示例在標(biāo)題中,總計(jì)約45,000,000個(gè)訓(xùn)練樣例。這里選擇了角色級生成而不是單詞級別,因?yàn)镸eme(表情包)傾向于使用拼寫和語法。此外字符級深度學(xué)習(xí)是單詞級深度學(xué)習(xí)的超集,因此如果有足夠的數(shù)據(jù)并且模型設(shè)計(jì)足以了解所有復(fù)雜性,則可以實(shí)現(xiàn)更高的準(zhǔn)確性。如果嘗試下面的完成模型,還會看到char級別可以更有趣!
https://imgflip.com/memegenerator
以下是第一個(gè)Meme(表情包)標(biāo)題是“制作所有memes”時(shí)的訓(xùn)練數(shù)據(jù)。省略了從數(shù)據(jù)庫中讀取代碼并執(zhí)行初始清理的代碼,因?yàn)樗浅?biāo)準(zhǔn),可以通過多種方式完成。
training_data?=?[
????["000000061533??0??",?"m"],
????["000000061533??0??m",?"a"],
????["000000061533??0??ma",?"k"],
????["000000061533??0??mak",?"e"],
????["000000061533??0??make",?"|"],
????["000000061533??1??make|",?"a"],
????["000000061533??1??make|a",?"l"],
????["000000061533??1??make|al",?"l"],
????["000000061533??1??make|all",?"?"],
????["000000061533??1??make|all?",?"t"],
????["000000061533??1??make|all?t",?"h"],
????["000000061533??1??make|all?th",?"e"],
????["000000061533??1??make|all?the",?"?"],
????["000000061533??1??make|all?the?",?"m"],
????["000000061533??1??make|all?the?m",?"e"],
????["000000061533??1??make|all?the?me",?"m"],
????["000000061533??1??make|all?the?mem",?"e"],
????["000000061533??1??make|all?the?meme",?"s"],
????["000000061533??1??make|all?the?memes",?"|"],
????...?45?million?more?rows?here?...
]#?we'll?need?our?feature?text?and?labels?as?separate?arrays?later
texts?=?[row[0]?for?row?in?training_data]
labels?=?[row[1]?for?row?in?training_data]
像機(jī)器學(xué)習(xí)中的大多數(shù)事情一樣,這只是一個(gè)分類問題。將左側(cè)的文本字符串分類為~70個(gè)不同的buckets 中的一個(gè),其中buckets 是字符。
解壓縮格式:
前12個(gè)字符是Meme(表情包)模板ID。這允許模型區(qū)分正在訓(xùn)練它的48個(gè)不同的Meme(表情包)。字符串左邊用零填充,因此所有ID都是相同的長度。
0或1是被預(yù)測的當(dāng)前文本框的索引,一般0是機(jī)頂盒和1是底盒,雖然許多記因是更復(fù)雜的。這兩個(gè)空格只是額外的間距,以確保模型可以將框索引與模板ID和Meme(表情包)文本區(qū)分開來。注意:至關(guān)重要的是卷積內(nèi)核寬度(在本文后面看到)不比4個(gè)空格加上索引字符(也就是≤5)寬。
之后是meme的文本,用|作為文本框的結(jié)尾字符。
最后一個(gè)字符(第二個(gè)數(shù)組項(xiàng))是序列中的下一個(gè)字符。
在訓(xùn)練之前,數(shù)據(jù)使用了幾種清洗技術(shù):
調(diào)整前導(dǎo)和尾隨空格,并用\s+單個(gè)空格字符替換重復(fù)的空格()。
應(yīng)用最少10個(gè)字符的字符串長度,這樣就不會生成無聊的單字或單字母Memes(表情包文本)。
應(yīng)用最大字符串長度為82個(gè)字符,因此不會生成超長表情包字符,因?yàn)槟P蛯⒏斓赜?xùn)練。82是任意的,它只是使整個(gè)訓(xùn)練字符串大約100個(gè)字符。
將所有內(nèi)容轉(zhuǎn)換為小寫以減少模型必須學(xué)習(xí)的字符數(shù),并且因?yàn)樵S多Memes(表情包文本)只是全部大寫。
使用非ascii字符跳過meme標(biāo)題可以降低模型必須學(xué)習(xí)的復(fù)雜性。這意味著特征文本和標(biāo)簽都將來自一組僅約70個(gè)字符,具體取決于訓(xùn)練數(shù)據(jù)恰好包含哪些ascii字符。
跳過包含豎線字符的meme標(biāo)題,|因?yàn)樗翘厥獾奈谋究蚪Y(jié)尾字符。
通過語言檢測庫運(yùn)行文本,并跳過不太可能是英語的meme標(biāo)題。提高生成的文本的質(zhì)量,因?yàn)槟P椭恍枰獙W(xué)習(xí)一種語言,相同的字符序列可以在多種語言中有意義。
跳過已添加到訓(xùn)練集中的重復(fù)Memes(表情包文本)標(biāo)題,以減少模型簡單記憶整個(gè)Memes(表情包文本)標(biāo)題的機(jī)會。
數(shù)據(jù)現(xiàn)在已準(zhǔn)備就緒,可以輸入神經(jīng)網(wǎng)絡(luò)!
第2步:數(shù)據(jù)轉(zhuǎn)換
首先,在代碼中導(dǎo)入python庫:?
from?keras?import?Sequentialfrom?keras.preprocessing.sequence?import?pad_sequencesfrom?keras.callbacks?import?ModelCheckpointfrom?keras.layers?import?Dense,?Dropout,?GlobalMaxPooling1D,?Conv1D,?MaxPooling1D,?Embeddingfrom?keras.layers.normalization?import?BatchNormalizationimport?numpy?as?npimport?util??#?util?is?a?custom?file?I?wrote,?see?github?link?below
因?yàn)樯窠?jīng)網(wǎng)絡(luò)只能對張量(向量/矩陣/多維數(shù)組)進(jìn)行操作,所以需要對文本進(jìn)行轉(zhuǎn)化。每個(gè)訓(xùn)練文本將通過從數(shù)據(jù)中找到的約70個(gè)唯一字符的數(shù)組中用相應(yīng)的索引替換每個(gè)字符,將其轉(zhuǎn)換為整數(shù)數(shù)組(等級1張量)。字符數(shù)組的順序是任意的,但選擇按字符頻率對其進(jìn)行排序,以便在更改訓(xùn)練數(shù)據(jù)量時(shí)保持大致一致。Keras有一個(gè)Tokenizer類,可以使用它(使用char_level = True),這里使用的是自己的util函數(shù),因?yàn)樗菿eras tokenizer更快。
#?output:?{'?':?1,?'0':?2,?'e':?3,?...?}
char_to_int?=?util.map_char_to_int(texts)#?output:?[[2,?2,?27,?11,?...],?...?]
sequences?=?util.texts_to_sequences(texts,?char_to_int)
labels?=?[char_to_int[char]?for?char?in?labels]
這些是數(shù)據(jù)按頻率順序包含的字符:
0etoains|rhl1udmy2cg4p53wf6b897kv."!?j:x,*"z-q/&$)(#%+_@=>;
接下來將填充帶有前導(dǎo)零的整數(shù)序列,因此它們的長度都相同,因?yàn)槟P偷膹埩繑?shù)學(xué)要求每個(gè)訓(xùn)練示例的形狀相同。(注意:可以在這里使用低至100的長度,因?yàn)槲谋局挥?00個(gè)字符,但希望以后所有的池操作都可以被2完全整除。)
SEQUENCE_LENGTH?=?128
data?=?pad_sequences(sequences,?maxlen=SEQUENCE_LENGTH)
最后將調(diào)整訓(xùn)練數(shù)據(jù)并將其分為訓(xùn)練和驗(yàn)證集。改組(隨機(jī)化順序)確保數(shù)據(jù)的特定子集不總是用于驗(yàn)證準(zhǔn)確性的子集。將一些數(shù)據(jù)拆分成驗(yàn)證集使能夠衡量模型在不允許它用于訓(xùn)練的示例上的表現(xiàn)。
#?randomize?order?of?training?data
indices?=?np.arange(data.shape[0])
np.random.shuffle(indices)
data?=?data[indices]
labels?=?labels[indices]#?validation?set?can?be?much?smaller?if?we?use?a?lot?of?data
validation_ratio?=?0.2?if?data.shape[0]?1000000?else?0.02
num_validation_samples?=?int(validation_ratio?*?data.shape[0])
x_train?=?data[:-num_validation_samples]
y_train?=?labels[:-num_validation_samples]
x_val?=?data[-num_validation_samples:]
y_val?=?labels[-num_validation_samples:]
第3步:模型設(shè)計(jì)
這里選擇使用卷積網(wǎng)絡(luò),在Keras上構(gòu)建conv網(wǎng)絡(luò)模型的代碼如下:
EMBEDDING_DIM?=?16
model?=?Sequential()
model.add(Embedding(len(char_to_int)?+?1,?EMBEDDING_DIM,?input_length=SEQUENCE_LENGTH))
model.add(Conv1D(1024,?5,?activation='relu',?padding='same'))
model.add(BatchNormalization())
model.add(MaxPooling1D(2))
model.add(Dropout(0.25))
model.add(Conv1D(1024,?5,?activation='relu',?padding='same'))
model.add(BatchNormalization())
model.add(MaxPooling1D(2))
model.add(Dropout(0.25))
model.add(Conv1D(1024,?5,?activation='relu',?padding='same'))
model.add(BatchNormalization())
model.add(MaxPooling1D(2))
model.add(Dropout(0.25))
model.add(Conv1D(1024,?5,?activation='relu',?padding='same'))
model.add(BatchNormalization())
model.add(MaxPooling1D(2))
model.add(Dropout(0.25))
model.add(Conv1D(1024,?5,?activation='relu',?padding='same'))
model.add(BatchNormalization())
model.add(GlobalMaxPooling1D())
model.add(Dropout(0.25))
model.add(Dense(1024,?activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.25))
model.add(Dense(len(labels_index),?activation='softmax'))
model.compile(loss='sparse_categorical_crossentropy',?optimizer='rmsprop',?metrics=['acc'])
代碼步驟如下:
首先,模型使用Keras嵌入將每個(gè)輸入示例從128個(gè)整數(shù)的數(shù)組(每個(gè)表示一個(gè)文本字符)轉(zhuǎn)換為128x16矩陣。嵌入是一個(gè)層,它學(xué)習(xí)將每個(gè)字符轉(zhuǎn)換為表示為整數(shù)的最佳方式,而不是表示為16個(gè)浮點(diǎn)數(shù)的數(shù)組[0.02, ..., -0.91]。這允許模型通過在16維空間中將它們彼此靠近地嵌入來了解哪些字符的使用類似,并最終提高模型預(yù)測的準(zhǔn)確性。
接下來,添加5個(gè)卷積層,每個(gè)層的內(nèi)核大小為5,1024個(gè)過濾器,以及ReLU激活。從概念上講,第一個(gè)轉(zhuǎn)換層正在學(xué)習(xí)如何從字符構(gòu)造單詞,后來的層正在學(xué)習(xí)構(gòu)建更長的單詞和單詞鏈(n-gram),每個(gè)單詞都比前一個(gè)更抽象。
padding='same' 用于確保圖層的輸出尺寸與輸入尺寸相同,因?yàn)榉駝t寬度5卷積會使內(nèi)核的每一側(cè)的圖層尺寸減小2。
選擇1024作為濾波器的數(shù)量,因?yàn)樗怯?xùn)練速度和模型精度之間的良好折衷,由試驗(yàn)和錯(cuò)誤確定。對于其他數(shù)據(jù)集,我建議從128個(gè)過濾器開始,然后將其增加/減少兩倍,以查看會發(fā)生什么。更多過濾器通常意味著更好的模型準(zhǔn)確性,但訓(xùn)練速度較慢,運(yùn)行時(shí)預(yù)測較慢,模型尺寸較大。但是如果數(shù)據(jù)太少或過濾器太多,模型可能會過度擬合,精度會下降,在這種情況下,應(yīng)該減少過濾器。
在測試尺寸為2,3,5和7之后選擇大小為5的卷積核。其中2和3的卷積確實(shí)更差, 7需要更多的參數(shù),這會使訓(xùn)練變慢。在研究中,其他人已經(jīng)成功地使用了3到7種不同組合的卷積大小,大小為5的卷積核通常在文本數(shù)據(jù)上表現(xiàn)得相當(dāng)不錯(cuò)。
選擇ReLU激活是因?yàn)樗焖?#xff0c;簡單,并且非常適用于各種各樣的用例。
在每個(gè)conv層之后添加批量標(biāo)準(zhǔn)化,以便基于給定批次的均值和方差對下一層的輸入?yún)?shù)進(jìn)行標(biāo)準(zhǔn)化。深度學(xué)習(xí)工程師尚未完全理解這種機(jī)制,歸一化輸入?yún)?shù)可以提高訓(xùn)練速度,并且由于消失/爆炸的梯度,對于更深的網(wǎng)絡(luò)變得更加重要。在每個(gè)轉(zhuǎn)換層之后添加一個(gè)Dropout層,以幫助防止該層簡單地記憶數(shù)據(jù)和過度擬合。Dropout(0.25)隨機(jī)丟棄25%的參數(shù)(將它們設(shè)置為零)。
在每個(gè)轉(zhuǎn)換層之間添加MaxPooling1D(2),以將128個(gè)字符的序列“擠壓”成下列層中的64,32,16和8個(gè)字符的序列。從概念上講,這允許卷積濾波器從更深層中的文本中學(xué)習(xí)更多抽象模式,因?yàn)樵诿總€(gè)最大池操作將維度減少2倍之后,寬度5內(nèi)核將跨越兩倍的字符。
在所有轉(zhuǎn)換圖層之后,使用全局最大合并圖層,它與普通的最大合并圖層相同,只是它會自動選擇縮小輸入尺寸以匹配下一圖層的大小。最后一層只是標(biāo)準(zhǔn)的密集(完全連接)層,有1024個(gè)神經(jīng)元,最后是70個(gè)神經(jīng)元,因?yàn)榉诸惼餍枰獮?0個(gè)不同的標(biāo)簽輸出概率。
model.compile步驟非常標(biāo)準(zhǔn)。RMSprop優(yōu)化器是一個(gè)不錯(cuò)的優(yōu)化器,沒有嘗試為這個(gè)神經(jīng)網(wǎng)絡(luò)改變它。loss=sparse_categorical_crossentropy告訴希望它優(yōu)化的模型,以便在一組2個(gè)或更多類別(又名標(biāo)簽)中選擇最佳類別。“稀疏”部分指的是標(biāo)簽是0到70之間的整數(shù),而不是長度為70的一個(gè)one-hot陣列。使用一個(gè)one-hot陣列作為標(biāo)簽需要更多的內(nèi)存,更多的處理時(shí)間,并且不會影響模型的準(zhǔn)確性。不要使用一個(gè)one-hot標(biāo)簽!
Keras有一個(gè)很好的model.summary()功能,可以查看模型:
_________________________________________________________________
Layer?(type)?????????????????Output?Shape??????????????Param?#
=================================================================
embedding_1?(Embedding)??????(None,?128,?16)???????????1136
_________________________________________________________________
conv1d_1?(Conv1D)????????????(None,?128,?1024)?????????82944
_________________________________________________________________
batch_normalization_1?(Batch?(None,?128,?1024)?????????4096
_________________________________________________________________
max_pooling1d_1?(MaxPooling1?(None,?64,?1024)??????????0
_________________________________________________________________
dropout_1?(Dropout)??????????(None,?64,?1024)??????????0
_________________________________________________________________
conv1d_2?(Conv1D)????????????(None,?64,?1024)??????????5243904
_________________________________________________________________
batch_normalization_2?(Batch?(None,?64,?1024)??????????4096
_________________________________________________________________
max_pooling1d_2?(MaxPooling1?(None,?32,?1024)??????????0
_________________________________________________________________
dropout_2?(Dropout)??????????(None,?32,?1024)??????????0
_________________________________________________________________
conv1d_3?(Conv1D)????????????(None,?32,?1024)??????????5243904
_________________________________________________________________
batch_normalization_3?(Batch?(None,?32,?1024)??????????4096
_________________________________________________________________
max_pooling1d_3?(MaxPooling1?(None,?16,?1024)??????????0
_________________________________________________________________
dropout_3?(Dropout)??????????(None,?16,?1024)??????????0
_________________________________________________________________
conv1d_4?(Conv1D)????????????(None,?16,?1024)??????????5243904
_________________________________________________________________
batch_normalization_4?(Batch?(None,?16,?1024)??????????4096
_________________________________________________________________
max_pooling1d_4?(MaxPooling1?(None,?8,?1024)???????????0
_________________________________________________________________
dropout_4?(Dropout)??????????(None,?8,?1024)???????????0
_________________________________________________________________
conv1d_5?(Conv1D)????????????(None,?8,?1024)???????????5243904
_________________________________________________________________
batch_normalization_5?(Batch?(None,?8,?1024)???????????4096
_________________________________________________________________
global_max_pooling1d_1?(Glob?(None,?1024)??????????????0
_________________________________________________________________
dropout_5?(Dropout)??????????(None,?1024)??????????????0
_________________________________________________________________
dense_1?(Dense)??????????????(None,?1024)??????????????1049600
_________________________________________________________________
batch_normalization_6?(Batch?(None,?1024)??????????????4096
_________________________________________________________________
dropout_6?(Dropout)??????????(None,?1024)??????????????0
_________________________________________________________________
dense_2?(Dense)??????????????(None,?70)????????????????71750
=================================================================
Total?params:?22,205,622
Trainable?params:?22,193,334
Non-trainable?params:?12,288
_________________________________________________________________
在調(diào)整上面討論的超參數(shù)時(shí),關(guān)注模型的參數(shù)計(jì)數(shù)很有用,它大致代表模型的學(xué)習(xí)能力總量。
第4步:訓(xùn)練
現(xiàn)在將讓模型訓(xùn)練并使用“檢查點(diǎn)”來保存歷史和最佳模型,以便可以在訓(xùn)練期間的任何時(shí)候檢查進(jìn)度并使用最新模型進(jìn)行預(yù)測。
#?the?path?where?you?want?to?save?all?of?this?model's?files
MODEL_PATH?=?'/home/ubuntu/imgflip/models/conv_model'#?just?make?this?large?since?you?can?stop?training?at?any?time
NUM_EPOCHS?=?48#?batch?size?below?256?will?reduce?training?speed?since#?CPU?(non-GPU)?work?must?be?done?between?each?batch
BATCH_SIZE?=?256#?callback?to?save?the?model?whenever?validation?loss?improves
checkpointer?=?ModelCheckpoint(filepath=MODEL_PATH?+?'/model.h5',?verbose=1,?save_best_only=True)#?custom?callback?to?save?history?and?plots?after?each?epoch
history_checkpointer?=?util.SaveHistoryCheckpoint(MODEL_PATH)#?the?main?training?function?where?all?the?magic?happens!
history?=?model.fit(x_train,?y_train,?validation_data=(x_val,?y_val),?epochs=NUM_EPOCHS,?batch_size=BATCH_SIZE,?callbacks=[checkpointer,?history_checkpointer])
這就是坐下來觀看神奇數(shù)字在幾個(gè)小時(shí)內(nèi)上升的地方......
Train?on?44274928?samples,?validate?on?903569?samples
Epoch?1/4844274928/44274928?[==============================]?-?16756s?378us/step?-?loss:?1.5516?-?acc:?0.5443?-?val_loss:?1.3723?-?val_acc:?0.5891
Epoch?00001:?val_loss?improved?from?inf?to?1.37226,?saving?model?to?/home/ubuntu/imgflip/models/gen_2019_04_04_03_28_00/model.h5
Epoch?2/4844274928/44274928?[==============================]?-?16767s?379us/step?-?loss:?1.4424?-?acc:?0.5748?-?val_loss:?1.3416?-?val_acc:?0.5979
Epoch?00002:?val_loss?improved?from?1.37226?to?1.34157,?saving?model?to?/home/ubuntu/imgflip/models/gen_2019_04_04_03_28_00/model.h5
Epoch?3/4844274928/44274928?[==============================]?-?16798s?379us/step?-?loss:?1.4192?-?acc:?0.5815?-?val_loss:?1.3239?-?val_acc:?0.6036
Epoch?00003:?val_loss?improved?from?1.34157?to?1.32394,?saving?model?to?/home/ubuntu/imgflip/models/gen_2019_04_04_03_28_00/model.h5
Epoch?4/4844274928/44274928?[==============================]?-?16798s?379us/step?-?loss:?1.4015?-?acc:?0.5857?-?val_loss:?1.3127?-?val_acc:?0.6055
Epoch?00004:?val_loss?improved?from?1.32394?to?1.31274,?saving?model?to?/home/ubuntu/imgflip/models/gen_2019_04_04_03_28_00/model.h5
Epoch?5/481177344/44274928?[..............................]?-?ETA:?4:31:59?-?loss:?1.3993?-?acc:?0.5869
發(fā)現(xiàn)當(dāng)訓(xùn)練損失/準(zhǔn)確性比驗(yàn)證損失/準(zhǔn)確性更差時(shí),這表明該模型學(xué)習(xí)良好且不過度擬合。
如果使用AWS服務(wù)器進(jìn)行訓(xùn)練,發(fā)現(xiàn)最佳實(shí)例為p3.2xlarge。這使用了自2019年4月以來最快的GPU(Tesla V100),并且該實(shí)例只有一個(gè)GPU,因?yàn)槟P蜔o法非常有效地使用多個(gè)GPU。確實(shí)嘗試過使用Keras的multi_gpu_model,但它需要使批量大小更大,以實(shí)際實(shí)現(xiàn)速度提升,這可能會影響模型的收斂能力,即使使用4個(gè)GPU也幾乎不會快2倍。帶有4個(gè)GPU的p3.8xlarge的成本是4倍。
第5步:預(yù)測
現(xiàn)在有一個(gè)模型可以輸出meme標(biāo)題中下一個(gè)字符應(yīng)該出現(xiàn)的概率,但是如何使用它來實(shí)際創(chuàng)建一個(gè)完整的meme(表情包)標(biāo)題?
基本前提是用想要為其生成文本的Memes(表情包標(biāo)題)初始化一個(gè)字符串,然后model.predict為每個(gè)字符調(diào)用一次,直到模型輸出結(jié)束文本字符的|次數(shù)與文本框中的文本框一樣多次。對于上面看到的“X All The Y”memes,默認(rèn)的文本框數(shù)為2,初始文本為:
"000000061533??0??"
考慮到模型輸出的70個(gè)概率,嘗試了幾種不同的方法來選擇下一個(gè)字符:
每次選擇得分最高的角色。這會生成非常單一的結(jié)果,因?yàn)樗看螢榻o定的Meme(表情包)選擇完全相同的文本,并且它在Meme(表情包)中反復(fù)使用相同的單詞。”when you find out your friends are the best party”,它會一遍又一遍地吐出"X All The Y meme"。它喜歡在其他Meme(表情包)中使用"best"和"party"這兩個(gè)詞。
給每個(gè)角色一個(gè)被選中的概率等于模型給出的分?jǐn)?shù),但只有當(dāng)分?jǐn)?shù)高于某個(gè)閾值時(shí)(≥最高分的10%才適用于該模型)。這意味著可以選擇多個(gè)字符,但偏向更高的得分字符。這種方法成功地增加了多樣性,但較長的短語有時(shí)缺乏凝聚力。這是Futurama Frymemes中的一個(gè):"not sure if she said or just put out of my day"。
給每個(gè)角色選擇相同的概率,但前提是它的分?jǐn)?shù)足夠高(≥最高分的10%適用于此模型)。此外使用beam搜索在任何給定時(shí)間保留N個(gè)文本的運(yùn)行列表,并使用所有角色分?jǐn)?shù)的乘積而不是最后一個(gè)角色的分?jǐn)?shù)。這需要花費(fèi)N倍的時(shí)間來計(jì)算,但在某些情況下似乎可以提高句子的凝聚力。
這里選擇使用方法2,因?yàn)樗俣瓤?#xff0c;效果好。以下是一些隨機(jī)生成的例子:
?
在imgflip.com/ai-meme的48個(gè)Meme(表情包)中生成。
https://imgflip.com/ai-meme
使用方法2進(jìn)行運(yùn)行時(shí)預(yù)測的代碼如下。Github上的完整實(shí)現(xiàn)是一種通用的Beam搜索算法,因此只需將波束寬度增加到1以上即可啟用Beam搜索。
#?min?score?as?percentage?of?the?maximum?score,?not?absolute
MIN_SCORE?=?0.1
int_to_char?=?{v:?k?for?k,?v?in?char_to_int.items()}def?predict_meme_text(template_id,?num_boxes,?init_text?=?''):
??template_id?=?str(template_id).zfill(12)
??final_text?=?''for?char_count?in?range(len(init_text),?SEQUENCE_LENGTH):
????box_index?=?str(final_text.count('|'))
????texts?=?[template_id?+?'??'?+?box_index?+?'??'?+?final_text]
????sequences?=?util.texts_to_sequences(texts,?char_to_int)
????data?=?pad_sequences(sequences,?maxlen=SEQUENCE_LENGTH)
????predictions_list?=?model.predict(data)
????predictions?=?[]for?j?in?range(0,?len(predictions_list[0])):
??????predictions.append({'text':?final_text?+?int_to_char[j],'score':?predictions_list[0][j]
??????})
????predictions?=?sorted(predictions,?key=lambda?p:?p['score'],?reverse=True)
????top_predictions?=?[]
????top_score?=?predictions[0]['score']
????rand_int?=?random.randint(int(MIN_SCORE?*?1000),?1000)for?prediction?in?predictions:#?give?each?char?a?chance?of?being?chosen?based?on?its?scoreif?prediction['score']?>=?rand_int?/?1000?*?top_score:
????????top_predictions.append(prediction)
????random.shuffle(top_predictions)
????final_text?=?top_predictions[0]['text']if?char_count?>=?SEQUENCE_LENGTH?-?1?or?final_text.count('|')?==?num_boxes?-?1:return?final_text
在github中,該文檔對應(yīng)的代碼如下:
https://github.com/dylanwenzlau/ml-scripts/tree/master/meme_text_gen_convnet
推薦閱讀
“Keras之父發(fā)聲:TF 2.0 + Keras 深度學(xué)習(xí)必知的12件事”
關(guān)于圖書
《深度學(xué)習(xí)之TensorFlow:入門、原理與進(jìn)階實(shí)戰(zhàn)》和《Python帶我起飛——入門、進(jìn)階、商業(yè)實(shí)戰(zhàn)》兩本圖書是代碼醫(yī)生團(tuán)隊(duì)精心編著的 AI入門與提高的精品圖書。配套資源豐富:配套視頻、QQ讀者群、實(shí)例源碼、 配套論壇:http://bbs.aianaconda.com?。更多請見:https://www.aianaconda.com
總結(jié)
以上是生活随笔為你收集整理的tensorflow包_在Keras和Tensorflow中使用深度卷积网络生成Meme(表情包)文本的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python图形用户界面pyside_P
- 下一篇: 求实数的绝对值。_例谈六种有关绝对值问题