日韩av黄I国产麻豆传媒I国产91av视频在线观看I日韩一区二区三区在线看I美女国产在线I麻豆视频国产在线观看I成人黄色短片

歡迎訪問 生活随笔!

生活随笔

當(dāng)前位置: 首頁 >

图像识别python cnn_MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(一)...

發(fā)布時(shí)間:2025/4/5 45 豆豆
生活随笔 收集整理的這篇文章主要介紹了 图像识别python cnn_MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(一)... 小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.

版權(quán)聲明:本文為博主原創(chuàng)文章,歡迎轉(zhuǎn)載,并請注明出處。聯(lián)系方式:460356155@qq.com

全連接神經(jīng)網(wǎng)絡(luò)是深度學(xué)習(xí)的基礎(chǔ),理解它就可以掌握深度學(xué)習(xí)的核心概念:前向傳播、反向誤差傳遞、權(quán)重、學(xué)習(xí)率等。這里先用python創(chuàng)建模型,用minist作為數(shù)據(jù)集進(jìn)行訓(xùn)練。

定義3層神經(jīng)網(wǎng)絡(luò):輸入層節(jié)點(diǎn)28*28(對應(yīng)minist圖片像素?cái)?shù))、隱藏層節(jié)點(diǎn)300、輸出層節(jié)點(diǎn)10(對應(yīng)0-9個(gè)數(shù)字)。

網(wǎng)絡(luò)的激活函數(shù)采用sigmoid,網(wǎng)絡(luò)權(quán)重的初始化采用正態(tài)分布。

完整代碼如下:

1 #-*- coding:utf-8 -*-

2

3 u"""全連接神經(jīng)網(wǎng)絡(luò)訓(xùn)練學(xué)習(xí)MINIST"""

4

5 __author__ = 'zhengbiqing 460356155@qq.com'

6

7

8 importnumpy9 importscipy.special10 importscipy.misc11 from PIL importImage12 importmatplotlib.pyplot13 importpylab14 importdatetime15 from random importshuffle16

17

18 #是否訓(xùn)練網(wǎng)絡(luò)

19 LEARN =True20

21 #是否保存網(wǎng)絡(luò)

22 SAVE_PARA =False23

24 #網(wǎng)絡(luò)節(jié)點(diǎn)數(shù)

25 INPUT = 784

26 HIDDEN = 300

27 OUTPUT = 10

28

29 #學(xué)習(xí)率和訓(xùn)練次數(shù)

30 LR = 0.05

31 EPOCH = 10

32

33 #訓(xùn)練數(shù)據(jù)集文件

34 TRAIN_FILE = 'mnist_train.csv'

35 TEST_FILE = 'mnist_test.csv'

36

37 #網(wǎng)絡(luò)保存文件名

38 WEIGHT_IH = "minist_fc_wih.npy"

39 WEIGHT_HO = "minist_fc_who.npy"

40

41

42 #神經(jīng)網(wǎng)絡(luò)定義

43 classNeuralNetwork:44 def __init__(self, inport_nodes, hidden_nodes, output_nodes, learnning_rate):45 #神經(jīng)網(wǎng)絡(luò)輸入層、隱藏層、輸出層節(jié)點(diǎn)數(shù)

46 self.inodes =inport_nodes47 self.hnodes =hidden_nodes48 self.onodes =output_nodes49

50 #神經(jīng)網(wǎng)絡(luò)訓(xùn)練學(xué)習(xí)率

51 self.learnning_rate =learnning_rate52

53 #用均值為0,標(biāo)準(zhǔn)方差為連接數(shù)的-0.5次方的正態(tài)分布初始化權(quán)重

54 #權(quán)重矩陣行列分別為hidden * input、 output * hidden,和ih、ho相反

55 self.wih = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))56 self.who = numpy.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))57

58 #sigmoid函數(shù)為激活函數(shù)

59 self.active_fun = lambdax: scipy.special.expit(x)60

61 #設(shè)置神經(jīng)網(wǎng)絡(luò)權(quán)重,在加載已訓(xùn)練的權(quán)重時(shí)調(diào)用

62 defset_weight(self, wih, who):63 self.wih =wih64 self.who =who65

66 #前向傳播,根據(jù)輸入得到輸出

67 defget_outputs(self, input_list):68 #把list轉(zhuǎn)換為N * 1的矩陣,ndmin=2二維,T轉(zhuǎn)制

