18-TFRecord 数据格式化存储工具(CDBmax 数据国度)
一、寫在前面
??這篇是拖了很久才正式寫的一篇,也算是工程化的應(yīng)用中比較重要的一個(gè)部分,所以在這里做一個(gè)簡(jiǎn)要的分享。TFRecord是TensorFlow官方推薦使用的數(shù)據(jù)格式化存儲(chǔ)工具,可以很大程度上提高TensorFlow訓(xùn)練過程中的IO效率。我們之前在做一些簡(jiǎn)單的訓(xùn)練的時(shí)候都是使用文本存儲(chǔ)的方式,比如之前所作的花卉分類使用的兩萬張圖片存儲(chǔ)之文本文件中大約耗費(fèi)了70個(gè)G的存儲(chǔ)空間,而且一次性讀入這么大的數(shù)據(jù)量很顯然是一個(gè)不現(xiàn)實(shí)的事情,就算數(shù)據(jù)部分成功讀入之后,模型所占內(nèi)存空間也許將無法滿足,就容易出現(xiàn)IOError的狀況,而出現(xiàn)了這種問題解決得辦法不外乎兩種方法,一種是加裝內(nèi)存條或者是加裝同型號(hào)顯卡,另外一種就是使用一種合理的IO讀寫的方式,來降低內(nèi)存或者顯存得開銷。首先第一種方法需要考慮的問題就比較多了,它不僅與個(gè)人的財(cái)力掛鉤,還與機(jī)器可拓展空間和真實(shí)空間掛鉤,所以學(xué)習(xí)一種優(yōu)化存儲(chǔ)空間的方式就顯得尤為重要。
二、相關(guān)原理
??1.TFRecord相對(duì)優(yōu)于一般方法的原因主要有以下幾點(diǎn):1.TFRecord內(nèi)部采用“Protocol Buffer”的二進(jìn)制數(shù)據(jù)編碼方案,其單個(gè)字節(jié)相對(duì)于漢字的“utf-8”編碼所占用的字節(jié)數(shù)要少的多,因而在生成一次TFRecord之后,模型訓(xùn)練過程中的數(shù)據(jù)讀取、加工處理的效率都會(huì)有很大的提升空間。2.利用了Threading(線程)和Queues(隊(duì)列)從TFRecord中分批次讀取數(shù)據(jù),這種方法可以實(shí)現(xiàn)一邊讀取數(shù)據(jù)至隊(duì)列,然后一邊使用隊(duì)首數(shù)據(jù)訓(xùn)練模型的目的。這樣就起到降低了內(nèi)存空間的高占用率的作用。
??2.簡(jiǎn)要介紹一下本篇博客的邏輯和任務(wù),本文主要編寫針對(duì)一個(gè)簡(jiǎn)單的分類數(shù)據(jù)集通過TFRecord的文件存儲(chǔ)方式存取至磁盤上,然后通過相應(yīng)讀寫方法,讀取TFRecord文件,并做簡(jiǎn)單的訓(xùn)練來熟悉整個(gè)流程。
三、相關(guān)代碼
??1.使用TFRecord的時(shí)候,數(shù)據(jù)單位一般是tf.train.Example或者是tf.train.SequenceExample,Example一般是用于處理數(shù)值、圖像大小固定的數(shù)據(jù),可使用該方法指定各特征數(shù)值的名稱和數(shù)據(jù)類型。
示例1:
??該代碼塊中int64_list(整數(shù)列表)可替換為BytesList(字符串列表)和FloatList(實(shí)數(shù)列表),也就是說Features支持存儲(chǔ)如上三種類型的數(shù)據(jù)。
??SequenceExample一般是用于處理文本、時(shí)間序列沒有固定長度的數(shù)據(jù),該部分在NLP中是比較常用的一種數(shù)據(jù)存儲(chǔ)方式。
示例2:
??2.寫入數(shù)據(jù)至文件
??首先創(chuàng)建一個(gè)協(xié)議內(nèi)存塊(Protocol Buffer),該協(xié)議內(nèi)存塊中將用于存放特征屬性[features]。
??然后將獲取到的數(shù)據(jù)填入到Example內(nèi)存協(xié)議塊中,再將協(xié)議內(nèi)存塊序列化為一個(gè)字符串并且通過tf.python_io.TFRecordWriter()方法寫入至TFRecords文件。
#IO寫入 def writer(train_file,test_file):#讀取data_X,data_Y數(shù)據(jù)集train_X,train_Y,test_X,test_Y=divide_data()#定義寫train_file IO對(duì)象writer1=tf.python_io.TFRecordWriter(train_file)for data_X,data_Y in zip(train_X,train_Y):#print(data_X,'*************')#print(data_Y,'*************')#轉(zhuǎn)換數(shù)據(jù)格式print(data_X.shape)print(data_Y.shape)data_X=data_X.astype(np.float32)data_Y=data_Y.astype(np.float32)mk_em=examples(data_X.tobytes(),data_Y.tobytes())writer1.write(mk_em.SerializeToString())writer1.close()writer2=tf.python_io.TFRecordWriter(test_file)for data_X,data_Y in zip(test_X,test_Y):data_X=data_X.astype(np.float32)data_Y=data_Y.astype(np.float32)mk_em=examples(data_X.tobytes(),data_Y.tobytes())writer2.write(mk_em.SerializeToString())writer2.close()if __name__=='__main__':writer('train_data.tfrecord','test_data.tfrecord')??運(yùn)行完整文件之后,將可以看到在同級(jí)目錄下面會(huì)生成兩個(gè)tfRecord文件,一個(gè)名為’train_data.tfrecord’,該文件用于模型訓(xùn)練,另一個(gè)名為’test_data.tfrecord’,該文件作為測(cè)試集而存在。
??3.從tfRecord文件中讀取數(shù)據(jù)
??首先獲取隊(duì)列,并對(duì)隊(duì)列中的內(nèi)存協(xié)議塊進(jìn)行讀取和解碼,然后將轉(zhuǎn)換之后的數(shù)據(jù)組合成一個(gè)batch的數(shù)據(jù),傳入至模型。用于優(yōu)化模型參數(shù)。代碼塊如下所示:
??4.構(gòu)建DNN網(wǎng)絡(luò)模型
??這里構(gòu)建了三隱藏層、一輸出層、一BN層以及兩層dropout層的DNN網(wǎng)絡(luò)模型結(jié)構(gòu),非線性變換中使用sigmoid函數(shù)對(duì)數(shù)據(jù)進(jìn)行非線性變換的處理,具體代碼如下:
??5.分批次讀取數(shù)據(jù)并訓(xùn)練DNN模型
#train def train():checkpoint_dir='./model'save_time=500training_epochs=100000display_time=5# 讀取訓(xùn)練集 TFRecord文件Tensor對(duì)象train_X, train_Y = read_tfrecord('train_data.tfrecord')#構(gòu)建返回的訓(xùn)練器train,loss,correct_rate,global_step=model(train_X,train_Y)# 讀取測(cè)試集 TFRecord文件Tensor對(duì)象test_X,test_Y=read_tfrecord('test_data.tfrecord',batch_size=1000)_,loss_test,correct_rate_test,_=model(test_X,test_Y)saver=tf.train.Saver(max_to_keep=2)with tf.Session(config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True)) as sess:sess.run(tf.global_variables_initializer())ckpt=Noneif True:#加載模型繼續(xù)訓(xùn)練ckpt=tf.train.latest_checkpoint(checkpoint_dir)if ckpt:print("load model …………")saver.restore(sess,ckpt)#開啟線程coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)#訓(xùn)練for epoch in range(sess.run(global_step),training_epochs):_,loss_,correct_rate_=sess.run([train,loss,correct_rate])if (epoch+1)%display_time==0:print('step:{},loss:{},correct_rate:{}'.format(epoch+1,loss_,correct_rate_))loss_test_,correct_rate_test_=sess.run([loss_test,correct_rate_test])print('testing:loss:{},correct_rate:{}'.format(loss_test_,correct_rate_test_))sess.run(tf.assign(global_step, epoch + 1))if (epoch+1)%save_time==0:print('save model …………')saver.save(sess,'./model/model.ckpt',global_step=global_step)coord.request_stop()coord.join(threads)if __name__=='__main__':train()三、廣而告之
當(dāng)你在進(jìn)行數(shù)據(jù)統(tǒng)計(jì)分析,模型建立遇到困難的時(shí)候,那么請(qǐng)點(diǎn)開這個(gè)鏈接吧:
https://shop163287636.taobao.com/?spm=a230r.7195193.1997079397.2.b79b4e98VwGtpt
四、總結(jié)
??1.使用該方法是具有一定局限性的,因?yàn)槠鋬H可以順序從tfRecord文件中讀取,所以需要對(duì)樣本的數(shù)據(jù)量有一定的要求,這個(gè)要求當(dāng)然是數(shù)據(jù)量越大越好,不然模型很容易過擬合。
??2.該部分完整代碼見 https://download.csdn.net/download/qq_37972530/10887231
總結(jié)
以上是生活随笔為你收集整理的18-TFRecord 数据格式化存储工具(CDBmax 数据国度)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: matlab无法用mcc,使用matla
- 下一篇: 一月英语