當前位置:
首頁 >
PyTorch基础-猫狗分类实战-10
發布時間:2024/9/15
43
豆豆
生活随笔
收集整理的這篇文章主要介紹了
PyTorch基础-猫狗分类实战-10
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
訓練模型并保存
import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets,transforms,models from torch.utils.data import Dataset import sys # 數據預處理 transform = transforms.Compose([transforms.RandomResizedCrop(224),# 對圖像進行隨機裁剪transforms.RandomRotation(20),# 隨機旋轉角度transforms.RandomHorizontalFlip(p=0.5),# 隨機水平翻轉transforms.ToTensor()# 變成tensor格式 ]) # 數據增強# 讀取數據 root = "image" train_dataset = datasets.ImageFolder(root + "/train",transform) test_dataset = datasets.ImageFolder(root + "/test",transform)# 導入數據 train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=8,shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=8,shuffle=True) classes = train_dataset.classes classes_index = train_dataset.class_to_idx print(classes) print(classes_index) model = models.vgg16(pretrained=True)# 載入vgg16預訓練模型 print(model) for param in model.parameters():param.requires_grad = False # 構建新的全連接層 model.classifier = torch.nn.Sequential(torch.nn.Linear(25088,100),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(100,2)) LR = 0.0003 # 定義代價函數 entropy_loss = nn.CrossEntropyLoss() # 定義優化器 optimizer = optim.Adam(model.parameters(),LR) def train():model.train()for i,data in enumerate(train_loader):# 獲得數據和對應的標簽inputs,labels = data# 獲得模型預測結果(64,10)out = model(inputs)# 交叉熵代價函數out(batch.C),labels(batch)loss = entropy_loss(out,labels)# 梯度清零optimizer.zero_grad()# 計算梯度loss.backward()# 修改權值optimizer.step()def test():model.eval()correct = 0for i,data in enumerate(test_loader):# 獲得數據和對應的標簽inputs,labels = data# 獲得模型預測結果out = model(inputs)# 獲得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 預測正確的數量correct += (predicted == labels).sum()print("test acc:{0}".format(correct.item()/len(test_dataset)))correct = 0for i,data in enumerate(train_loader):# 獲得數據和對應的標簽inputs,labels = data# 獲得模型預測結果out = model(inputs)# 獲得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 預測正確的數量correct += (predicted == labels).sum()print("train acc:{0}".format(correct.item()/len(train_dataset))) for epoch in range(5):print("epoch:",epoch)train()test() torch.save(model.state_dict(),"cat_dog.pth") # 保存模型加載模型進行預測
import torch import numpy as np from PIL import Image from torchvision import transforms,models model = models.vgg16(pretrained=True) # 構建新的全連接層 model.classifier = torch.nn.Sequential(torch.nn.Linear(25088,100),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(100,2)) model.load_state_dict(torch.load("cat_dog.pth")) # 加載模型 model.eval() # 預測模式 label = np.array(["cat","dog"]) # 數據預處理 transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor() ]) # 預測函數 def predict(image_path):# 打開圖片img = Image.open(image_path)# 數據處理,增加一個維度img = transform(img).unsqueeze(0)# 預測得到的結果outputs = model(img)# 獲得最大值所在位置_,predicted = torch.max(outputs,1)# 轉換為類別名稱print(label[predicted.item()]) predict("image/test/cat/cat.1490.jpg")總結
以上是生活随笔為你收集整理的PyTorch基础-猫狗分类实战-10的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: PyTorch基础-模型的保存和加载-0
- 下一篇: 机器学习基础-一元线性回归-01