【项目实战】pytorch载入训练好的模型并进行可视化模型预测绘图
生活随笔
收集整理的這篇文章主要介紹了
【项目实战】pytorch载入训练好的模型并进行可视化模型预测绘图
小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
main函數(shù)載入模型,加載圖片,輸出結(jié)果:
if __name__ == '__main__':image = Image.open(r"C:\Users\pic\test\he_5.jpg")image =transform(image).unsqueeze(0)modelme = torch.load('modefresnet.pkl')modelme.eval() #表示將模型轉(zhuǎn)變?yōu)閑valuation(測試)模式,這樣就可以排除BN和Dropout對測試的干擾。visualize_model(modelme)outputs = modelme(image)_, predict = torch.max(outputs.data, 1)for j in range(image.size()[0]):print('predicted: {}'.format(class_names[predict[j]]))對圖片的統(tǒng)一處理transform:
transform=transforms.Compose([transforms.Resize(224),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])對于預(yù)測結(jié)果進(jìn)行可視化的函數(shù):
def visualize_model(model, num_images=6):was_training = model.trainingmodel.eval()images_so_far = 0fig = plt.figure()with torch.no_grad():#for i, (inputs, labels) in enumerate(dataloaders['val']):for i, (inputs, labels) in enumerate(testloder):outputs = model(inputs)_, preds = torch.max(outputs, 1)for j in range(inputs.size()[0]):images_so_far += 1ax = plt.subplot(num_images // 2, 2, images_so_far)ax.axis('off')ax.set_title('predicted: {}'.format(class_names[preds[j]]))imshow(inputs.cpu().data[j])if images_so_far == num_images:model.train(mode=was_training)plt.show()returnmodel.train(mode=was_training)載入一新的圖片數(shù)據(jù)集:
data_dir =os.getcwd() + '\\data\\' dataloadertest =datasets.ImageFolder(os.path.join(data_dir, "tt"),transform=transforms.Compose([transforms.Resize(224),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]) ) testloder = torch.utils.data.DataLoader(dataloadertest,batch_size = 4,shuffle = True)目錄結(jié)構(gòu):
其中要注意傳入的圖片的預(yù)處理:
image = Image.open(r"C:\Users\pic\test\he_5.jpg")
image =transform(image).unsqueeze(0)
需為PIL格式,且需先進(jìn)行轉(zhuǎn)化才能傳入模型。
結(jié)果:
經(jīng)測試之后不論是傳入單張圖片還是一個(gè)新數(shù)據(jù)集結(jié)果均符合預(yù)期。
總結(jié)
以上是生活随笔為你收集整理的【项目实战】pytorch载入训练好的模型并进行可视化模型预测绘图的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【pytorch】pytorch自定义训
- 下一篇: 【后端过程记录】用flask搭建服务器作