深度学习之基于Xception实现四种动物识别
本次實驗類似于貓狗大戰,只不過將兩種動物識別變為了四種動物識別。
本文的重點是卷積神經網絡Xception的實踐,在之前的學習中,我們已經實驗過其他幾種比較常用的網絡模型,但是Xception網絡并未實踐過。在弄本科畢設的時候,一個好朋友的畢設就是基于Xception實現海洋垃圾的識別,最終的實驗效果達到了99%左右,由此可見Xception的模型性能還是不錯的。
本次實驗基于Xception實現動物識別,最終的模型準確率在95%左右。
1.導入庫
import numpy as np import tensorflow as tf import matplotlib.pyplot as plt import os,pathlib,PIL# 支持中文 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用來正常顯示中文標簽 plt.rcParams['axes.unicode_minus'] = False # 用來正常顯示負號2.數據加載
data_dir = "E:/tmp/.keras/datasets/animal_photos" data_dir = pathlib.Path(data_dir) img_count = len(list(data_dir.glob('*/*')))共4000張圖片
all_images_paths = list(data_dir.glob('*')) all_images_paths = [str(path) for path in all_images_paths] all_label_names = [path.split("\\")[5].split(".")[0] for path in all_images_paths] 分為四類: ['cat', 'chook', 'dog', 'horse']超參數的設置:
height = 224 width = 224 epochs =10 batch_size = 128圖像增強:
一共分為4類,每一類有1000張圖片,數據并不是很多,因此對原數據進行數據加強。并按照8:2的比例劃分訓練集與測試集。
顯示圖像:
plt.figure(figsize=(15, 10)) # 圖形的寬為15高為10for images, labels in train_ds:for i in range(8):ax = plt.subplot(5, 8, i + 1)plt.imshow(images[i])plt.title(all_label_names[np.argmax(labels[i])])plt.axis("off")break plt.show()3.Xception模型
Xception是Inception的改進版本,創新點便是 深度可分離卷積。
深度可分離卷積 = 深度卷積+逐點卷積。具體步驟如下所示:
第一步:Depthwise 卷積,對輸入的每個channel,分別進行 3 × 3 卷積操作,并將結果 concat。
第二步:Pointwise 卷積,對 Depthwise 卷積中的 concat 結果,進行 1 × 1 卷積操作。
標準卷積與深度可分離卷積的對比如下所示:圖片來源
既然最終的結果是一樣的,那為什么深度可分離卷積方式更優呢?
原因就是利用深度可分離卷積,參數更少,從而在迭代更新的過程中,計算量就更小。
本次實驗利用遷移學習采用官方模型進行訓練
base_model = tf.keras.applications.Xception(weights = 'imagenet',include_top = False,pooling = 'max',input_shape = (height,width,3)) base_model.trainable = False#前面的參數設置為不可訓練 input = base_model.input x = tf.keras.layers.Dense(256,activation='relu')(base_model.output) x = tf.keras.layers.Dense(128,activation='relu')(x) output = tf.keras.layers.Dense(4,activation='sigmoid')(x) model = tf.keras.models.Model(inputs = input,outputs = output)優化器的設置:
# 設置初始學習率 initial_learning_rate = 1e-4lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate,decay_steps=300,decay_rate=0.96,staircase=True)# 將指數衰減學習率送入優化器 optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)網絡編譯&&訓練
model.compile(optimizer = optimizer,loss = "categorical_crossentropy",metrics = ['accuracy'] )history = model.fit(train_ds,validation_data = test_ds,epochs = epochs )Accuracy與Loss圖如下所示:
模型準確率比較高,在95%左右。
4.預測&&混淆矩陣
模型保存:
model.save("E:/Users/yqx/PycharmProjects/animal_rec/model.h5")模型加載:
model = tf.keras.models.load_model("E:/Users/yqx/PycharmProjects/animal_rec/model.h5")預測:
plt.figure(figsize=(50,50)) num = 0 for images,labels in test_ds:for i in range(64):ax = plt.subplot(8,8,i+1)plt.imshow(images[i])img_array = tf.expand_dims(images[i],0)pre = model.predict(img_array)if np.argmax(pre) == np.argmax(labels[i]):plt.title(all_label_names[np.argmax(pre)])else:plt.title("False :"+str(all_label_names[np.argmax(pre)]))if np.argmax(pre) == np.argmax(labels[i]):num += 1plt.axis("off")break plt.suptitle("The Acc rating is:{}".format(num / 64)) plt.show()
混淆矩陣的繪制:
努力加油a啊
總結
以上是生活随笔為你收集整理的深度学习之基于Xception实现四种动物识别的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 七个合法学习黑客技术的网站,让你从萌新成
- 下一篇: 梳理百年深度学习发展史-七月在线机器学习