Keras ImageDataGenerator用于数据扩充/增强的原理及方法
摘要
在這篇博客中,您將學習如何使用Keras的ImageDataGenerator類執行數據擴充/增強。另外將介紹什么是數據增強,數據增強的類型,為什么使用數據增強以及它能做什么/不能做什么。
有三種數據增強類型,默認情況下,Keras的ImageDataGenerator該類執行就地/即時數據擴充。
檢測到過度擬合的倆種解決方案是(1)減少模型容量或(2)執行正則化。
數據增強是正則化的一種形式,使我們的網絡可以更好地將其推廣到我們的測試/驗證集。
在訓練中不應用數據增強會導致過度擬合。應用數據增強,可以進行平滑的訓練,避免過度擬合以及擁有更高的準確性/更低的損失。
強烈建議在所有的訓練中都使用數據增強。
1. Keras ImageDataGenerator是什么
Keras的ImageDataGenerator在訓練卷積神經網絡中很常見,是對待訓練的數據集執行一系列隨機變換后進行訓練模型,提高模型的通用性,使得模型具有更好的泛化能力。
在修改后的擴充數據上訓練的模型更有可能概括為訓練集中未包含的示例數據點。
也可以通過一些簡單的幾何變換得到增強后的圖像,如平移,旋轉,放大/縮小,剪切,水平/垂直翻轉等;
對輸入圖像應用少量的轉換將稍微改變其外觀,但不會更改類標簽,從而使數據增強成為適用于計算機視覺任務的非常自然,簡便的方法。
2. Keras ImageDataGenerator工作原理
ImageDataGenerator接受原始數據,對其進行隨機轉換,并僅返回轉換后的新數據。
- 接受一批用于訓練的圖像;
- 進行此批處理并對批處理中的每個圖像應用一系列隨機變換(包括隨機旋轉,調整大小,剪切等);
- 用新的,隨機轉換的批次替換原始批次;
- 在此隨機轉換的批次上訓練CNN(即原始數據本身不用于訓練)。
3. Keras ImageDataGenerator的三種類型
- (1)通過數據增強生成數據集和數據擴展(較少見)
這種方法存在一個問題——尚未完全提高模型的泛化能力。
想象一下通過一張圖生成100張圖然后進行訓練;由于所有這些數據均基于超小型數據集。
我們不能期望在少量數據上訓練NN,然后期望將其推廣到從未訓練過且從未見過的數據。
- (2)就地/即時數據增強(最常見)
這種加強方式是使用最普遍的,有倆個地方需要注意:
- ImageDataGenerator 是不是原始數據和變換后的數據都返回——只返回隨機變換的數據。
- 因為這種擴充是在訓練時完成的,因此稱其為“就地”和“即時”數據擴充(即不會在訓練之前生成這些示例);
由于訓練的時候用的是經過隨機平移、旋轉、剪切等變換后的數據進行的,因此模型具有了比較好的泛化能力,其在測試集上表現良好,而在訓練集上將差一些,由于我們并沒有拿原始的訓練數據訓練,因此具有一定的偏差。
- (3)將數據集生成和就地擴充相結合
在訓練數據很少,并且真實的場景數據比較難以收集的情況下,可以用將類型2數據擴充(即就地/即時數據擴充)應用于通過模擬收集的數據。
類似于行為克隆,在自動駕駛應用中有運用。
4. 項目結構
5. 實現generate_images.py, train.py并訓練CNN
- generate_images.py 生成數據增強后的數據集
- train.py 并進行不同的數據增強后,進行模型訓練
(1)通用1張圖像生成100張訓練數據,并訓練CNN; 50%的準確率
(2)使用 Kaggle狗與貓的數據 集的一個子集,并在不進行數據擴充的情況下訓練CNN; 64%的準確率
(3)使用 Kaggle狗與貓的數據 集的一個子集,并在進行數據擴充的情況下訓練CNN; 69%的準確率
運用(1)生成的訓練精確度/損失圖
運用(2)生成的訓練精確度/損失圖
運用(3)生成的訓練精確度/損失圖【收斂的比較好,不會有精確度身高,損失也跟著升高的情況,可以完美的避開過度擬合,并且具有比較好的泛化能力】
得出結論:
- 數據增強可以減少過度擬合,并提高模型進行泛化的能力;
- 數據增強是一種正則化形式,保證驗證和訓練損失如何在幾乎沒有分歧的情況下下降。同樣,訓練和驗證拆分的分類準確性也一起提高;
- 通過使用數據增強,可以克服過度擬合!
# 測試三種數據增強類型后訓練的模型情況# 第一種試驗:通用1張圖像生成100張訓練數據,進行訓練 50%的準確率
# python train.py --dataset generated_dataset --plot plot_generated_dataset.png
# 探討數據擴充如何通過兩次實驗來減少過度擬合并提高模型進行泛化的能力,獲取到了 64%的準確率,檢測到過度擬合的倆種解決方案是(1)減少模型容量或(2)執行正則化。
# 第二種試驗:不使用數據擴充
# python train.py --dataset dogs_vs_cats_small --plot plot_dogs_vs_cats_no_aug.png
# 第三種試驗:運用數據擴充 研究數據增強如何充當正則化形式 69%的準確率 【注意驗證和訓練損失如何在幾乎沒有分歧的情況下下降。同樣,訓練和驗證拆分的分類準確性也一起提高。】
# 通過使用數據增強,我們可以克服過度擬合!
# 強烈建議在任何情況下訓練神經網絡時都使用 數據增強;
# python train.py --dataset dogs_vs_cats_small --augment 1 --plot plot_dogs_vs_cats_with_aug.png# 導入必要的包
# 設置matplot為Agg以保存模型訓練的plot圖到磁盤
import matplotlib
matplotlib.use("Agg")from pyimagesearch.resnet import ResNet
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
# 導入ImageDataGenerator
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.utils import to_categorical
from imutils import paths
import matplotlib.pyplot as plt
import numpy as np
import argparse
import cv2
import os# 構建命令行參數
# --dataset 數據集的路徑
# --augment 是否使用數據增強方式2(1.通過數據增強生成數據集和數據擴展(較少見) 2.就地/即時數據增強(最常見) 3.將數據集生成和就地擴充相結合 )
# --plot 保存 loss/accuracy 圖的路徑
ap = argparse.ArgumentParser()
ap.add_argument("-d", "--dataset", required=True,help="path to input dataset")
ap.add_argument("-a", "--augment", type=int, default=-1,help="whether or not 'on the fly' data augmentation should be used")
ap.add_argument("-p", "--plot", type=str, default="plot.png",help="path to output loss/accuracy plot")
args = vars(ap.parse_args())# 初始化初始學習率,批處理大小batchsize,訓練的期數epochs
INIT_LR = 1e-1
BS = 8
EPOCHS = 50# 獲取數據集,并把數據,標簽按順序存儲在list中
print("[INFO] loading images...")
imagePaths = list(paths.list_images(args["dataset"]))
data = []
labels = []# 循環遍歷圖片路徑
for imagePath in imagePaths:# 從文件名中提取分類標簽名稱,加載圖片,忽略寬高比縮放為 64*64label = imagePath.split(os.path.sep)[-2]image = cv2.imread(imagePath)image = cv2.resize(image, (64, 64))# 更新數據、標簽listdata.append(image)labels.append(label)# 轉換數據、標簽list為Numpy array,并將數據的像素強度轉換為[0,255]
data = np.array(data, dtype="float") / 255.0# 編碼類標簽,由字符串轉為integer轉為 一鍵熱編碼數組(echc:[1,0]代表cats,[0,1]代表dogs)
le = LabelEncoder()
labels = le.fit_transform(labels)
labels = to_categorical(labels, 2)#分組數據為75%的訓練數據,25%的測試數據
(trainX, testX, trainY, testY) = train_test_split(data, labels,test_size=0.25, random_state=42)# 初始化數據擴充對象(初始化一個空對象)
aug = ImageDataGenerator()# 檢查是否需要進行數據擴充 --augment參數的值
if args["augment"] > 0:print("[INFO] performing 'on the fly' data augmentation")# 隨機旋轉,縮放,移動,剪切和翻轉。(random rotations, zooms, shifts, shears, and flips)aug = ImageDataGenerator(rotation_range=20,zoom_range=0.15,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.15,horizontal_flip=True,fill_mode="nearest")# 初始化優化器和模型
# 構建我們的ResNet,使用隨機梯度下降優化和學習率衰減的模型。我們使用“ binary_crossentropy” 2類問題的損失。如果您有兩個以上的類標簽,請確保使用“ categorial_crossentropy”
print("[INFO] compiling model...")
opt = SGD(lr=INIT_LR, momentum=0.9, decay=INIT_LR / EPOCHS)
model = ResNet.build(64, 64, 3, 2, (2, 3, 4),(32, 64, 128, 256), reg=0.0001)
model.compile(loss="binary_crossentropy", optimizer=opt,metrics=["accuracy"])# 訓練模型
# 對象分批處理數據擴充(僅當--augment命令行參數已設置,對象才會執行數據擴充)
print("[INFO] training network for {} epochs...".format(EPOCHS))
H = model.fit(x=aug.flow(trainX, trainY, batch_size=BS),validation_data=(testX, testY),steps_per_epoch=len(trainX) // BS,epochs=EPOCHS)# 評估模型
print("[INFO] evaluating network...")
predictions = model.predict(x=testX.astype("float32"), batch_size=BS)
print(classification_report(testY.argmax(axis=1),predictions.argmax(axis=1), target_names=le.classes_))# 繪制訓練損失/精確度圖
N = np.arange(0, EPOCHS)
plt.style.use("ggplot")
plt.figure()
plt.plot(N, H.history["loss"], label="train_loss")
plt.plot(N, H.history["val_loss"], label="val_loss")
plt.plot(N, H.history["accuracy"], label="train_acc")
plt.plot(N, H.history["val_accuracy"], label="val_acc")
plt.title("Training Loss and Accuracy on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig(args["plot"])
參考:
- https://www.pyimagesearch.com/2019/07/08/keras-imagedatagenerator-and-data-augmentation/
總結
以上是默认站点為你收集整理的Keras ImageDataGenerator用于数据扩充/增强的原理及方法的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Keras TensorFlow教程:使
- 下一篇: 使用Python,OpenCV构建透明的