69 inputs = numpy.array(input_list, ndmin=2).T70

71 #隱藏層輸入 = W dot X,矩陣乘法

72 hidden_inputs =numpy.dot(self.wih, inputs)73 hidden_outputs =self.active_fun(hidden_inputs)74

75 final_inputs =numpy.dot(self.who, hidden_outputs)76 final_outputs =self.active_fun(final_inputs)77

78 returninputs, hidden_outputs, final_outputs79

80 #網(wǎng)絡(luò)訓(xùn)練,誤差計(jì)算,誤差反向分配更新網(wǎng)絡(luò)權(quán)重

81 deftrain(self, input_list, target_list):82 inputs, hidden_outputs, final_outputs =self.get_outputs(input_list)83

84 targets = numpy.array(target_list, ndmin=2).T85

86 #誤差計(jì)算

87 output_errors = targets -final_outputs88 hidden_errors =numpy.dot(self.who.T, output_errors)89

90 #連接權(quán)重更新

91 self.who += numpy.dot(self.learnning_rate * output_errors * final_outputs * (1 -final_outputs), hidden_outputs.T)92 self.wih += numpy.dot(self.learnning_rate * hidden_errors * hidden_outputs * (1 -hidden_outputs), inputs.T)93

94

95 #圖像像素值變換

96 defvals2input(vals):97 #[0,255]的圖像像素值轉(zhuǎn)換為i[0.01,1],以便sigmoid函數(shù)作非線性變換

98 return (numpy.asfarray(vals) / 255.0 * 0.99) + 0.01

99

100

101 '''

102 訓(xùn)練網(wǎng)絡(luò)103 train:是否訓(xùn)練網(wǎng)絡(luò),如果不訓(xùn)練則直接加載已訓(xùn)練得到的網(wǎng)絡(luò)權(quán)重104 epoch:訓(xùn)練次數(shù)105 save:是否保存訓(xùn)練結(jié)果,即網(wǎng)絡(luò)權(quán)重106 '''

107 defnet_train(train, epochs, save):108 iftrain:109 with open(TRAIN_FILE, 'r') as train_file:110 train_list =train_file.readlines()111

112 for epoch inrange(epochs):113 #打亂訓(xùn)練數(shù)據(jù)

114 shuffle(train_list)115

116 for data intrain_list:117 all_vals = data.split(',')118 #圖像數(shù)據(jù)為0~255,轉(zhuǎn)換到0.01~1區(qū)間,以便激活函數(shù)更有效

119 inputs = vals2input(all_vals[1:])120

121 #標(biāo)簽,正確的為0.99,其他為0.01

122 targets = numpy.zeros(OUTPUT) + 0.01

123 targets[int(all_vals[0])] = 0.99

124

125 net.train(inputs, targets)126

127 #每個(gè)epoch結(jié)束后用測試集檢查識別準(zhǔn)確度

128 net_test(epoch)129 print('')130

131 ifsave:132 #保存連接權(quán)重

133 numpy.save(WEIGHT_IH, net.wih)134 numpy.save(WEIGHT_HO, net.who)135 else:136 #不訓(xùn)練直接加載已保存的權(quán)重

137 wih =numpy.load(WEIGHT_IH)138 who =numpy.load(WEIGHT_HO)139 net.set_weight(wih, who)140

141

142 '''

143 用測試集檢查準(zhǔn)確率144 '''

145 defnet_test(epoch):146 with open(TEST_FILE, 'r') as test_file:147 test_list =test_file.readlines()148

149 ok =0150 errlist = [0] * 10

151

152 for data intest_list:153 all_vals = data.split(',')154 inputs = vals2input(all_vals[1:])155 _, _, net_out =net.get_outputs(inputs)156

157 max =numpy.argmax(net_out)158 if max ==int(all_vals[0]):159 ok += 1

160 else:161 #識別錯(cuò)誤統(tǒng)計(jì),每個(gè)數(shù)字識別錯(cuò)誤計(jì)數(shù)

162 #print('target:', all_vals[0], 'net_out:', max)

163 errlist[int(all_vals[0])] += 1

164

165 print('EPOCH: {epoch} score: {score}'.format(epoch=epoch, score = ok / len(test_list) * 100))166 print('error list:', errlist, 'total:', sum(errlist))167

