TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和 tf.data.Dataset机制
TensorFlow數(shù)據(jù)讀取機(jī)制:文件隊(duì)列 tf.train.slice_input_producer和tf.data.Dataset機(jī)制
? ? 之前寫了一篇博客,關(guān)于《Tensorflow生成自己的圖片數(shù)據(jù)集TFrecord》,項(xiàng)目做多了,你會(huì)發(fā)現(xiàn)將數(shù)據(jù)轉(zhuǎn)為TFrecord格式,實(shí)在是太麻煩了,靈活性太差!后面就總結(jié)一下TensorFlow數(shù)據(jù)讀取機(jī)制,主要還是介紹tf.data.Dataset的數(shù)據(jù)讀取機(jī)制(Pipeline機(jī)制)。
? ? TensorFlow數(shù)據(jù)讀取機(jī)制主要是兩種方法:
(1)一種是使用文件隊(duì)列方式,如使用slice_input_producer和string_input_producer;這種方法既可以將數(shù)據(jù)轉(zhuǎn)存為TFrecord數(shù)據(jù)格式,也可以直接讀取文件圖片數(shù)據(jù),當(dāng)然轉(zhuǎn)存為TFrecord數(shù)據(jù)格式進(jìn)行讀取,會(huì)更高效點(diǎn)
(2)另一種是TensorFlow 1.4版本后出現(xiàn)的tf.data.Dataset的數(shù)據(jù)讀取機(jī)制(Pipeline機(jī)制)。這是TensorFlow強(qiáng)烈推薦的方式,是一種更高效的讀取方式。使用tf.data.Dataset模塊的pipline機(jī)制,可實(shí)現(xiàn)CPU多線程處理輸入的數(shù)據(jù),如讀取圖片和圖片的一些的預(yù)處理,這樣GPU可以專注于訓(xùn)練過程,而CPU去準(zhǔn)備數(shù)據(jù)。
? ? ? 本博客Github源碼:https://github.com/PanJinquan/tensorflow-learning-tutorials?->tf_record_demo文件夾(覺得可以,還請(qǐng)給個(gè)“Star”哦
? ? ?之前專門寫了一篇博客關(guān)于《?Tensorflow生成自己的圖片數(shù)據(jù)集TFrecords(支持多標(biāo)簽label)》https://blog.csdn.net/guyuealian/article/details/80857228,主要實(shí)現(xiàn)的是使用自己的數(shù)據(jù)集制作TensorFlow的TFrecord數(shù)據(jù)格式。
目錄
目錄
TensorFlow數(shù)據(jù)讀取機(jī)制:文件隊(duì)列 tf.train.slice_input_producer和tf.data.Dataset機(jī)制
1. 文件隊(duì)列讀取方式:slice_input_producer和string_input_producer
1.1.生成圖片數(shù)據(jù)集TFrecords
(1)生成單個(gè)record文件 (單label)
(2)生成單個(gè)record文件 (多l(xiāng)abel)
(3)生成多個(gè)record文件的方法
1.2. 直接文件讀取方式?
2.tf.data.Dataset數(shù)據(jù)讀取機(jī)制:Pipeline機(jī)制
prefetch(必須放在最后)
map
repeat
實(shí)例代碼1:dataset.make_initializable_iterator()
實(shí)例代碼2:dataset.make_one_shot_iterator()
實(shí)例代碼3:產(chǎn)生用于訓(xùn)練的圖和label
實(shí)例代碼4:產(chǎn)生用于訓(xùn)練的原始圖和target目標(biāo)圖
實(shí)例代碼5: tf.data.Dataset.from_generator
3.?用Python循環(huán)產(chǎn)生批量數(shù)據(jù)batch
4.參考資料:
1. 文件隊(duì)列讀取方式:slice_input_producer和string_input_producer
? ? TensorFlow可以采用tf.train.slice_input_producer或者tf.train.string_input_producer兩種方法產(chǎn)生文件隊(duì)列,其區(qū)別就是:前者是輸入是tensor_list,因此,可以將多個(gè)list組合成一個(gè)tensorlist作為輸入;而后者只能是一個(gè)string_tensor了,例子如下:
image_dir ='path/to/image_dir/*.jpg'image_list = glob.glob(image_dir)label_list=...image_list = tf.convert_to_tensor(image_list, dtype=tf.string)# 可以將image_list,label_list多個(gè)list組合成一個(gè)tensor_listimage_que, label_que = tf.train.slice_input_producer([image_list,label_list], num_epochs=1)# 只能時(shí)string_tensor,所以不能組合多個(gè)listimage = tf.train.string_input_producer(image_list, num_epochs=1)1.1.生成圖片數(shù)據(jù)集TFrecords
? ? 假設(shè)train.txt保存圖片的路徑和標(biāo)簽信息,如下,以空格分割,第一項(xiàng)的圖片的路徑名,第二項(xiàng)是圖片對(duì)應(yīng)的labels
dog/1.jpg 0 dog/2.jpg 0 dog/3.jpg 0 dog/4.jpg 0 cat/1.jpg 1 cat/2.jpg 1 cat/3.jpg 1 cat/4.jpg 1? ? 這里提供三種方法將圖像數(shù)據(jù)轉(zhuǎn)存為TFrecords數(shù)據(jù)格式,當(dāng)然也包含TFrecords解析的方法,詳細(xì)的用法都會(huì)在函數(shù)參數(shù)說明,已經(jīng)封裝了很簡(jiǎn)單了,你只需要改變你圖片的路徑就可以。
- 生成單個(gè)record文件 (單label)
? ? 這種方法會(huì)將所有圖片數(shù)據(jù)和單labels轉(zhuǎn)存為一個(gè)record文件,合適單labels小批量的數(shù)據(jù)
- 生成單個(gè)record文件 (多l(xiāng)abel)
? ? 這種方法將所有圖片數(shù)據(jù)和多個(gè)labels轉(zhuǎn)存為一個(gè)record文件,合適多l(xiāng)abels的小批量的數(shù)據(jù)
- 生成多個(gè)record文件的方法
? ? 這種方法將圖片數(shù)據(jù)和labels,切分一個(gè)batch_size的大小,并轉(zhuǎn)存為多個(gè)record文件,合適大批量的數(shù)據(jù)
(1)生成單個(gè)record文件 (單label)
? ? ?下面是封裝好的py文件,可以直接生成單個(gè)record文件 ,當(dāng)然這里假設(shè)只有一個(gè)label情況。其中g(shù)et_batch_images函數(shù)會(huì)產(chǎn)生一個(gè)batch的數(shù)據(jù),這個(gè)batch的數(shù)據(jù)就可以用于CNN的網(wǎng)絡(luò)訓(xùn)練的數(shù)據(jù)。
# -*-coding: utf-8 -*- """@Project: create_tfrecord@File : create_tfrecord.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2018-07-27 17:19:54@desc : 將圖片數(shù)據(jù)保存為單個(gè)tfrecord文件 """##########################################################################import tensorflow as tf import numpy as np import os import cv2 import matplotlib.pyplot as plt import random from PIL import Image########################################################################## def _int64_feature(value):return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) # 生成字符串型的屬性 def _bytes_feature(value):return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) # 生成實(shí)數(shù)型的屬性 def float_list_feature(value):return tf.train.Feature(float_list=tf.train.FloatList(value=value))def get_example_nums(tf_records_filenames):'''統(tǒng)計(jì)tf_records圖像的個(gè)數(shù)(example)個(gè)數(shù):param tf_records_filenames: tf_records文件路徑:return:'''nums= 0for record in tf.python_io.tf_record_iterator(tf_records_filenames):nums += 1return numsdef show_image(title,image):'''顯示圖片:param title: 圖像標(biāo)題:param image: 圖像的數(shù)據(jù):return:'''# plt.figure("show_image")# print(image.dtype)plt.imshow(image)plt.axis('on') # 關(guān)掉坐標(biāo)軸為 offplt.title(title) # 圖像題目plt.show()def load_labels_file(filename,labels_num=1,shuffle=False):'''載圖txt文件,文件中每行為一個(gè)圖片信息,且以空格隔開:圖像路徑 標(biāo)簽1 標(biāo)簽2,如:test_image/1.jpg 0 2:param filename::param labels_num :labels個(gè)數(shù):param shuffle :是否打亂順序:return:images type->list:return:labels type->list'''images=[]labels=[]with open(filename) as f:lines_list=f.readlines()if shuffle:random.shuffle(lines_list)for lines in lines_list:line=lines.rstrip().split(' ')label=[]for i in range(labels_num):label.append(int(line[i+1]))images.append(line[0])labels.append(label)return images,labelsdef read_image(filename, resize_height, resize_width,normalization=False):'''讀取圖片數(shù)據(jù),默認(rèn)返回的是uint8,[0,255]:param filename::param resize_height::param resize_width::param normalization:是否歸一化到[0.,1.0]:return: 返回的圖片數(shù)據(jù)'''bgr_image = cv2.imread(filename)if len(bgr_image.shape)==2:#若是灰度圖則轉(zhuǎn)為三通道print("Warning:gray image",filename)bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)#將BGR轉(zhuǎn)為RGB# show_image(filename,rgb_image)# rgb_image=Image.open(filename)if resize_height>0 and resize_width>0:rgb_image=cv2.resize(rgb_image,(resize_width,resize_height))rgb_image=np.asanyarray(rgb_image)if normalization:# 不能寫成:rgb_image=rgb_image/255rgb_image=rgb_image/255.0# show_image("src resize image",image)return rgb_imagedef get_batch_images(images,labels,batch_size,labels_nums,one_hot=False,shuffle=False,num_threads=1):''':param images:圖像:param labels:標(biāo)簽:param batch_size::param labels_nums:標(biāo)簽個(gè)數(shù):param one_hot:是否將labels轉(zhuǎn)為one_hot的形式:param shuffle:是否打亂順序,一般train時(shí)shuffle=True,驗(yàn)證時(shí)shuffle=False:return:返回batch的images和labels'''min_after_dequeue = 200capacity = min_after_dequeue + 3 * batch_size # 保證capacity必須大于min_after_dequeue參數(shù)值if shuffle:images_batch, labels_batch = tf.train.shuffle_batch([images,labels],batch_size=batch_size,capacity=capacity,min_after_dequeue=min_after_dequeue,num_threads=num_threads)else:images_batch, labels_batch = tf.train.batch([images,labels],batch_size=batch_size,capacity=capacity,num_threads=num_threads)if one_hot:labels_batch = tf.one_hot(labels_batch, labels_nums, 1, 0)return images_batch,labels_batchdef read_records(filename,resize_height, resize_width,type=None):'''解析record文件:源文件的圖像數(shù)據(jù)是RGB,uint8,[0,255],一般作為訓(xùn)練數(shù)據(jù)時(shí),需要?dú)w一化到[0,1]:param filename::param resize_height::param resize_width::param type:選擇圖像數(shù)據(jù)的返回類型None:默認(rèn)將uint8-[0,255]轉(zhuǎn)為float32-[0,255]normalization:歸一化float32-[0,1]standardization:標(biāo)準(zhǔn)化float32-[0,1],再減均值中心化:return:'''# 創(chuàng)建文件隊(duì)列,不限讀取的數(shù)量filename_queue = tf.train.string_input_producer([filename])# create a reader from file queuereader = tf.TFRecordReader()# reader從文件隊(duì)列中讀入一個(gè)序列化的樣本_, serialized_example = reader.read(filename_queue)# get feature from serialized example# 解析符號(hào)化的樣本features = tf.parse_single_example(serialized_example,features={'image_raw': tf.FixedLenFeature([], tf.string),'height': tf.FixedLenFeature([], tf.int64),'width': tf.FixedLenFeature([], tf.int64),'depth': tf.FixedLenFeature([], tf.int64),'label': tf.FixedLenFeature([], tf.int64)})tf_image = tf.decode_raw(features['image_raw'], tf.uint8)#獲得圖像原始的數(shù)據(jù)tf_height = features['height']tf_width = features['width']tf_depth = features['depth']tf_label = tf.cast(features['label'], tf.int32)# PS:恢復(fù)原始圖像數(shù)據(jù),reshape的大小必須與保存之前的圖像shape一致,否則出錯(cuò)# tf_image=tf.reshape(tf_image, [-1]) # 轉(zhuǎn)換為行向量tf_image=tf.reshape(tf_image, [resize_height, resize_width, 3]) # 設(shè)置圖像的維度# 恢復(fù)數(shù)據(jù)后,才可以對(duì)圖像進(jìn)行resize_images:輸入uint->輸出float32# tf_image=tf.image.resize_images(tf_image,[224, 224])# [3]數(shù)據(jù)類型處理# 存儲(chǔ)的圖像類型為uint8,tensorflow訓(xùn)練時(shí)數(shù)據(jù)必須是tf.float32if type is None:tf_image = tf.cast(tf_image, tf.float32)elif type == 'normalization': # [1]若需要?dú)w一化請(qǐng)使用:# 僅當(dāng)輸入數(shù)據(jù)是uint8,才會(huì)歸一化[0,255]# tf_image = tf.cast(tf_image, dtype=tf.uint8)# tf_image = tf.image.convert_image_dtype(tf_image, tf.float32)tf_image = tf.cast(tf_image, tf.float32) * (1. / 255.0) # 歸一化elif type == 'standardization': # 標(biāo)準(zhǔn)化# tf_image = tf.cast(tf_image, dtype=tf.uint8)# tf_image = tf.image.per_image_standardization(tf_image) # 標(biāo)準(zhǔn)化(減均值除方差)# 若需要?dú)w一化,且中心化,假設(shè)均值為0.5,請(qǐng)使用:tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) - 0.5 # 中心化# 這里僅僅返回圖像和標(biāo)簽# return tf_image, tf_height,tf_width,tf_depth,tf_labelreturn tf_image,tf_labeldef create_records(image_dir,file, output_record_dir, resize_height, resize_width,shuffle,log=5):'''實(shí)現(xiàn)將圖像原始數(shù)據(jù),label,長,寬等信息保存為record文件注意:讀取的圖像數(shù)據(jù)默認(rèn)是uint8,再轉(zhuǎn)為tf的字符串型BytesList保存,解析請(qǐng)需要根據(jù)需要轉(zhuǎn)換類型:param image_dir:原始圖像的目錄:param file:輸入保存圖片信息的txt文件(image_dir+file構(gòu)成圖片的路徑):param output_record_dir:保存record文件的路徑:param resize_height::param resize_width:PS:當(dāng)resize_height或者resize_width=0是,不執(zhí)行resize:param shuffle:是否打亂順序:param log:log信息打印間隔'''# 加載文件,僅獲取一個(gè)labelimages_list, labels_list=load_labels_file(file,1,shuffle)writer = tf.python_io.TFRecordWriter(output_record_dir)for i, [image_name, labels] in enumerate(zip(images_list, labels_list)):image_path=os.path.join(image_dir,images_list[i])if not os.path.exists(image_path):print('Err:no image',image_path)continueimage = read_image(image_path, resize_height, resize_width)image_raw = image.tostring()if i%log==0 or i==len(images_list)-1:print('------------processing:%d-th------------' % (i))print('current image_path=%s' % (image_path),'shape:{}'.format(image.shape),'labels:{}'.format(labels))# 這里僅保存一個(gè)label,多l(xiāng)abel適當(dāng)增加"'label': _int64_feature(label)"項(xiàng)label=labels[0]example = tf.train.Example(features=tf.train.Features(feature={'image_raw': _bytes_feature(image_raw),'height': _int64_feature(image.shape[0]),'width': _int64_feature(image.shape[1]),'depth': _int64_feature(image.shape[2]),'label': _int64_feature(label)}))writer.write(example.SerializeToString())writer.close()def disp_records(record_file,resize_height, resize_width,show_nums=4):'''解析record文件,并顯示show_nums張圖片,主要用于驗(yàn)證生成record文件是否成功:param tfrecord_file: record文件路徑:return:'''# 讀取record函數(shù)tf_image, tf_label = read_records(record_file,resize_height,resize_width,type='normalization')# 顯示前4個(gè)圖片init_op = tf.initialize_all_variables()with tf.Session() as sess:sess.run(init_op)coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)for i in range(show_nums):image,label = sess.run([tf_image,tf_label]) # 在會(huì)話中取出image和label# image = tf_image.eval()# 直接從record解析的image是一個(gè)向量,需要reshape顯示# image = image.reshape([height,width,depth])print('shape:{},tpye:{},labels:{}'.format(image.shape,image.dtype,label))# pilimg = Image.fromarray(np.asarray(image_eval_reshape))# pilimg.show()show_image("image:%d"%(label),image)coord.request_stop()coord.join(threads)def batch_test(record_file,resize_height, resize_width):''':param record_file: record文件路徑:param resize_height::param resize_width::return::PS:image_batch, label_batch一般作為網(wǎng)絡(luò)的輸入'''# 讀取record函數(shù)tf_image,tf_label = read_records(record_file,resize_height,resize_width,type='normalization')image_batch, label_batch= get_batch_images(tf_image,tf_label,batch_size=4,labels_nums=5,one_hot=False,shuffle=False)init = tf.global_variables_initializer()with tf.Session() as sess: # 開始一個(gè)會(huì)話sess.run(init)coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(coord=coord)for i in range(4):# 在會(huì)話中取出images和labelsimages, labels = sess.run([image_batch, label_batch])# 這里僅顯示每個(gè)batch里第一張圖片show_image("image", images[0, :, :, :])print('shape:{},tpye:{},labels:{}'.format(images.shape,images.dtype,labels))# 停止所有線程coord.request_stop()coord.join(threads)if __name__ == '__main__':# 參數(shù)設(shè)置resize_height = 224 # 指定存儲(chǔ)圖片高度resize_width = 224 # 指定存儲(chǔ)圖片寬度shuffle=Truelog=5# 產(chǎn)生train.record文件image_dir='dataset/train'train_labels = 'dataset/train.txt' # 圖片路徑train_record_output = 'dataset/record/train.tfrecords'create_records(image_dir,train_labels, train_record_output, resize_height, resize_width,shuffle,log)train_nums=get_example_nums(train_record_output)print("save train example nums={}".format(train_nums))# 產(chǎn)生val.record文件image_dir='dataset/val'val_labels = 'dataset/val.txt' # 圖片路徑val_record_output = 'dataset/record/val.tfrecords'create_records(image_dir,val_labels, val_record_output, resize_height, resize_width,shuffle,log)val_nums=get_example_nums(val_record_output)print("save val example nums={}".format(val_nums))# 測(cè)試顯示函數(shù)# disp_records(train_record_output,resize_height, resize_width)batch_test(train_record_output,resize_height, resize_width)(2)生成單個(gè)record文件 (多l(xiāng)abel)
? ? 對(duì)于多l(xiāng)abel的情況,你可以在單label的基礎(chǔ)上增加多個(gè)“l(fā)abel': tf.FixedLenFeature([], tf.int64)“,但每次label個(gè)數(shù)不一樣時(shí),都需要修改,挺麻煩的。這里提供一個(gè)方法:label數(shù)據(jù)也可以像圖像數(shù)據(jù)那樣,轉(zhuǎn)為string類型來保存:labels_raw = np.asanyarray(labels,dtype=np.float32).tostring() ,解析時(shí)也跟圖像數(shù)據(jù)一樣進(jìn)行解析:tf_label = tf.decode_raw(features['labels'],tf.float32) ,這樣,不管多少個(gè)label,我們都可以保存為record文件了:
? ?多l(xiāng)abel的TXT文件:
0.jpg 0.33 0.55 1.jpg 0.42 0.73 2.jpg 0.16 0.75 3.jpg 0.78 0.66 4.jpg 0.46 0.59 5.jpg 0.46 0.09 6.jpg 0.89 0.93 7.jpg 0.42 0.82 8.jpg 0.39 0.76 9.jpg 0.46 0.40 # -*-coding: utf-8 -*- """@Project: create_tfrecord@File : create_tf_record_multi_label.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2018-07-27 17:19:54@desc : 將圖片數(shù)據(jù),多l(xiāng)abel,保存為單個(gè)tfrecord文件 """##########################################################################import tensorflow as tf import numpy as np import os import cv2 import matplotlib.pyplot as plt import random from PIL import Image########################################################################## def _int64_feature(value):return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))def _float_feature(value):return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))# 生成字符串型的屬性 def _bytes_feature(value):return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) # 生成實(shí)數(shù)型的屬性 def float_list_feature(value):return tf.train.Feature(float_list=tf.train.FloatList(value=value))def get_example_nums(tf_records_filenames):'''統(tǒng)計(jì)tf_records圖像的個(gè)數(shù)(example)個(gè)數(shù):param tf_records_filenames: tf_records文件路徑:return:'''nums= 0for record in tf.python_io.tf_record_iterator(tf_records_filenames):nums += 1return numsdef show_image(title,image):'''顯示圖片:param title: 圖像標(biāo)題:param image: 圖像的數(shù)據(jù):return:'''# plt.figure("show_image")# print(image.dtype)plt.imshow(image)plt.axis('on') # 關(guān)掉坐標(biāo)軸為 offplt.title(title) # 圖像題目plt.show()def load_labels_file(filename,labels_num=1,shuffle=False):'''載圖txt文件,文件中每行為一個(gè)圖片信息,且以空格隔開:圖像路徑 標(biāo)簽1 標(biāo)簽2,如:test_image/1.jpg 0 2:param filename::param labels_num :labels個(gè)數(shù):param shuffle :是否打亂順序:return:images type->list:return:labels type->list'''images=[]labels=[]with open(filename) as f:lines_list=f.readlines()if shuffle:random.shuffle(lines_list)for lines in lines_list:line=lines.rstrip().split(' ')label=[]for i in range(labels_num):label.append(float(line[i+1]))images.append(line[0])labels.append(label)return images,labelsdef read_image(filename, resize_height, resize_width,normalization=False):'''讀取圖片數(shù)據(jù),默認(rèn)返回的是uint8,[0,255]:param filename::param resize_height::param resize_width::param normalization:是否歸一化到[0.,1.0]:return: 返回的圖片數(shù)據(jù)'''bgr_image = cv2.imread(filename)if len(bgr_image.shape)==2:#若是灰度圖則轉(zhuǎn)為三通道print("Warning:gray image",filename)bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)#將BGR轉(zhuǎn)為RGB# show_image(filename,rgb_image)# rgb_image=Image.open(filename)if resize_height>0 and resize_width>0:rgb_image=cv2.resize(rgb_image,(resize_width,resize_height))rgb_image=np.asanyarray(rgb_image)if normalization:# 不能寫成:rgb_image=rgb_image/255rgb_image=rgb_image/255.0# show_image("src resize image",image)return rgb_imagedef get_batch_images(images,labels,batch_size,labels_nums,one_hot=False,shuffle=False,num_threads=1):''':param images:圖像:param labels:標(biāo)簽:param batch_size::param labels_nums:標(biāo)簽個(gè)數(shù):param one_hot:是否將labels轉(zhuǎn)為one_hot的形式:param shuffle:是否打亂順序,一般train時(shí)shuffle=True,驗(yàn)證時(shí)shuffle=False:return:返回batch的images和labels'''min_after_dequeue = 200capacity = min_after_dequeue + 3 * batch_size # 保證capacity必須大于min_after_dequeue參數(shù)值if shuffle:images_batch, labels_batch = tf.train.shuffle_batch([images,labels],batch_size=batch_size,capacity=capacity,min_after_dequeue=min_after_dequeue,num_threads=num_threads)else:images_batch, labels_batch = tf.train.batch([images,labels],batch_size=batch_size,capacity=capacity,num_threads=num_threads)if one_hot:labels_batch = tf.one_hot(labels_batch, labels_nums, 1, 0)return images_batch,labels_batchdef read_records(filename,resize_height, resize_width,type=None):'''解析record文件:源文件的圖像數(shù)據(jù)是RGB,uint8,[0,255],一般作為訓(xùn)練數(shù)據(jù)時(shí),需要?dú)w一化到[0,1]:param filename::param resize_height::param resize_width::param type:選擇圖像數(shù)據(jù)的返回類型None:默認(rèn)將uint8-[0,255]轉(zhuǎn)為float32-[0,255]normalization:歸一化float32-[0,1]standardization:歸一化float32-[0,1],再減均值中心化:return:'''# 創(chuàng)建文件隊(duì)列,不限讀取的數(shù)量filename_queue = tf.train.string_input_producer([filename])# create a reader from file queuereader = tf.TFRecordReader()# reader從文件隊(duì)列中讀入一個(gè)序列化的樣本_, serialized_example = reader.read(filename_queue)# get feature from serialized example# 解析符號(hào)化的樣本features = tf.parse_single_example(serialized_example,features={'image_raw': tf.FixedLenFeature([], tf.string),'height': tf.FixedLenFeature([], tf.int64),'width': tf.FixedLenFeature([], tf.int64),'depth': tf.FixedLenFeature([], tf.int64),'labels': tf.FixedLenFeature([], tf.string)})tf_image = tf.decode_raw(features['image_raw'], tf.uint8)#獲得圖像原始的數(shù)據(jù)tf_height = features['height']tf_width = features['width']tf_depth = features['depth']# tf_label = tf.cast(features['labels'], tf.float32)tf_label = tf.decode_raw(features['labels'],tf.float32)# PS:恢復(fù)原始圖像數(shù)據(jù),reshape的大小必須與保存之前的圖像shape一致,否則出錯(cuò)# tf_image=tf.reshape(tf_image, [-1]) # 轉(zhuǎn)換為行向量tf_image=tf.reshape(tf_image, [resize_height, resize_width, 3]) # 設(shè)置圖像的維度tf_label=tf.reshape(tf_label, [2]) # 設(shè)置圖像的維度# 恢復(fù)數(shù)據(jù)后,才可以對(duì)圖像進(jìn)行resize_images:輸入uint->輸出float32# tf_image=tf.image.resize_images(tf_image,[224, 224])# [3]數(shù)據(jù)類型處理# 存儲(chǔ)的圖像類型為uint8,tensorflow訓(xùn)練時(shí)數(shù)據(jù)必須是tf.float32if type is None:tf_image = tf.cast(tf_image, tf.float32)elif type == 'normalization': # [1]若需要?dú)w一化請(qǐng)使用:# 僅當(dāng)輸入數(shù)據(jù)是uint8,才會(huì)歸一化[0,255]# tf_image = tf.cast(tf_image, dtype=tf.uint8)# tf_image = tf.image.convert_image_dtype(tf_image, tf.float32)tf_image = tf.cast(tf_image, tf.float32) * (1. / 255.0) # 歸一化elif type == 'standardization': # 標(biāo)準(zhǔn)化# tf_image = tf.cast(tf_image, dtype=tf.uint8)# tf_image = tf.image.per_image_standardization(tf_image) # 標(biāo)準(zhǔn)化(減均值除方差)# 若需要?dú)w一化,且中心化,假設(shè)均值為0.5,請(qǐng)使用:tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) - 0.5 # 中心化# 這里僅僅返回圖像和標(biāo)簽# return tf_image, tf_height,tf_width,tf_depth,tf_labelreturn tf_image,tf_labeldef create_records(image_dir,file, output_record_dir, resize_height, resize_width,shuffle,log=5):'''實(shí)現(xiàn)將圖像原始數(shù)據(jù),label,長,寬等信息保存為record文件注意:讀取的圖像數(shù)據(jù)默認(rèn)是uint8,再轉(zhuǎn)為tf的字符串型BytesList保存,解析請(qǐng)需要根據(jù)需要轉(zhuǎn)換類型:param image_dir:原始圖像的目錄:param file:輸入保存圖片信息的txt文件(image_dir+file構(gòu)成圖片的路徑):param output_record_dir:保存record文件的路徑:param resize_height::param resize_width:PS:當(dāng)resize_height或者resize_width=0是,不執(zhí)行resize:param shuffle:是否打亂順序:param log:log信息打印間隔'''# 加載文件,僅獲取一個(gè)labellabels_num=2images_list, labels_list=load_labels_file(file,labels_num,shuffle)writer = tf.python_io.TFRecordWriter(output_record_dir)for i, [image_name, labels] in enumerate(zip(images_list, labels_list)):image_path=os.path.join(image_dir,images_list[i])if not os.path.exists(image_path):print('Err:no image',image_path)continueimage = read_image(image_path, resize_height, resize_width)image_raw = image.tostring()if i%log==0 or i==len(images_list)-1:print('------------processing:%d-th------------' % (i))print('current image_path=%s' % (image_path),'shape:{}'.format(image.shape),'labels:{}'.format(labels))# 這里僅保存一個(gè)label,多l(xiāng)abel適當(dāng)增加"'label': _int64_feature(label)"項(xiàng)# label=labels[0]# labels_raw="0.12,0,15"labels_raw = np.asanyarray(labels,dtype=np.float32).tostring()example = tf.train.Example(features=tf.train.Features(feature={'image_raw': _bytes_feature(image_raw),'height': _int64_feature(image.shape[0]),'width': _int64_feature(image.shape[1]),'depth': _int64_feature(image.shape[2]),'labels': _bytes_feature(labels_raw),}))writer.write(example.SerializeToString())writer.close()def disp_records(record_file,resize_height, resize_width,show_nums=4):'''解析record文件,并顯示show_nums張圖片,主要用于驗(yàn)證生成record文件是否成功:param tfrecord_file: record文件路徑:return:'''# 讀取record函數(shù)tf_image, tf_label = read_records(record_file,resize_height,resize_width,type='normalization')# 顯示前4個(gè)圖片init_op = tf.initialize_all_variables()with tf.Session() as sess:sess.run(init_op)coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)for i in range(show_nums):image,label = sess.run([tf_image,tf_label]) # 在會(huì)話中取出image和label# image = tf_image.eval()# 直接從record解析的image是一個(gè)向量,需要reshape顯示# image = image.reshape([height,width,depth])print('shape:{},tpye:{},labels:{}'.format(image.shape,image.dtype,label))# pilimg = Image.fromarray(np.asarray(image_eval_reshape))# pilimg.show()show_image("image:{}".format(label),image)coord.request_stop()coord.join(threads)def batch_test(record_file,resize_height, resize_width):''':param record_file: record文件路徑:param resize_height::param resize_width::return::PS:image_batch, label_batch一般作為網(wǎng)絡(luò)的輸入'''# 讀取record函數(shù)tf_image,tf_label = read_records(record_file,resize_height,resize_width,type='normalization')image_batch, label_batch= get_batch_images(tf_image,tf_label,batch_size=4,labels_nums=2,one_hot=False,shuffle=True)init = tf.global_variables_initializer()with tf.Session() as sess: # 開始一個(gè)會(huì)話sess.run(init)coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(coord=coord)for i in range(4):# 在會(huì)話中取出images和labelsimages, labels = sess.run([image_batch, label_batch])# 這里僅顯示每個(gè)batch里第一張圖片show_image("image", images[0, :, :, :])print('shape:{},tpye:{},labels:{}'.format(images.shape,images.dtype,labels))# 停止所有線程coord.request_stop()coord.join(threads)if __name__ == '__main__':# 參數(shù)設(shè)置resize_height = 224 # 指定存儲(chǔ)圖片高度resize_width = 224 # 指定存儲(chǔ)圖片寬度shuffle=Truelog=1000# 產(chǎn)生train.record文件image_dir='dataset_regression/images'train_labels = 'dataset_regression/train.txt' # 圖片路徑train_record_output = 'dataset_regression/record/train.tfrecords'create_records(image_dir,train_labels, train_record_output, resize_height, resize_width,shuffle,log)train_nums=get_example_nums(train_record_output)print("save train example nums={}".format(train_nums))# 測(cè)試顯示函數(shù)# disp_records(train_record_output,resize_height, resize_width)# 產(chǎn)生val.record文件image_dir='dataset_regression/images'val_labels = 'dataset_regression/val.txt' # 圖片路徑val_record_output = 'dataset_regression/record/val.tfrecords'create_records(image_dir,val_labels, val_record_output, resize_height, resize_width,shuffle,log)val_nums=get_example_nums(val_record_output)print("save val example nums={}".format(val_nums))## # 測(cè)試顯示函數(shù)# # disp_records(train_record_output,resize_height, resize_width)# batch_test(val_record_output,resize_height, resize_width)(3)生成多個(gè)record文件的方法
? ? ? 上述該代碼只能保存為單個(gè)record文件,當(dāng)圖片數(shù)據(jù)很多時(shí)候,會(huì)導(dǎo)致單個(gè)record文件超級(jí)巨大的情況,解決方法就是,將數(shù)據(jù)分成多個(gè)record文件保存,讀取時(shí),只需要將多個(gè)record文件的路徑列表交給“tf.train.string_input_producer”??梢栽O(shè)置參數(shù)batchSize的大小,比如batchSize=2000,表示每2000張圖片保存為一個(gè)*.tfrecords,這樣可以避免單個(gè)record文件過大的情況。
? ? ? 完整代碼如下:
# -*-coding: utf-8 -*- """@Project: tf_record_demo@File : tf_record_batchSize.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2018-07-27 17:19:54@desc : 將圖片數(shù)據(jù)保存為多個(gè)record文件 """##########################################################################import tensorflow as tf import numpy as np import os import cv2 import math import matplotlib.pyplot as plt import random from PIL import Image########################################################################## def _int64_feature(value):return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) # 生成字符串型的屬性 def _bytes_feature(value):return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) # 生成實(shí)數(shù)型的屬性 def float_list_feature(value):return tf.train.Feature(float_list=tf.train.FloatList(value=value))def show_image(title,image):'''顯示圖片:param title: 圖像標(biāo)題:param image: 圖像的數(shù)據(jù):return:'''# plt.figure("show_image")# print(image.dtype)plt.imshow(image)plt.axis('on') # 關(guān)掉坐標(biāo)軸為 offplt.title(title) # 圖像題目plt.show()def load_labels_file(filename,labels_num=1):'''載圖txt文件,文件中每行為一個(gè)圖片信息,且以空格隔開:圖像路徑 標(biāo)簽1 標(biāo)簽2,如:test_image/1.jpg 0 2:param filename::param labels_num :labels個(gè)數(shù):return:images type->list:return:labels type->list'''images=[]labels=[]with open(filename) as f:for lines in f.readlines():line=lines.rstrip().split(' ')label=[]for i in range(labels_num):label.append(int(line[i+1]))images.append(line[0])labels.append(label)return images,labelsdef read_image(filename, resize_height, resize_width):'''讀取圖片數(shù)據(jù),默認(rèn)返回的是uint8,[0,255]:param filename::param resize_height::param resize_width::return: 返回的圖片數(shù)據(jù)是uint8,[0,255]'''bgr_image = cv2.imread(filename)if len(bgr_image.shape)==2:#若是灰度圖則轉(zhuǎn)為三通道print("Warning:gray image",filename)bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)#將BGR轉(zhuǎn)為RGB# show_image(filename,rgb_image)# rgb_image=Image.open(filename)if resize_height>0 and resize_width>0:rgb_image=cv2.resize(rgb_image,(resize_width,resize_height))rgb_image=np.asanyarray(rgb_image)# show_image("src resize image",image)return rgb_imagedef create_records(image_dir,file, record_txt_path, batchSize,resize_height, resize_width):'''實(shí)現(xiàn)將圖像原始數(shù)據(jù),label,長,寬等信息保存為record文件注意:讀取的圖像數(shù)據(jù)默認(rèn)是uint8,再轉(zhuǎn)為tf的字符串型BytesList保存,解析請(qǐng)需要根據(jù)需要轉(zhuǎn)換類型:param image_dir:原始圖像的目錄:param file:輸入保存圖片信息的txt文件(image_dir+file構(gòu)成圖片的路徑):param output_record_txt_dir:保存record文件的路徑:param batchSize: 每batchSize個(gè)圖片保存一個(gè)*.tfrecords,避免單個(gè)文件過大:param resize_height::param resize_width:PS:當(dāng)resize_height或者resize_width=0是,不執(zhí)行resize'''if os.path.exists(record_txt_path):os.remove(record_txt_path)setname, ext = record_txt_path.split('.')# 加載文件,僅獲取一個(gè)labelimages_list, labels_list=load_labels_file(file,1)sample_num = len(images_list)# 打亂樣本的數(shù)據(jù)# random.shuffle(labels_list)batchNum = int(math.ceil(1.0 * sample_num / batchSize))for i in range(batchNum):start = i * batchSizeend = min((i + 1) * batchSize, sample_num)batch_images = images_list[start:end]batch_labels = labels_list[start:end]# 逐個(gè)保存*.tfrecords文件filename = setname + '{0}.tfrecords'.format(i)print('save:%s' % (filename))writer = tf.python_io.TFRecordWriter(filename)for i, [image_name, labels] in enumerate(zip(batch_images, batch_labels)):image_path=os.path.join(image_dir,batch_images[i])if not os.path.exists(image_path):print('Err:no image',image_path)continueimage = read_image(image_path, resize_height, resize_width)image_raw = image.tostring()print('image_path=%s,shape:( %d, %d, %d)' % (image_path,image.shape[0], image.shape[1], image.shape[2]),'labels:',labels)# 這里僅保存一個(gè)label,多l(xiāng)abel適當(dāng)增加"'label': _int64_feature(label)"項(xiàng)label=labels[0]example = tf.train.Example(features=tf.train.Features(feature={'image_raw': _bytes_feature(image_raw),'height': _int64_feature(image.shape[0]),'width': _int64_feature(image.shape[1]),'depth': _int64_feature(image.shape[2]),'label': _int64_feature(label)}))writer.write(example.SerializeToString())writer.close()# 用txt保存*.tfrecords文件列表# record_list='{}.txt'.format(setname)with open(record_txt_path, 'a') as f:f.write(filename + '\n')def read_records(filename,resize_height, resize_width):'''解析record文件:param filename:保存*.tfrecords文件的txt文件路徑:return:'''# 讀取txt中所有*.tfrecords文件with open(filename, 'r') as f:lines = f.readlines()files_list=[]for line in lines:files_list.append(line.rstrip())# 創(chuàng)建文件隊(duì)列,不限讀取的數(shù)量filename_queue = tf.train.string_input_producer(files_list,shuffle=False)# create a reader from file queuereader = tf.TFRecordReader()# reader從文件隊(duì)列中讀入一個(gè)序列化的樣本_, serialized_example = reader.read(filename_queue)# get feature from serialized example# 解析符號(hào)化的樣本features = tf.parse_single_example(serialized_example,features={'image_raw': tf.FixedLenFeature([], tf.string),'height': tf.FixedLenFeature([], tf.int64),'width': tf.FixedLenFeature([], tf.int64),'depth': tf.FixedLenFeature([], tf.int64),'label': tf.FixedLenFeature([], tf.int64)})tf_image = tf.decode_raw(features['image_raw'], tf.uint8)#獲得圖像原始的數(shù)據(jù)tf_height = features['height']tf_width = features['width']tf_depth = features['depth']tf_label = tf.cast(features['label'], tf.int32)# tf_image=tf.reshape(tf_image, [-1]) # 轉(zhuǎn)換為行向量tf_image=tf.reshape(tf_image, [resize_height, resize_width, 3]) # 設(shè)置圖像的維度# 存儲(chǔ)的圖像類型為uint8,這里需要將類型轉(zhuǎn)為tf.float32# tf_image = tf.cast(tf_image, tf.float32)# [1]若需要?dú)w一化請(qǐng)使用:tf_image = tf.image.convert_image_dtype(tf_image, tf.float32)# 歸一化# tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) # 歸一化# [2]若需要?dú)w一化,且中心化,假設(shè)均值為0.5,請(qǐng)使用:# tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) - 0.5 #中心化return tf_image, tf_height,tf_width,tf_depth,tf_labeldef disp_records(record_file,resize_height, resize_width,show_nums=4):'''解析record文件,并顯示show_nums張圖片,主要用于驗(yàn)證生成record文件是否成功:param tfrecord_file: record文件路徑:param resize_height::param resize_width::param show_nums: 默認(rèn)顯示前四張照片:return:'''tf_image, tf_height, tf_width, tf_depth, tf_label = read_records(record_file,resize_height, resize_width) # 讀取函數(shù)# 顯示前show_nums個(gè)圖片init_op = tf.initialize_all_variables()with tf.Session() as sess:sess.run(init_op)coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)for i in range(show_nums):image,height,width,depth,label = sess.run([tf_image,tf_height,tf_width,tf_depth,tf_label]) # 在會(huì)話中取出image和label# image = tf_image.eval()# 直接從record解析的image是一個(gè)向量,需要reshape顯示# image = image.reshape([height,width,depth])print('shape:',image.shape,'label:',label)# pilimg = Image.fromarray(np.asarray(image_eval_reshape))# pilimg.show()show_image("image:%d"%(label),image)coord.request_stop()coord.join(threads)def batch_test(record_file,resize_height, resize_width):''':param record_file: record文件路徑:param resize_height::param resize_width::return::PS:image_batch, label_batch一般作為網(wǎng)絡(luò)的輸入'''tf_image,tf_height,tf_width,tf_depth,tf_label = read_records(record_file,resize_height, resize_width) # 讀取函數(shù)# 使用shuffle_batch可以隨機(jī)打亂輸入:# shuffle_batch用法:https://blog.csdn.net/u013555719/article/details/77679964min_after_dequeue = 100#該值越大,數(shù)據(jù)越亂,必須小于capacitybatch_size = 4# capacity = (min_after_dequeue + (num_threads + a small safety margin?batchsize)capacity = min_after_dequeue + 3 * batch_size#容量:一個(gè)整數(shù),隊(duì)列中的最大的元素?cái)?shù)image_batch, label_batch = tf.train.shuffle_batch([tf_image, tf_label],batch_size=batch_size,capacity=capacity,min_after_dequeue=min_after_dequeue)init = tf.global_variables_initializer()with tf.Session() as sess: # 開始一個(gè)會(huì)話sess.run(init)coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(coord=coord)for i in range(4):# 在會(huì)話中取出images和labelsimages, labels = sess.run([image_batch, label_batch])# 這里僅顯示每個(gè)batch里第一張圖片show_image("image", images[0, :, :, :])print(images.shape, labels)# 停止所有線程coord.request_stop()coord.join(threads)if __name__ == '__main__':# 參數(shù)設(shè)置image_dir='dataset/train'train_file = 'dataset/train.txt' # 圖片路徑output_record_txt = 'dataset/record/record.txt'#指定保存record的文件列表resize_height = 224 # 指定存儲(chǔ)圖片高度resize_width = 224 # 指定存儲(chǔ)圖片寬度batchSize=8000 #batchSize一般設(shè)置為8000,即每batchSize張照片保存為一個(gè)record文件# 產(chǎn)生record文件create_records(image_dir=image_dir,file=train_file,record_txt_path=output_record_txt,batchSize=batchSize,resize_height=resize_height,resize_width=resize_width)# 測(cè)試顯示函數(shù)disp_records(output_record_txt,resize_height, resize_width)# batch_test(output_record_txt,resize_height, resize_width)1.2. 直接文件讀取方式?
? ? 上面介紹的是如何將數(shù)據(jù)轉(zhuǎn)存為TFrecord文件,訓(xùn)練時(shí)再解析TFrecord。這種轉(zhuǎn)存為TFrecord數(shù)據(jù)格式的方法,雖然高效,但也喪失了靈活性,特別是新增數(shù)據(jù)或者刪除相關(guān)數(shù)據(jù)時(shí),這時(shí)就不得不重新制作TFrecord數(shù)據(jù)了。這就挺麻煩啦,如果不想轉(zhuǎn)為TFrecord文件,可以直接讀取圖像文件進(jìn)行訓(xùn)練。
? ? 這種方法比較簡(jiǎn)單,靈活性很強(qiáng),但效率很低,因?yàn)槊看蔚?xùn)練,GPU/CPU都要等待數(shù)據(jù)讀取I/O操作,圖像文件讀取以及預(yù)處理過程本身就很耗時(shí),甚至比你迭代一次網(wǎng)絡(luò)還耗時(shí)。解決的方法,就是采用tf.data.Dataset數(shù)據(jù)讀取機(jī)制。
? ? 直接文件讀取方式的完整代碼可以參考如下:
? ? 假設(shè)我們有train.txt的文件數(shù)據(jù)如下:
0.jpg 0
1.jpg 0
2.jpg 0
3.jpg 0
4.jpg 0
5.jpg 1
6.jpg 1
7.jpg 1
8.jpg 1
9.jpg 1
? ? 可以使用下面的方法直接讀取圖像數(shù)據(jù),并產(chǎn)生一個(gè)batch的訓(xùn)練數(shù)據(jù):
# -*-coding: utf-8 -*- """@Project: tf_record_demo@File : tf_read_files.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2018-10-14 10:44:06 """ import tensorflow as tf import glob import numpy as np import os import matplotlib.pyplot as pltimport cv2 def show_image(title, image):'''顯示圖片:param title: 圖像標(biāo)題:param image: 圖像的數(shù)據(jù):return:'''# plt.imshow(image, cmap='gray')plt.imshow(image)plt.axis('on') # 關(guān)掉坐標(biāo)軸為 offplt.title(title) # 圖像題目plt.show()def tf_read_image(filename, resize_height, resize_width):'''讀取圖片:param filename::param resize_height::param resize_width::return:'''image_string = tf.read_file(filename)image_decoded = tf.image.decode_jpeg(image_string, channels=3)# tf_image = tf.cast(image_decoded, tf.float32)tf_image = tf.cast(image_decoded, tf.float32) * (1. / 255.0) # 歸一化if resize_width>0 and resize_height>0:tf_image = tf.image.resize_images(tf_image, [resize_height, resize_width])# tf_image = tf.image.per_image_standardization(tf_image) # 標(biāo)準(zhǔn)化[0,1](減均值除方差)return tf_imagedef get_batch_images(image_list, label_list, batch_size, labels_nums, resize_height, resize_width, one_hot=False, shuffle=False):''':param image_list:圖像:param label_list:標(biāo)簽:param batch_size::param labels_nums:標(biāo)簽個(gè)數(shù):param one_hot:是否將labels轉(zhuǎn)為one_hot的形式:param shuffle:是否打亂順序,一般train時(shí)shuffle=True,驗(yàn)證時(shí)shuffle=False:return:返回batch的images和labels'''# 生成隊(duì)列image_que, tf_label = tf.train.slice_input_producer([image_list, label_list], shuffle=shuffle)tf_image = tf_read_image(image_que, resize_height, resize_width)min_after_dequeue = 200capacity = min_after_dequeue + 3 * batch_size # 保證capacity必須大于min_after_dequeue參數(shù)值if shuffle:images_batch, labels_batch = tf.train.shuffle_batch([tf_image, tf_label],batch_size=batch_size,capacity=capacity,min_after_dequeue=min_after_dequeue)else:images_batch, labels_batch = tf.train.batch([tf_image, tf_label],batch_size=batch_size,capacity=capacity)if one_hot:labels_batch = tf.one_hot(labels_batch, labels_nums, 1, 0)return images_batch, labels_batchdef load_image_labels(filename):'''載圖txt文件,文件中每行為一個(gè)圖片信息,且以空格隔開:圖像路徑 標(biāo)簽1,如:test_image/1.jpg 0:param filename::return:'''images_list = []labels_list = []with open(filename) as f:lines = f.readlines()for line in lines:# rstrip:用來去除結(jié)尾字符、空白符(包括\n、\r、\t、' ',即:換行、回車、制表符、空格)content = line.rstrip().split(' ')name = content[0]labels = []for value in content[1:]:labels.append(int(value))images_list.append(name)labels_list.append(labels)return images_list, labels_listdef batch_test(filename, image_dir):labels_nums = 2batch_size = 4resize_height = 200resize_width = 200image_list, label_list = load_image_labels(filename)image_list=[os.path.join(image_dir,image_name) for image_name in image_list]image_batch, labels_batch = get_batch_images(image_list=image_list,label_list=label_list,batch_size=batch_size,labels_nums=labels_nums,resize_height=resize_height, resize_width=resize_width,one_hot=False, shuffle=True)with tf.Session() as sess: # 開始一個(gè)會(huì)話sess.run(tf.global_variables_initializer())coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(coord=coord)for i in range(4):# 在會(huì)話中取出images和labelsimages, labels = sess.run([image_batch, labels_batch])# 這里僅顯示每個(gè)batch里第一張圖片show_image("image", images[0, :, :, :])print('shape:{},tpye:{},labels:{}'.format(images.shape, images.dtype, labels))# 停止所有線程coord.request_stop()coord.join(threads)if __name__ == "__main__":image_dir = "./dataset/train"filename = "./dataset/train.txt"batch_test(filename, image_dir)2.tf.data.Dataset數(shù)據(jù)讀取機(jī)制:Pipeline機(jī)制
? ? 要執(zhí)行訓(xùn)練步驟,您必須首先提取并轉(zhuǎn)換訓(xùn)練數(shù)據(jù),然后將其提供給在加速器上運(yùn)行的模型。然而,在一個(gè)簡(jiǎn)單的同步執(zhí)行中,當(dāng) CPU 正在準(zhǔn)備數(shù)據(jù)時(shí),加速器則處于空閑狀態(tài)。相反,當(dāng)加速器正在訓(xùn)練模型時(shí),CPU 則處于空閑狀態(tài)。因此,訓(xùn)練步驟時(shí)間是 CPU 預(yù)處理時(shí)間和加速器訓(xùn)練時(shí)間的總和。
prefetch(必須放在最后)
? ? ?TensorFlow引入了tf.data.Dataset模塊,使其數(shù)據(jù)讀入的操作變得更為方便,而支持多線程(進(jìn)程)的操作,也在效率上獲得了一定程度的提高。使用tf.data.Dataset模塊的pipline機(jī)制,可實(shí)現(xiàn)CPU多線程處理輸入的數(shù)據(jù),如讀取圖片和圖片的一些的預(yù)處理,這樣GPU可以專注于訓(xùn)練過程,而CPU去準(zhǔn)備數(shù)據(jù)。
? ? 參考資料:
https://blog.csdn.net/u014061630/article/details/80776975
(五星推薦)TensorFlow全新的數(shù)據(jù)讀取方式:Dataset API入門教程:http://baijiahao.baidu.com/s?id=1583657817436843385&wfr=spider&for=pc
? ? Pipelining 將一個(gè)訓(xùn)練步驟的預(yù)處理和模型執(zhí)行重疊。當(dāng)加速器正在執(zhí)行訓(xùn)練步驟 N 時(shí),CPU 正在準(zhǔn)備步驟 N + 1 的數(shù)據(jù)。這樣做的目的是可以將步驟時(shí)間縮短到極致,包含訓(xùn)練以及提取和轉(zhuǎn)換數(shù)據(jù)所需時(shí)間(而不是總和)。
? ? 如果沒有使用 pipelining,則 CPU 和 GPU / TPU 在大部分時(shí)間處于閑置狀態(tài):
? ? 而使用 pipelining 技術(shù)后,空閑時(shí)間顯著減少:
? ? tf.data?API 通過 tf.data.Dataset.prefetch 轉(zhuǎn)換提供了一個(gè)軟件 pipelining 操作機(jī)制,該轉(zhuǎn)換可用于將數(shù)據(jù)生成的時(shí)間與所消耗時(shí)間分離。特別是,轉(zhuǎn)換使用后臺(tái)線程和內(nèi)部緩沖區(qū),以便在請(qǐng)求輸入數(shù)據(jù)集之前從輸入數(shù)據(jù)集中預(yù)提取元素。因此,為了實(shí)現(xiàn)上面說明的 pipelining 效果,您可以將 prefetch(1) 添加為數(shù)據(jù)集管道的最終轉(zhuǎn)換(如果單個(gè)訓(xùn)練步驟消耗 n 個(gè)元素,則添加 prefetch(n))。
? ? tf.data.Dataset.prefetch 提供了 software pipelining 機(jī)制。該函數(shù)解耦了 數(shù)據(jù)產(chǎn)生的時(shí)間 和 數(shù)據(jù)消耗的時(shí)間。具體來說,該函數(shù)有一個(gè)后臺(tái)線程和一個(gè)內(nèi)部緩存區(qū),在數(shù)據(jù)被請(qǐng)求前,就從 dataset 中預(yù)加載一些數(shù)據(jù)(進(jìn)一步提高性能)。prefech(n) 一般作為最后一個(gè) transformation,其中 n 為 batch_size。?prefetch 的使用方法如下:? ?
? ? 要將此更改應(yīng)用于我們的運(yùn)行示例,請(qǐng)將:
dataset = dataset.batch(batch_size=FLAGS.batch_size) return?dataset? ? 更改為:
dataset = dataset.batch(batch_size=FLAGS.batch_size) dataset = dataset.prefetch(buffer_size=FLAGS.prefetch_buffer_size) return?dataset? ? 請(qǐng)注意,在任何時(shí)候只要有機(jī)會(huì)將 “制造者” 的工作與 “消費(fèi)者” 的工作重疊,預(yù)取轉(zhuǎn)換就會(huì)產(chǎn)生效益。前面的建議只是最常見的應(yīng)用程序。
map
? ? 使用?tf.data.Dataset.map,我們可以很方便地對(duì)數(shù)據(jù)集中的各個(gè)元素進(jìn)行預(yù)處理。因?yàn)檩斎朐刂g時(shí)獨(dú)立的,所以可以在多個(gè) CPU 核心上并行地進(jìn)行預(yù)處理。map?變換提供了一個(gè)?num_parallel_calls參數(shù)去指定并行的級(jí)別。
? ? 準(zhǔn)備批處理時(shí),可能需要預(yù)處理輸入元素。為此,tf.data?API 提供了 tf.data.Dataset.map 轉(zhuǎn)換,它將用戶定義的函數(shù)(例如,運(yùn)行示例中的 parse_fn)應(yīng)用于輸入數(shù)據(jù)集的每個(gè)元素。由于輸入元素彼此獨(dú)立,因此可以跨多個(gè) CPU 內(nèi)核并行化預(yù)處理。為了實(shí)現(xiàn)這一點(diǎn),map 轉(zhuǎn)換提供了 thenum_parallel_calls 參數(shù)來指定并行度。例如,下圖說明了將 num_parallel_calls = 2 設(shè)置為 map 轉(zhuǎn)換的效果:
dataset = dataset.map(map_func=parse_fn, num_parallel_calls=FLAGS.num_parallel_calls)repeat
? ? repeat的功能就是將整個(gè)序列重復(fù)多次,主要用來處理機(jī)器學(xué)習(xí)中的epoch,假設(shè)原先的數(shù)據(jù)是一個(gè)epoch,使用repeat(5)就可以將之變成5個(gè)epoch:
? ? 如果直接調(diào)用repeat()的話,生成的序列就會(huì)無限重復(fù)下去,沒有結(jié)束,因此也不會(huì)拋出tf.errors.OutOfRangeError異常
實(shí)例代碼1:dataset.make_initializable_iterator()
# -*-coding: utf-8 -*- """@Project: fine tuning@File : pipeline.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2018-11-17 20:18:54 """ import tensorflow as tf import numpy as np import glob import matplotlib.pyplot as pltwidth=0 height=0 def show_image(title, image):'''顯示圖片:param title: 圖像標(biāo)題:param image: 圖像的數(shù)據(jù):return:'''# plt.figure("show_image")# print(image.dtype)plt.imshow(image)plt.axis('on') # 關(guān)掉坐標(biāo)軸為 offplt.title(title) # 圖像題目plt.show()def tf_read_image(filename, label):image_string = tf.read_file(filename)image_decoded = tf.image.decode_jpeg(image_string, channels=3)image = tf.cast(image_decoded, tf.float32)if width>0 and height>0:image = tf.image.resize_images(image, [height, width])image = tf.cast(image, tf.float32) * (1. / 255.0) # 歸一化return image, labeldef input_fun(files_list, labels_list, batch_size, shuffle=True):''':param files_list::param labels_list::param batch_size::param shuffle::return:'''# 構(gòu)建數(shù)據(jù)集dataset = tf.data.Dataset.from_tensor_slices((files_list, labels_list))if shuffle:dataset = dataset.shuffle(100)dataset = dataset.repeat() # 空為無限循環(huán)dataset = dataset.map(tf_read_image, num_parallel_calls=4) # num_parallel_calls一般設(shè)置為cpu內(nèi)核數(shù)量dataset = dataset.batch(batch_size)dataset = dataset.prefetch(2) # software pipelining 機(jī)制return datasetif __name__ == '__main__':data_dir = 'dataset/image/*.jpg'# labels_list = tf.constant([0,1,2,3,4])# labels_list = [1, 2, 3, 4, 5]files_list = glob.glob(data_dir)labels_list = np.arange(len(files_list))num_sample = len(files_list)batch_size = 1dataset = input_fun(files_list, labels_list, batch_size=batch_size, shuffle=False)# 需滿足:max_iterate*batch_size <=num_sample*num_epoch,否則越界max_iterate = 3with tf.Session() as sess:iterator = dataset.make_initializable_iterator()init_op = iterator.make_initializer(dataset)sess.run(init_op)iterator = iterator.get_next()for i in range(max_iterate):images, labels = sess.run(iterator)show_image("image", images[0, :, :, :])print('shape:{},tpye:{},labels:{}'.format(images.shape, images.dtype, labels))實(shí)例代碼2:dataset.make_one_shot_iterator()
? ? ?上面的迭代器是使用dataset.make_initializable_iterator(),當(dāng)然一個(gè)更簡(jiǎn)單的方法是使用dataset.make_one_shot_iterator(),下面的代碼,可把dataset.make_one_shot_iterator()放在input_fun函數(shù)中,直接返回一個(gè)迭代器iterator:
# -*-coding: utf-8 -*- """@Project: fine tuning@File : pipeline.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2018-11-17 20:18:54 """ import tensorflow as tf import numpy as np import glob import matplotlib.pyplot as pltwidth = 224 height = 224def show_image(title, image):'''顯示圖片:param title: 圖像標(biāo)題:param image: 圖像的數(shù)據(jù):return:'''# plt.figure("show_image")# print(image.dtype)plt.imshow(image)plt.axis('on') # 關(guān)掉坐標(biāo)軸為 offplt.title(title) # 圖像題目plt.show()def tf_read_image(filename, label):image_string = tf.read_file(filename)image_decoded = tf.image.decode_jpeg(image_string, channels=3)image = tf.cast(image_decoded, tf.float32)if width > 0 and height > 0:image = tf.image.resize_images(image, [height, width])image = tf.cast(image, tf.float32) * (1. / 255.0) # 歸一化return image, labeldef input_fun(files_list, labels_list, batch_size, shuffle=True):''':param files_list::param labels_list::param batch_size::param shuffle::return:'''# 構(gòu)建數(shù)據(jù)集dataset = tf.data.Dataset.from_tensor_slices((files_list, labels_list))if shuffle:dataset = dataset.shuffle(100)dataset = dataset.repeat() # 空為無限循環(huán)dataset = dataset.map(tf_read_image, num_parallel_calls=4) # num_parallel_calls一般設(shè)置為cpu內(nèi)核數(shù)量dataset = dataset.batch(batch_size)dataset = dataset.prefetch(2) # software pipelining 機(jī)制iterator = dataset.make_one_shot_iterator()return iteratorif __name__ == '__main__':data_dir = './data/demo_data/*.jpg'# labels_list = tf.constant([0,1,2,3,4])# labels_list = [1, 2, 3, 4, 5]files_list = glob.glob(data_dir)labels_list = np.arange(len(files_list))num_sample = len(files_list)batch_size = 4iterator = input_fun(files_list, labels_list, batch_size=batch_size, shuffle=False)# 需滿足:max_iterate*batch_size <=num_sample*num_epoch,否則越界max_iterate = 3with tf.Session() as sess:# iterator = dataset.make_initializable_iterator()# init_op = iterator.make_initializer(dataset)# sess.run(init_op)iterator = iterator.get_next()for i in range(max_iterate):images, labels = sess.run(iterator)show_image("image", images[0, :, :, :])print('shape:{},tpye:{},labels:{}'.format(images.shape, images.dtype, labels))實(shí)例代碼3:產(chǎn)生用于訓(xùn)練的圖和label
假設(shè)train.txt的數(shù)據(jù)如下:
0_8354.jpg 8 3 5 4 1_3621.jpg 3 6 2 1 2_4326.jpg 4 3 2 6 3_7711.jpg 7 7 1 1 # -*-coding: utf-8 -*- """@Project: verification_code@File : dataset.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2019-03-03 18:45:13 """ import tensorflow as tf import numpy as np import glob import os import matplotlib.pyplot as plt from utils import file_processing,image_processingprint("TF Version:{}".format(tf.__version__))resize_height = 0 # 指定存儲(chǔ)圖片高度 resize_width = 0 # 指定存儲(chǔ)圖片寬度def load_image_labels(filename):'''載圖txt文件,文件中每行為一個(gè)圖片信息,且以空格隔開:圖像路徑 標(biāo)簽1 標(biāo)簽1,如:test_image/1.jpg 0 2:param filename::return:'''images_list=[]labels_list=[]with open(filename) as f:lines = f.readlines()for line in lines:#rstrip:用來去除結(jié)尾字符、空白符(包括\n、\r、\t、' ',即:換行、回車、制表符、空格)content=line.rstrip().split(' ')name=content[0]labels=[]for value in content[1:]:labels.append(int(value))images_list.append(name)labels_list.append(labels)return images_list,labels_listdef show_image(title, image):'''顯示圖片:param title: 圖像標(biāo)題:param image: 圖像的數(shù)據(jù):return:'''# plt.figure("show_image")# print(image.dtype)plt.imshow(image)plt.axis('on') # 關(guān)掉坐標(biāo)軸為 offplt.title(title) # 圖像題目plt.show()def tf_resize_image(image, width=0, height=0):if (width is None) or (height is None): # 錯(cuò)誤寫法:resize_height and resize_width is Nonereturn imageimage = tf.image.resize_images(image, [height, width])return imagedef tf_read_image(file, width, height):image_string = tf.read_file(file)image_decoded = tf.image.decode_jpeg(image_string, channels=3)image = tf.cast(image_decoded, tf.float32)image=tf_resize_image(image, width, height)image = tf.cast(image, tf.float32) * (1. / 255.0) # 歸一化return imagedef map_read_image(files_list, labels_list):tf_image=tf_read_image(files_list,resize_width,resize_height)return tf_image,labels_listdef input_fun(files_list, labels_list, batch_size, shuffle=True):''':param orig_image::param dest_image::param batch_size::param num_epoch::param shuffle::return:'''# 構(gòu)建數(shù)據(jù)集dataset = tf.data.Dataset.from_tensor_slices((files_list, labels_list))#TF version>=1.4# dataset = tf.contrib.data.Dataset.from_tensor_slices((files_list, labels_list))#TF version<1.4if shuffle:dataset = dataset.shuffle(100)dataset = dataset.repeat() # 空為無限循環(huán)# dataset = dataset.map(map_read_image, num_parallel_calls=4) # num_parallel_calls一般設(shè)置為cpu內(nèi)核數(shù)量dataset = dataset.map(map_read_image) # num_parallel_calls一般設(shè)置為cpu內(nèi)核數(shù)量dataset = dataset.batch(batch_size)dataset = dataset.prefetch(2) # software pipelining 機(jī)制dataset = dataset.make_one_shot_iterator()return datasetdef get_image_data(images_list, image_dir,labels_list, batch_size, re_height, re_width, shuffle=False):global resize_heightglobal resize_widthresize_height = re_height # 指定存儲(chǔ)圖片高度resize_width = re_width # 指定存儲(chǔ)圖片寬度image_list = [os.path.join(image_dir, name) for name in images_list]dataset = input_fun(image_list, labels_list, batch_size, shuffle)return datasetif __name__ == '__main__':filename='../dataset/train.txt'image_dir="E:/TensoFlow/verification_code/dataset/train"images_list, labels_list=load_image_labels(filename)batch_size = 4dataset=get_image_data(images_list, image_dir,labels_list, batch_size, re_height=None, re_width=None, shuffle=False)# 需滿足:max_iterate*batch_size <=num_sample*num_epoch,否則越界max_iterate = 3with tf.Session() as sess:# dataset = dataset.make_initializable_iterator()# init_op = dataset.make_initializer(dataset)# sess.run(init_op)dataset = dataset.get_next()for i in range(max_iterate):images, labels = sess.run(dataset)print('shape:{},tpye:{},labels:{}'.format(images.shape, images.dtype, labels))show_image("image", images[0, :, :, :])實(shí)例代碼4:產(chǎn)生用于訓(xùn)練的原始圖和target目標(biāo)圖
# -*-coding: utf-8 -*- """@Project: triple_path_networks@File : load_data.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2018-11-29 11:40:37 """import tensorflow as tfimport glob import numpy as np import utils.image_processing as image_processing import os print("TF Version:{}".format(tf.__version__))resize_height = 0 # 指定存儲(chǔ)圖片高度 resize_width = 0 # 指定存儲(chǔ)圖片寬度def write_data(file, content_list, model):with open(file, mode=model) as f:for line in content_list:f.write(line + "\n")def read_data(file):with open(file, mode="r") as f:content_list = f.readlines()content_list = [content.rstrip() for content in content_list]return content_listdef read_train_val_data(filename,factor=0.8):image_list = read_data(filename)trian_num=int(len(image_list)*factor)train_list = image_list[:trian_num]val_list = image_list[trian_num:]print("data info***************************")print("--train nums:{}".format(len(train_list)))print("--val nums:{}".format(len(val_list)))print("************************************")return train_list,val_listdef tf_resize_image(image,width=0,height=0):if height>0 and width>0:image = tf.image.resize_images(image, [height, width])return imagedef tf_read_image(file,width=224,height=224):image_string = tf.read_file(file)image_decoded = tf.image.decode_jpeg(image_string, channels=3)image = tf.cast(image_decoded, tf.float32)if height>0 and width>0:image = tf.image.resize_images(image, [height, width])image = tf.cast(image, tf.float32) * (1. / 255.0) # 歸一化return imagedef map_read_image(orig_file, dest_file):orig_image=tf_read_image(orig_file,resize_width,resize_height)dest_image=tf_read_image(dest_file,resize_width,resize_height)return orig_image,dest_imagedef input_fun(orig_image, dest_image, batch_size, shuffle=True):''':param orig_image::param dest_image::param batch_size::param num_epoch::param shuffle::return:'''# 構(gòu)建數(shù)據(jù)集# dataset = tf.data.Dataset.from_tensor_slices((orig_image, dest_image))dataset = tf.contrib.data.Dataset.from_tensor_slices((orig_image, dest_image))if shuffle:dataset = dataset.shuffle(100)dataset = dataset.repeat() #空為無限循環(huán)# dataset = dataset.map(map_read_image, num_parallel_calls=4) # num_parallel_calls一般設(shè)置為cpu內(nèi)核數(shù)量dataset = dataset.map(map_read_image) # num_parallel_calls一般設(shè)置為cpu內(nèi)核數(shù)量dataset = dataset.batch(batch_size)# dataset = dataset.prefetch(2) # software pipelining 機(jī)制return datasetdef get_image_data(file_list,orig_dir,dest_dir,batch_size,re_height,re_width,shuffle=False):global resize_heightglobal resize_widthresize_height = re_height # 指定存儲(chǔ)圖片高度resize_width = re_width # 指定存儲(chǔ)圖片寬度orig_image_list=[os.path.join(orig_dir,name) for name in file_list]dest_image_list=[os.path.join(dest_dir,name) for name in file_list]dataset = input_fun(orig_image_list, dest_image_list, batch_size=batch_size,shuffle=shuffle)return datasetif __name__ == '__main__':orig_dir="../dataset/blackberry/blackberry"dest_dir="../dataset/blackberry/canon"filename="../dataset/blackberry/filelist.txt"batch_size = 1file_list=read_data(filename)dataset = get_image_data(file_list,orig_dir, dest_dir, batch_size=batch_size,shuffle=False)# 迭代次數(shù):max_iterate=10# 需滿足:max_iterate*batch_size <=num_sample*num_epoch,否則越界max_iterate = 5with tf.Session() as sess:iterator = dataset.make_initializable_iterator()init_op = iterator.make_initializer(dataset)sess.run(init_op)iterator = iterator.get_next()for i in range(max_iterate):orig_image, dest_image = sess.run(iterator)image_processing.show_image("orig_image", orig_image[0, :, :, :])image_processing.show_image("dest_image", dest_image[0, :, :, :])print('orig_image:{},dest_image:{}'.format(orig_image.shape, dest_image.shape))實(shí)例代碼5: tf.data.Dataset.from_generator
? ??tf.data.Dataset.from_tensor_slices并不支持輸入長度不同list,比如以下代碼
t = [[4,2], [3,4,5]] dataset = tf.data.Dataset.from_tensor_slices(t)? ? 將會(huì)報(bào)錯(cuò):
ValueError: Argument must be a dense tensor: [[4, 2], [3, 4, 5]] - got shape [2], but wanted [2, 2].? ? 一種決解的方法,采用? tf.data.Dataset.from_generator生成器:
import tensorflow as tf import numpy as np data1 = np.array([[1], [2, 3], [3, 4]]) data2 = np.array([[10], [20, 30], [30, 40]])def data_generator():for el, e2 in zip(data1, data2):yield el, e2dataset = tf.data.Dataset.from_generator(data_generator,output_types=(tf.int32, tf.int32),output_shapes=(None, None)) #或者output_shapes=(tf.TensorShape([None]), tf.TensorShape([None]))iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() max_iter = 3 with tf.Session() as sess:for i in range(max_iter):d1, d2 = sess.run(next_element)print("d1:{}".format(d1))print("d2:{}".format(d2))print("******************************")? 輸出:
d1:[1]
d2:[10]
******************************
d1:[2 3]
d2:[20 30]
******************************
d1:[3 4]
d2:[30 40]
******************************
參考資料:
https://stackoverflow.com/questions/47580716/how-to-input-a-list-of-lists-with-different-sizes-in-tf-data-dataset
https://blog.csdn.net/foreseerwang/article/details/80572182
https://blog.csdn.net/dqcfkyqdxym3f8rb0/article/details/79342369(五星推薦)
3.?用Python循環(huán)產(chǎn)生批量數(shù)據(jù)batch
? ? ?這部分請(qǐng)參考本人的博客《Python循環(huán)產(chǎn)生批量數(shù)據(jù)batch》??https://blog.csdn.net/guyuealian/article/details/83473298
? ? 上面提到的方法都是在TensorFlow提高API接口完成的,數(shù)據(jù)預(yù)處理也必須依賴TensorFlow的API接口。當(dāng)遇到一些特殊處理,而TensorFlow沒有相應(yīng)的接口時(shí),就比較尷尬。比如要對(duì)輸入的圖像進(jìn)行邊緣檢測(cè)處理時(shí),這時(shí)能想到就是用OpenCV的Canny算法,一種簡(jiǎn)單的方法就是,每次sess.run()獲得圖像數(shù)據(jù)后,再調(diào)用OpenCV的Canny算法……是的,有的麻煩!
? ? ?這里提供一個(gè)我自己設(shè)計(jì)方法,不依賴TensorFlow,靈活性很強(qiáng),你可以對(duì)數(shù)據(jù)進(jìn)行任意的操作,可以使用OpenCV,numpy等任意的庫函數(shù)。
? ?TXT文本如下,格式:圖片名 label1 label2 ,注意label可以多個(gè)
1.jpg 1 11 2.jpg 2 12 3.jpg 3 13 4.jpg 4 14 5.jpg 5 15 6.jpg 6 16 7.jpg 7 17 8.jpg 8 18? ? 要想產(chǎn)生batch數(shù)據(jù),關(guān)鍵是要用到Python的關(guān)鍵字yield,實(shí)現(xiàn)一個(gè)batch一個(gè)batch的返回?cái)?shù)據(jù),代碼實(shí)現(xiàn)主要有兩個(gè)方法:
def get_data_batch(inputs, batch_size=None, shuffle=False):'''循環(huán)產(chǎn)生批量數(shù)據(jù)batch:param inputs: list數(shù)據(jù):param batch_size: batch大小:param shuffle: 是否打亂inputs數(shù)據(jù):return: 返回一個(gè)batch數(shù)據(jù)''' def get_next_batch(batch):return batch.__next__()? ? 使用時(shí),將數(shù)據(jù)傳到?get_data_batch( )方法,然后使用get_next_batch( )獲得一個(gè)batch數(shù)據(jù),完整的Python代碼如下:
# -*-coding: utf-8 -*- """@Project: create_batch_data@File : create_batch_data.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2017-10-27 18:20:15 """ import math import random import os import glob import numpy as npdef get_data_batch(inputs, batch_size=None, shuffle=False):'''循環(huán)產(chǎn)生批量數(shù)據(jù)batch:param inputs: list類型數(shù)據(jù),多個(gè)list,請(qǐng)[list0,list1,...]:param batch_size: batch大小:param shuffle: 是否打亂inputs數(shù)據(jù):return: 返回一個(gè)batch數(shù)據(jù)'''rows = len(inputs[0])indices = list(range(rows))# 如果輸入是list,則需要轉(zhuǎn)為listif shuffle:random.seed(100)random.shuffle(indices)while True:batch_indices = np.asarray(indices[0:batch_size]) # 產(chǎn)生一個(gè)batch的indexindices = indices[batch_size:] + indices[:batch_size] # 循環(huán)移位,以便產(chǎn)生下一個(gè)batchbatch_data = []for data in inputs:data = np.asarray(data)temp_data=data[batch_indices] #使用下標(biāo)查找,必須是ndarray類型類型batch_data.append(temp_data.tolist())yield batch_datadef get_data_batch2(inputs, batch_size=None, shuffle=False):'''循環(huán)產(chǎn)生批量數(shù)據(jù)batch:param inputs: list類型數(shù)據(jù),多個(gè)list,請(qǐng)[list0,list1,...]:param batch_size: batch大小:param shuffle: 是否打亂inputs數(shù)據(jù):return: 返回一個(gè)batch數(shù)據(jù)'''# rows,cols=inputs.shaperows = len(inputs[0])indices = list(range(rows))if shuffle:random.seed(100)random.shuffle(indices)while True:batch_indices = indices[0:batch_size] # 產(chǎn)生一個(gè)batch的indexindices = indices[batch_size:] + indices[:batch_size] # 循環(huán)移位,以便產(chǎn)生下一個(gè)batchbatch_data = []for data in inputs:temp_data = find_list(batch_indices, data)batch_data.append(temp_data)yield batch_datadef find_list(indices, data):out = []for i in indices:out = out + [data[i]]return outdef get_list_batch(inputs, batch_size=None, shuffle=False):'''循環(huán)產(chǎn)生batch數(shù)據(jù):param inputs: list數(shù)據(jù):param batch_size: batch大小:param shuffle: 是否打亂inputs數(shù)據(jù):return: 返回一個(gè)batch數(shù)據(jù)'''if shuffle:random.shuffle(inputs)while True:batch_inouts = inputs[0:batch_size]inputs = inputs[batch_size:] + inputs[:batch_size] # 循環(huán)移位,以便產(chǎn)生下一個(gè)batchyield batch_inoutsdef load_file_list(text_dir):text_dir = os.path.join(text_dir, '*.txt')text_list = glob.glob(text_dir)return text_listdef get_next_batch(batch):return batch.__next__()def load_image_labels(finename):'''載圖txt文件,文件中每行為一個(gè)圖片信息,且以空格隔開:圖像路徑 標(biāo)簽1 標(biāo)簽1,如:test_image/1.jpg 0 2:param test_files::return:'''images_list = []labels_list = []with open(finename) as f:lines = f.readlines()for line in lines:# rstrip:用來去除結(jié)尾字符、空白符(包括\n、\r、\t、' ',即:換行、回車、制表符、空格)content = line.rstrip().split(' ')name = content[0]labels = []for value in content[1:]:labels.append(float(value))images_list.append(name)labels_list.append(labels)return images_list, labels_listif __name__ == '__main__':filename = './training_data/test.txt'images_list, labels_list = load_image_labels(filename)# 若輸入為np.arange數(shù)組,則需要tolist()為list類型,如:# images_list = np.reshape(np.arange(8*3), (8,3))# labels_list = np.reshape(np.arange(8*3), (8,3))# images_list=images_list.tolist()# labels_list=labels_list.tolist()iter = 5 # 迭代3次,每次輸出一個(gè)batch個(gè)# batch = get_data_batch([images_list, labels_list], batch_size=3, shuffle=False)batch = get_data_batch2(inputs=[images_list,labels_list], batch_size=5, shuffle=True)for i in range(iter):print('**************************')batch_images, batch_labels = get_next_batch(batch)print('batch_images:{}'.format(batch_images))print('batch_labels:{}'.format(batch_labels))? ?運(yùn)行輸出結(jié)果為:
**************************
batch_images:['1.jpg', '2.jpg', '3.jpg']
batch_labels:[[1.0, 11.0], [2.0, 12.0], [3.0, 13.0]]
**************************
batch_images:['4.jpg', '5.jpg', '6.jpg']
batch_labels:[[4.0, 14.0], [5.0, 15.0], [6.0, 16.0]]
**************************
batch_images:['7.jpg', '8.jpg', '1.jpg']
batch_labels:[[7.0, 17.0], [8.0, 18.0], [1.0, 11.0]]
**************************
batch_images:['2.jpg', '3.jpg', '4.jpg']
batch_labels:[[2.0, 12.0], [3.0, 13.0], [4.0, 14.0]]
**************************
batch_images:['5.jpg', '6.jpg', '7.jpg']
batch_labels:[[5.0, 15.0], [6.0, 16.0], [7.0, 17.0]]
Process finished with exit code 0
?
4.參考資料:
[1]https://blog.csdn.net/happyhorizion/article/details/77894055? (五星推薦)
[2]https://blog.csdn.net/ywx1832990/article/details/78462582
[3]https://blog.csdn.net/csuzhaoqinghui/article/details/51377941
[4]《tf.data API,構(gòu)建高性能 TensorFlow 輸入管道》
?
?
總結(jié)
以上是生活随笔為你收集整理的TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和 tf.data.Dataset机制的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: OpenCV图像缩放resize各种插值
- 下一篇: Dilated/Atrous conv