CV算法复现(分类算法4/6):GoogLeNet(2014年 谷歌)
生活随笔
收集整理的這篇文章主要介紹了
CV算法复现(分类算法4/6):GoogLeNet(2014年 谷歌)
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
致謝:霹靂吧啦Wz:https://space.bilibili.com/18161609
目錄
致謝:霹靂吧啦Wz:https://space.bilibili.com/18161609
1 本次要點
1.1 pytorch框架語法
2 網絡簡介
2.1 歷史意義
2.2 網絡亮點
2.3 題外話
2.3 網絡結構
3 代碼結構
3.1 model.py
3.2 train.py
3.3 predict.py
1 本次要點
1.1 pytorch框架語法
- nn.MaxPool2d()中參數ceil_mode=True:向上取整
- 如果要忽略pth權重文件中網絡的部分參數(即網絡結構中有些結構沒有或不需要了,但訓練得到的pth中有),則在加載pth時調用load_state_dict()時,參數?strict設為False
2 網絡簡介
2.1 歷史意義
- GoogLeNet在2014年由Google團隊提出,斬獲當年ImageNet競賽中Classification Task (分類任務) 第一名。(VGG是第2名)
2.2 網絡亮點
- 引入了Inception 結構(融合不同 尺度的 特征信息)
- 使用1x1卷積核進行降維以及映射處理
- 添加兩個輔助分類器幫助訓練(推理時,輔助分類器會刪除!)
- 增加兩個輔助分類器的作用:
- 增加低層網絡的分類能力
- 可阻止網絡中間部分梯度消失
- 增加正則化(即網絡總的損失函數有網絡中層的影響,一定程度增加了正則化)
- 正則化作用:提高模型的泛化能力,避免過擬合。
- 正則化方法:在損失函數中加入正則化項(相等于加個“掛墜”,防止亂動)(如標簽平滑,旨在阻止網絡對某一類別過分自信)、dropout、早停、數據增強。
- 增加兩個輔助分類器的作用:
- 推理階段只使用1個全連接層(使用平均池化層代替,大大減少模型參數)
2.3 題外話
- GoogleNet參數(700萬)僅VGG 的。
- GoogleNet一共有4代,Inception v1,Inception v2,Inception v3,Inception v4,后續基本圍繞Inception?module結構改進。
- 當然,GoogleNet結構復雜,且有兩個輔助分類器,搭建和訓練麻煩,導致后續VGG被應用的更多。
2.3 網絡結構
- 上圖中depth值,指的是該結構連續有多個。
- reduce是降維意思,#3x3 reduce指的是3x3卷積層前的那個1x1卷積層,#5x5 reduce指的是5x5卷積層前的那個1x1卷積層。
3 代碼結構
- train.py
- model.py
- predict.py
3.1 model.py
import torch.nn as nn
import torch
import torch.nn.functional as F# aux_logits:是否適用輔助分類器
class GoogLeNet(nn.Module):def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):super(GoogLeNet, self).__init__()self.aux_logits = aux_logitsself.conv1 = BasicConv2d(3, 64, kersel_size=7, stride=2, padding=3)self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) # ceil_mode=True:向上取整# LocalRespNorm層原論文中在此有,但并沒什么幫助,可以不用。# nn.LocalResponseNorm()self.conv2 = BasicConv2d(64, 64, kernel_size=1)self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)#輔助分類器if aux_logits:self.aux1 = InceptionAux(512, num_classes)self.aux2 = InceptionAux(528, num_classes)#通過自適應平均池化,無論輸入圖像維度多少,都在此得到高為1寬為1的特征矩陣。self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.dropout = nn.Dropout(0.4)self.fc = nn.Linear(1024, num_classes)if init_weights:self._initialize_weights()def forward(self, x):# n x 3 x 224 x 224x = self.conv1(x) # n x 64 x 112 x 112x = self.maxpool1(x)# n x 64 x 56 x 56x = self.conv2(x) # n x 64 x 56 x 56x = self.conv3(x)# n x 192 x 56 x 56x = self.maxpool2(x)# n x 192 x 28 x 28x = self.inception3a(x)# n x 256 x 28 x 28x = self.inception3b(x)# n x 480 x 28 x 28x = self.maxpool3(x)# n x 480 x 14 x 14x = self.inception4a(x)# n x 512 x 14 x 14# 在訓練模式才使用輔助分類器 且 確認訓練模式使用它# 因為在測試階段,輔助分類器并不需要使用(精度沒有主分類器高)# 注意:該變量由net.train()和net.eval()自動控制if self.training and self.aux_logits:aux1 = self.aux1(x)x = self.inception4b(x)# n x 512 x 14 x 14x = self.inception4c(x)# n x 512 x 14 x 14x = self.inception4d(x)# n x 528 x 14 x 14if self.training and self.aux_logits:aux2 = self.aux2(x)x = self.inception4e(x)# n x 832 x 14 x 14x = self.maxpool4(x)# n x 832 x 7 x 7x = self.inception5a(x)# n x 832 x 7 x 7x = self.inception5b(x)# n x 1024 x 7 x 7x = self.avgpool(x)# n x 1024 x 1 x 1x = torch.flatten(x, 1)# n x 1024x = self.dropout(x)x = self.fc(x)# n x 1000(num_classes)#如果使用輔助分類器,則返回3個分類器結果。if self.training and self.aux_logits:return x, aux2, aux1return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)#要確保每個分支的輸出高寬相等,這樣才能在通道方向拼接。
class Inception(nn.Module):def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj)super(Inception, self).__init__()self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)self.branch2 = nn.Sequential(BasicConv2d(in_channels, ch3x3red, kernel_size=1),BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) # 因為kernel_size=3,所以使用padding=1,使得輸出大小等于輸入)self.branch3 = nn.Sequential(BasicConv2d(in_channels, ch5x5red, kernel_size=1),BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2) # 因為kernel_size=5,所以使用padding=2,使得輸出大小等于輸入)self.branch4 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=1, padding=1),# 為了保證輸出輸出大小一致,需要設stride=1, padding=1BasicConv2d(in_channels, pool_proj, kernel_size=1))def forward(self, x):branch1 = self.branch1(x)branch2 = self.branch2(x)branch3 = self.branch3(x)branch4 = self.branch4(x)outputs = [branch1, branch2, branch3, branch4]return torch.cat(outputs, 1) #在channel維度合并,即outputs的第1個維度。(B,C,H,W)#輔助分類器
class InceptionAux(nn.Module):def __init__(self, in_channels, num_classes):super(InceptionAux, self).__init__()self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)self.conv = BasicConv2d(in_channels, 128, kernel_size=1)self.fc1 = nn.Linear(2048, 1024)self.fc2 = nn.Linear(1024, num_classes)def forward(self, x):# aux1: N*512*14*14, aux2: N*528*14*14 x = self.averagePool(x)# aux1: N*512*4*4, aux2: N*528*4*4 x = self.conv(x)# N * 128 * 4 * 4x = torch.flatten(x, 1) #展平第1維度(即Channel維度)x = F.dropout(x, 0.5, training=self.training)# N * 2047x = F.relu(self.fc1(x), inplace=True)x = F.dropout(x, 0.5, training=self.training)# N * 1024x = self.fc2(x)# N * num_classesreturn xclass BasicConv2d(nn.Module):def __init__(self, in_channels, out_channels, **kwargs):super(BasicConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)x = self.relu(x)return x
3.2 train.py
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torchvision
import json
import matplotlib.pyplot as plt
import os
import torch.optim as optim
from model import GoogLeNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data") # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images fot validation.".format(train_num,val_num))# test_data_iter = iter(validate_loader)# test_image, test_label = test_data_iter.next()# net = torchvision.models.googlenet(num_classes=5)# model_dict = net.state_dict()# pretrain_model = torch.load("googlenet.pth")# del_list = ["aux1.fc2.weight", "aux1.fc2.bias",# "aux2.fc2.weight", "aux2.fc2.bias",# "fc.weight", "fc.bias"]# pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list}# model_dict.update(pretrain_dict)# net.load_state_dict(model_dict)net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)net.to(device)loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0003)best_acc = 0.0save_path = './googleNet.pth'for epoch in range(30):# trainnet.train() #running_loss = 0.0for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad()logits, aux_logits2, aux_logits1 = net(images.to(device))loss0 = loss_function(logits, labels.to(device))loss1 = loss_function(aux_logits1, labels.to(device))loss2 = loss_function(aux_logits2, labels.to(device))loss = loss0 + loss1 * 0.3 + loss2 * 0.3loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()# print train processrate = (step + 1) / len(train_loader)a = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")print()# validatenet.eval()acc = 0.0 # accumulate accurate number / epochwith torch.no_grad():for val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device)) # eval model only have last output layerpredict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == val_labels.to(device)).sum().item()val_accurate = acc / val_numif val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')if __name__ == '__main__':main()
3.3 predict.py
import torch
from model import GoogLeNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import jsondata_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load image
img = Image.open("../tulip.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)# read class_indict
try:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)# create model
model = GoogLeNet(num_classes=5, aux_logits=False)
# load model weights
model_weight_path = "./googleNet.pth" #加載模型
#輔助分類的器權重也保存在pth中,但預測時會屏蔽輔助分類器的結構,也就不需要加載這些權重。
#方法是model.load_state_dict()中參數strict設為False,即不精準匹配模型參數。
missing_keys, unexpected_keys = model.load_state_dict(torch.load(model_weight_path), strict=False)model.eval()
with torch.no_grad():# predict classoutput = torch.squeeze(model(img))predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)])
plt.show()
?
總結
以上是生活随笔為你收集整理的CV算法复现(分类算法4/6):GoogLeNet(2014年 谷歌)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: CV算法复现(分类算法3/6):VGG(
- 下一篇: CV算法复现(分类算法5/6):ResN