168

169 #變換圖片的尺寸,保存變換后的圖片

170 defresize_img(filein, fileout, width, height, type):171 img =Image.open(filein)172 out =img.resize((width, height), Image.ANTIALIAS)173 out.save(fileout, type)174

175

176 #用訓(xùn)練得到的網(wǎng)絡(luò)識別一個(gè)圖片文件

177 defimg_test(img_file):178 file_name_list = img_file.split('.')179 file_name, file_type = file_name_list[0], file_name_list[1]180 out_file = file_name + 'out' + '.' +file_type181 resize_img(img_file, out_file, 28, 28, file_type)182

183 img_array = scipy.misc.imread(out_file, flatten=True)184 img_data = 255.0 - img_array.reshape(784)185 img_data = (img_data / 255.0 * 0.99) + 0.01

186

187 _, _, net_out =net.get_outputs(img_data)188 max =numpy.argmax(net_out)189 print('pic recognized as:', max)190

191

192 #顯示數(shù)據(jù)集某個(gè)索引對應(yīng)的圖片

193 defimg_show(train, index):194 file = TRAIN_FILE if train elseTEST_FILE195 with open(file, 'r') as test_file:196 test_list =test_file.readlines()197

198 all_values = test_list[index].split(',')199 print('number is:', all_values[0])200

201 image_array = numpy.asfarray(all_values[1:]).reshape((28, 28))202 matplotlib.pyplot.imshow(image_array, cmap='Greys', interpolation='None')203 pylab.show()204

205

206 start_time =datetime.datetime.now()207

208 net =NeuralNetwork(INPUT, HIDDEN, OUTPUT, LR)209 net_train(LEARN, EPOCH, SAVE_PARA)210

211 if notLEARN:212 net_test(0)213 else:214 print('MINIST FC Train:', INPUT, HIDDEN, OUTPUT, 'LR:', LR, 'EPOCH:', EPOCH)215 print('train spend time:', datetime.datetime.now() -start_time)216

217 #用畫圖軟件創(chuàng)建圖片文件,由得到的網(wǎng)絡(luò)進(jìn)行識別

218 #img_test('t9.png')

219

220 #顯示minist中的某個(gè)圖片

221 #img_show(True, 1)

784-300-10簡單的全連接神經(jīng)網(wǎng)絡(luò)訓(xùn)練結(jié)果準(zhǔn)確率基本在97.7%左右,運(yùn)行結(jié)果如下:

EPOCH: 0 score: 95.96000000000001

error list:? [13, 21, 31, 28, 51, 61, 33, 66, 44, 56]? total:? 404

EPOCH: 1 score: 96.77

error list:? [15, 19, 27, 63, 37, 37, 21, 40, 18, 46]? total:? 323

EPOCH: 2 score: 97.25

error list:? [9, 17, 26, 26, 24, 56, 21, 41, 22, 33]? total:? 275

EPOCH: 3 score: 97.82

error list:? [9, 16, 21, 18, 20, 18, 22, 21, 31, 42]? total:? 218

EPOCH: 4 score: 97.54

error list:? [12, 23, 17, 25, 15, 34, 19, 25, 22, 54]? total:? 246

EPOCH: 5 score: 97.78999999999999

error list:? [10, 16, 20, 23, 21, 32, 18, 31, 26, 24]? total:? 221

EPOCH: 6 score: 97.6

error list:? [9, 13, 26, 34, 27, 26, 20, 28, 22, 35]? total:? 240

EPOCH: 7 score: 97.74000000000001

error list:? [12, 8, 26, 29, 27, 26, 25, 20, 27, 26]? total:? 226

EPOCH: 8 score: 97.77

error list:? [7, 10, 27, 16, 29, 28, 23, 29, 26, 28]? total:? 223

EPOCH: 9 score: 97.99

error list:? [11, 10, 32, 17, 18, 24, 14, 22, 21, 32]? total:? 201

MINIST FC Train: 784 300 10 LR: 0.05 EPOCH: 10

train spend time:? 0:05:54.137925

Process finished with exit code 0

總結(jié)

以上是生活随笔為你收集整理的图像识别python cnn_MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(一)...的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。

如果覺得生活随笔網(wǎng)站內(nèi)容還不錯(cuò),歡迎將生活随笔推薦給好友。