TensorFlow(2)-训练数据载入
tensorflow 訓練數據載入
- 1. tf.data.Dataset
- 2. dataset 創建數據集的方式
- 2.1 tf.data.Dataset.from_tensor_slices()
- 2.2 tf.data.TextLineDataset()
- 2.3 tf.data.FixedLengthRecordDataset()
- 2.4 tf.data.TFRecordDataset()
- 3. dateset 迭代操作iterator
- 3.1 make_one_shot_iterator()
- 3.2 make_initializable_iterator()
- 3.3 reinitializable iterator()
- 3.4 feedable iterator()
- 4. dataset的map、batch、shuffle、repeat操作
- 5. 非eager/eager 模式
- 5.1 非eager模式demo
- 5.2 eager模式demo
1. tf.data.Dataset
參考Google官方給出的Dataset API中的類圖,Dataset 務于數據讀取,構建輸入數據的pipeline。
Dataset可以看作是相同類型“元素”的有序列表,可使用Iterator迭代獲取Dataset中的元素。
2. dataset 創建數據集的方式
2.1 tf.data.Dataset.from_tensor_slices()
從tensor中創建數據集,數據集元素以tensor第一維度為劃分。
import tensorflow as tf import numpy as np # 切分傳入Tensor的第一個維度,生成相應的dataset。 dataset1 = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) # 如果傳入字典,那切分結果就是字典按值切分,元素型如{"a":[1],"b":[x,x]} dataset2 = tf.data.Dataset.from_tensor_slices({"a": np.array([1.0, 2.0, 3.0, 4.0, 5.0]), "b": np.random.uniform(size=(5, 2))} )2.2 tf.data.TextLineDataset()
讀取文件數據創建數據集,數據集元素為文件的每一行
2.3 tf.data.FixedLengthRecordDataset()
從一個文件列表和record_bytes中創建數據集,數據集元素是文件中固定字節數record_bytes的內容。
2.4 tf.data.TFRecordDataset()
讀TFRecord文件創建數據集,數據集中的一條數據是一個TFExample。
dataset = tf.data.TFRecordDataset(filenames = [tfrecord_file_name]) # [tfrecord_file_name] tfrecord 文件列表
frecord 文件中的特征一般都經過tf.train.Example 序列化,在使用前需要先解碼tf.train.Example.FromString()
raw_example = next(iter(dataset)) parsed = tf.train.Example.FromString(raw_example.numpy())3. dateset 迭代操作iterator
iterator是從Dataset對象中創建出來的,用于迭代取數據集中的元素。
3.1 make_one_shot_iterator()
dataset.make_one_shot_iterator()–只能從頭到尾讀取一次dataset。如果一個dataset中元素被讀取完了再sess.run()的話,會拋出tf.errors.OutOfRangeError異常。因此可以在外界捕捉這個異常以判斷數據是否讀取完。
import tensorflow as tf import numpy as np # 切分傳入Tensor的第一個維度,生成相應的dataset。如果傳入字典,那切分結果就是字典按值切分 dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) iterator = dataset.make_one_shot_iterator() # 只能從頭到尾讀取一次 one_element = iterator.get_next() # 從iterator里取出一個元素。 # 處于非Eager模式,所以one_element只是一個Tensor,并不是一個實際的值。調用sess.run(one_element)后,才能真正地取出一個值。 with tf.Session() as sess:try:while True:print(sess.run(one_element))except tf.errors.OutOfRangeError:print("end!")3.2 make_initializable_iterator()
dataset.make_initializable_iterator()–支持placeholder dataset 的迭代操作,這可以方便通過參數快速定義新的Iterator。
# limit相當于一個參數,它規定了Dataset中數的上限, 使用make_initializable_iterator limit = tf.placeholder(dtype=tf.int32, shape=[]) dataset = tf.data.Dataset.from_tensor_slices(tf.range(start=0, limit=limit)) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() with tf.Session() as sess:sess.run(iterator.initializer, feed_dict={limit: 10})for i in range(10):value = sess.run(next_element)assert i == valuesess.run(next_element) 每run一次, 數據迭代器指針就會往下移動一個。TF官網學習(9)–使用iterator注意事項
如果在dataset的構建時,一次性讀入了所有的數據,會導致計算圖變得很大,給傳輸、保存帶來不便。make_initializable_iterator()支持placeholder 操作,僅在需要傳輸數據時再取數據。
# 從硬盤中讀入兩個Numpy數組 with np.load("/var/data/training_data.npy") as data:features = data["features"]labels = data["labels"]features_placeholder = tf.placeholder(features.dtype, features.shape) labels_placeholder = tf.placeholder(labels.dtype, labels.shape)dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder)) iterator = dataset.make_initializable_iterator() sess.run(iterator.initializer, feed_dict={features_placeholder: features,labels_placeholder: labels})3.3 reinitializable iterator()
dataset.reinitializable iterator() --待補
3.4 feedable iterator()
dataset.feedable iterator()–待補
4. dataset的map、batch、shuffle、repeat操作
map–接收一個函數,Dataset中的每個元素都會被當作這個函數的輸入,并將函數返回值作為新的Dataset。
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0batch–將多個元素組合成一個batch
dataset = dataset.batch(16) # 將數據集劃分為batch size為16的小批次shuffle– 打亂dataset中的元素,參數buffersize。打亂的實現機理:從buffer_size 大小的部buffer中隨機抽取元素,組成打亂后的數據集。buffer中被抽走的元素由原數據集中的后續元素補位置。 重復‘抽取-補充’這個過程,直至buffer為空。
會在batch之間打亂數據–疑問多tfrecord 文件是一次性構建數據集還是一條一條的構建
buffer_size 的大小詳見tf.data.Dataset.shuffle(buffer_size)中buffer_size的理解
dataset = dataset.shuffle(buffer_size=10000)repeat– 將整個序列重復多次,用來處理機器學習中的epoch,假設原始數據是一個epoch,使用repeat(5)就可以將之變成5個epoch
dataset = dataset.repeat(5)5. 非eager/eager 模式
5.1 非eager模式demo
在非Eager模式下,Dataset中讀出的一個元素一般對應一個batch的Tensor,我們可以使用這個Tensor在計算圖中構建模型。
import tensorflow as tf import numpy as np # 切分傳入Tensor的第一個維度,生成相應的dataset。如果傳入字典,那切分結果就是字典按值切分 dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) iterator = dataset.make_one_shot_iterator() # 只能從頭到尾讀取一次 one_element = iterator.get_next() # 從iterator里取出一個元素。 # 處于非Eager模式,所以one_element只是一個Tensor,并不是一個實際的值。調用sess.run(one_element)后,才能真正地取出一個值。 with tf.Session() as sess:try:while True:print(sess.run(one_element))except tf.errors.OutOfRangeError:print("end!")5.2 eager模式demo
在Eager模式下,Dataset建立Iterator的方式有所不同,此時通過讀出的數據就是含有值的Tensor,方便調試。
import tensorflow.contrib.eager as tfe tfe.enable_eager_execution() dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) for one_element in tfe.Iterator(dataset):print(one_element) # 可直接讀取數據參考文獻:TensorFlow全新的數據讀取方式:Dataset API入門教程
總結
以上是生活随笔為你收集整理的TensorFlow(2)-训练数据载入的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: ThinkPHP redirect 页面
- 下一篇: 在GCC和Visual Studio中使