keras框架实现手写数字识别
生活随笔
收集整理的這篇文章主要介紹了
keras框架实现手写数字识别
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
詳細細節可學習從零開始神經網絡:keras框架實現數字圖像識別詳解!
代碼實現:
[1] ''' 將訓練數據和檢測數據加載到內存中(第一次運行需要下載數據,會比較慢): (mnist是手寫數據集) train_images是用于訓練系統的手寫數字圖片; train_labels是用于標注圖片的信息; test_images是用于檢測系統訓練效果的圖片; test_labels是test_images圖片對應的數字標簽。 ''' from tensorflow.keras.datasets import mnist (train_images, train_labels), (test_images, test_labels) = mnist.load_data() print('train_images.shape = ',train_images.shape) print('tran_labels = ', train_labels) print('test_images.shape = ', test_images.shape) print('test_labels', test_labels) ''' 1.train_images.shape打印結果表明,train_images是一個含有60000個元素的數組. 數組中的元素是一個二維數組,二維數組的行和列都是28. 也就是說,一個數字圖片的大小是28*28. 2.train_lables打印結果表明,第一張手寫數字圖片的內容是數字5,第二種圖片是數字0,以此類推. 3.test_images.shape的打印結果表示,用于檢驗訓練效果的圖片有10000張。 4.test_labels輸出結果表明,用于檢測的第一張圖片內容是數字7,第二張是數字2,依次類推。 '''[2] ''' 把用于測試的第一張圖片打印出來看看 ''' digit = test_images[0] import matplotlib.pyplot as plt plt.imshow(digit, cmap=plt.cm.binary) plt.show()[3] ''' 使用tensorflow.Keras搭建一個有效識別圖案的神經網絡, 1.layers:表示神經網絡中的一個數據處理層。(dense:全連接層) 2.models.Sequential():表示把每一個數據處理層串聯起來. 3.layers.Dense(…):構造一個數據處理層。 4.input_shape(28*28,):表示當前處理層接收的數據格式必須是長和寬都是28的二維數組, 后面的“,“表示數組里面的每一個元素到底包含多少個數字都沒有關系. ''' from tensorflow.keras import models from tensorflow.keras import layersnetwork = models.Sequential() #add往網絡里添加層 network.add(layers.Dense(512, activation='relu', input_shape=(28*28,))) #512是輸出結點個數 network.add(layers.Dense(10, activation='softmax')) #因為只能輸出0-9,是10種情況,所以輸出只能是10network.compile(optimizer='rmsprop', loss='categorical_crossentropy',metrics=['accuracy'])[4] ''' 在把數據輸入到網絡模型之前,把數據做歸一化處理: 1.reshape(60000, 28*28):train_images數組原來含有60000個元素,每個元素是一個28行,28列的二維數組, 現在把每個二維數組轉變為一個含有28*28個元素的一維數組. 2.由于數字圖案是一個灰度圖,圖片中每個像素點值的大小范圍在0到255之間. 3.train_images.astype(“float32”)/255 把每個像素點的值從范圍0-255轉變為范圍在0-1之間的浮點值。(歸一化) ''' train_images = train_images.reshape((60000, 28*28)) train_images = train_images.astype('float32') / 255test_images = test_images.reshape((10000, 28*28)) test_images = test_images.astype('float32') / 255''' 把圖片對應的標記也做一個更改: 目前所有圖片的數字圖案對應的是0到9。 例如test_images[0]對應的是數字7的手寫圖案,那么其對應的標記test_labels[0]的值就是7。 我們需要把數值7變成一個含有10個元素的數組,然后在第7個元素設置為1,其他元素設置為0。 例如test_lables[0] 的值由7轉變為數組[0,0,0,0,0,0,0,1,0,0,] ''' from tensorflow.keras.utils import to_categorical print("before change:" ,test_labels[0]) train_labels = to_categorical(train_labels) test_labels = to_categorical(test_labels) print("after change: ", test_labels[0])[5] ''' 把數據輸入網絡進行訓練: train_images:用于訓練的手寫數字圖片; train_labels:對應的是圖片的標記; batch_size:每次網絡從輸入的圖片數組中隨機選取128個作為一組進行計算。 epochs:每次計算的循環是五次 ''' network.fit(train_images, train_labels, epochs=5, batch_size = 128)[6] ''' 測試數據輸入,檢驗網絡學習后的圖片識別效果. 識別效果與硬件有關(CPU/GPU). ''' test_loss, test_acc = network.evaluate(test_images, test_labels, verbose=1) print(test_loss) print('test_acc', test_acc)[7] ''' 輸入一張手寫數字圖片到網絡中,看看它的識別效果 ''' (train_images, train_labels), (test_images, test_labels) = mnist.load_data() digit = test_images[1] plt.imshow(digit, cmap=plt.cm.binary) plt.show() test_images = test_images.reshape((10000, 28*28)) res = network.predict(test_images)for i in range(res[1].shape[0]):if (res[1][i] == 1):print("the number for the picture is : ", i)break實現結果:
總結
以上是生活随笔為你收集整理的keras框架实现手写数字识别的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 运行keras出现 FutureWarn
- 下一篇: Tensorflow入门神经网络代码框架