【项目实战】pytorch载入训练好的模型并进行可视化模型预测绘图
生活随笔
收集整理的這篇文章主要介紹了
【项目实战】pytorch载入训练好的模型并进行可视化模型预测绘图
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
main函數載入模型,加載圖片,輸出結果:
if __name__ == '__main__':image = Image.open(r"C:\Users\pic\test\he_5.jpg")image =transform(image).unsqueeze(0)modelme = torch.load('modefresnet.pkl')modelme.eval() #表示將模型轉變為evaluation(測試)模式,這樣就可以排除BN和Dropout對測試的干擾。visualize_model(modelme)outputs = modelme(image)_, predict = torch.max(outputs.data, 1)for j in range(image.size()[0]):print('predicted: {}'.format(class_names[predict[j]]))對圖片的統一處理transform:
transform=transforms.Compose([transforms.Resize(224),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])對于預測結果進行可視化的函數:
def visualize_model(model, num_images=6):was_training = model.trainingmodel.eval()images_so_far = 0fig = plt.figure()with torch.no_grad():#for i, (inputs, labels) in enumerate(dataloaders['val']):for i, (inputs, labels) in enumerate(testloder):outputs = model(inputs)_, preds = torch.max(outputs, 1)for j in range(inputs.size()[0]):images_so_far += 1ax = plt.subplot(num_images // 2, 2, images_so_far)ax.axis('off')ax.set_title('predicted: {}'.format(class_names[preds[j]]))imshow(inputs.cpu().data[j])if images_so_far == num_images:model.train(mode=was_training)plt.show()returnmodel.train(mode=was_training)載入一新的圖片數據集:
data_dir =os.getcwd() + '\\data\\' dataloadertest =datasets.ImageFolder(os.path.join(data_dir, "tt"),transform=transforms.Compose([transforms.Resize(224),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]) ) testloder = torch.utils.data.DataLoader(dataloadertest,batch_size = 4,shuffle = True)目錄結構:
其中要注意傳入的圖片的預處理:
image = Image.open(r"C:\Users\pic\test\he_5.jpg")
image =transform(image).unsqueeze(0)
需為PIL格式,且需先進行轉化才能傳入模型。
結果:
經測試之后不論是傳入單張圖片還是一個新數據集結果均符合預期。
總結
以上是生活随笔為你收集整理的【项目实战】pytorch载入训练好的模型并进行可视化模型预测绘图的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【pytorch】pytorch自定义训
- 下一篇: 【后端过程记录】用flask搭建服务器作