TensorFlow2简单入门-图像加载及预处理
生活随笔
收集整理的這篇文章主要介紹了
TensorFlow2简单入门-图像加载及预处理
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
下載數據
import tensorflow as tfimport pathlib data_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',fname='flower_photos', untar=True) data_root = pathlib.Path(data_root_orig) print(data_root) """ 輸出: C:\Users\Administrator\.keras\datasets\flower_photos """可以通過C:\Users\Administrator.keras\datasets\flower_photos路徑查找到下載的文件
#查看數據目錄 for item in data_root.iterdir():print(item) """ 輸出: C:\Users\Administrator\.keras\datasets\flower_photos\daisy C:\Users\Administrator\.keras\datasets\flower_photos\dandelion C:\Users\Administrator\.keras\datasets\flower_photos\LICENSE.txt C:\Users\Administrator\.keras\datasets\flower_photos\roses C:\Users\Administrator\.keras\datasets\flower_photos\sunflowers C:\Users\Administrator\.keras\datasets\flower_photos\tulips """flower_photos文件夾下包括5個文件夾和一個說明文件,5個文件夾中分別放有5個類別的數據(即對應著5種不同的標簽。)
import random #獲取所有圖片的路徑 all_image_paths = list(data_root.glob('*/*')) all_image_paths = [str(path) for path in all_image_paths] #將所有路徑打亂 random.shuffle(all_image_paths)image_count = len(all_image_paths) image_count """ 輸出:3670 """ all_image_paths[:3] """ 輸出: ['C:\\Users\\Administrator\\.keras\\datasets\\flower_photos\\daisy\\11870378973_2ec1919f12.jpg','C:\\Users\\Administrator\\.keras\\datasets\\flower_photos\\roses\\8442304572_2fdc9c7547_n.jpg','C:\\Users\\Administrator\\.keras\\datasets\\flower_photos\\dandelion\\17574213074_f5416afd84.jpg'] """檢查圖片
from PIL import Image import ostrain_images = [] for image in all_image_paths[]:train_images.append(Image.open(os.path.join(image)))將圖片與標簽同步從本地文件中拿出來。
import matplotlib.pyplot as plttrain_labels = [pathlib.Path(path).parent.name for path in all_image_paths]plt.figure(figsize=(20,10)) for i in range(20):plt.subplot(5,10,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i])plt.xlabel(train_labels[i]) plt.show()構建一個 tf.data.Dataset
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_root,validation_split=0.2,subset="training",seed=123,image_size=(192, 192),batch_size=20)class_names = train_ds.class_names print("\n",class_names)train_ds """ 輸出: Found 3670 files belonging to 5 classes. Using 2936 files for training.['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] <BatchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)> """train_ds = tf.keras.preprocessing.image_dataset_from_directory():將創建一個從本地目錄讀取圖像數據的數據集。數據集對象可以直接傳遞到fit(),也可以在自定義低級訓練循環中進行迭代。
import matplotlib.pyplot as pltplt.figure(figsize=(20, 10)) for images, labels in train_ds.take(1):for i in range(20):ax = plt.subplot(5, 10, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")- dataset.take(1):取第一個元素構建dataset(是第一個元素,不是隨機的一個),從文件中讀取數據形成train_ds時是以為20為一個步長的,故這里的dataset.take(1)即前20個數據。
- dataset.skip(2):跳過前2個元素后構建的dataset
總結
以上是生活随笔為你收集整理的TensorFlow2简单入门-图像加载及预处理的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: TensorFlow2简单入门-三维张量
- 下一篇: TensorFlow2简单入门-加载及预