模型微调技术
模型微調(diào)
- 一、遷移學(xué)習(xí)中的常見技巧:微調(diào)(fine-tuning)
- 1.1 概念
- 1.2 步驟
- 1.3 訓(xùn)練
- 1.4 實(shí)現(xiàn)
一、遷移學(xué)習(xí)中的常見技巧:微調(diào)(fine-tuning)
1.1 概念
1.微調(diào)所有層;
2.固定網(wǎng)絡(luò)前面幾層權(quán)重,只微調(diào)網(wǎng)絡(luò)的后面幾層,這樣做有兩個(gè)原因:A. 避免因數(shù)據(jù)量小造成過擬合現(xiàn)象;B.CNN前幾層的特征中包含更多的一般特征(比如,邊緣信息,色彩信息等),這對(duì)許多任務(wù)來(lái)說(shuō)是非常通用的,但是CNN后面幾層的特征學(xué)習(xí)注重高層特征,也就是語(yǔ)義特征,這是針對(duì)于數(shù)據(jù)集而言的,不同的數(shù)據(jù)集后面幾層學(xué)習(xí)的語(yǔ)義特征也是完全不同的;
1.2 步驟
1.3 訓(xùn)練
- 源數(shù)據(jù)集遠(yuǎn)復(fù)雜于目標(biāo)數(shù)據(jù),通常微調(diào)效果更好;
- 通常使用更小的學(xué)習(xí)率和更少的數(shù)據(jù)迭代;
1.4 實(shí)現(xiàn)
#熱狗識(shí)別 #導(dǎo)入所需包 from d2l import torch as d2l from torch import nn import torchvision import torch import os %matplotlib inline #獲取數(shù)據(jù)集 """ 我們使用的熱狗數(shù)據(jù)集來(lái)源于網(wǎng)絡(luò)。 該數(shù)據(jù)集包含1400張熱狗的“正類”圖像,以及包含盡可能多的其他食物的“負(fù)類”圖像。 含著兩個(gè)類別的1000張圖片用于訓(xùn)練,其余的則用于測(cè)試。 """ d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')data_dir = d2l.download_extract('hotdog') print(data_dir) #輸出..\data\hotdog train_imgs=torchvision.datasets.ImageFolder(os.path.join(data_dir,'train')) test_imgs=torchvision.datasets.ImageFolder(os.path.join(data_dir,'test')) hotdogs=[train_imgs[i][0] for i in range(8)] not_hotdogs=[train_imgs[-i-1][0] for i in range(8)] d2l.show_images(hotdogs+not_hotdogs,2,8,scale=1.4) # 使用RGB通道的均值和標(biāo)準(zhǔn)差,以標(biāo)準(zhǔn)化每個(gè)通道 """ 在訓(xùn)練期間,我們首先從圖像中裁切隨機(jī)大小和隨機(jī)長(zhǎng)寬比的區(qū)域,然后將該區(qū)域縮放為\(224*224\)輸入圖像。 在測(cè)試過程中,我們將圖像的高度和寬度都縮放到256像素,然后裁剪中央\(224*224\)區(qū)域作為輸入。 此外,對(duì)于RGB(紅、綠和藍(lán))顏色通道,我們分別標(biāo)準(zhǔn)化每個(gè)通道。 具體而言,該通道的每個(gè)值減去該通道的平均值,然后將結(jié)果除以該通道的標(biāo)準(zhǔn)差。 """ normalize=torchvision.transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) train_augs=torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),#隨機(jī)裁剪,并resize成224torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize]) test_augs=torchvision.transforms.Compose([torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),#將圖片從中心裁剪成224*224torchvision.transforms.ToTensor(),normalize]) #我們使用在ImageNet數(shù)據(jù)集上預(yù)訓(xùn)練的ResNet-18作為源模型。 在這里,我們指定pretrained=True以自動(dòng)下載預(yù)訓(xùn)練的模型參數(shù)。 #如果你首次使用此模型,則需要連接互聯(lián)網(wǎng)才能下載。 pretrained_net=torchvision.models.resnet18(pretrained=True) """ 預(yù)訓(xùn)練的源模型實(shí)例包含許多特征層和一個(gè)輸出層fc(全連接層)。 此劃分的主要目的是促進(jìn)對(duì)除輸出層以外所有層的模型參數(shù)進(jìn)行微調(diào)。 下面給出了源模型的成員變量fc。 """ pretrained_net.fc #輸出 #Linear(in_features=512, out_features=1000, bias=True) finetune_net=torchvision.models.resnet18(pretrained=True) finetune_net.fc=nn.Linear(finetune_net.fc.in_features,2)#全連接層的輸入神經(jīng)元數(shù)量是特征數(shù)量,因?yàn)槭?分類,所以輸出是2 nn.init.xavier_uniform_(finetune_net.fc.weight)#隨機(jī)初始化全連接層權(quán)重 #Parameter containing: tensor([[ 0.0378, 0.0630, -0.0080, ..., -0.0220, -0.0511, 0.0959],[ 0.0556, 0.0227, -0.0262, ..., -0.1059, -0.0171, 0.0051]],requires_grad=True) #微調(diào)模型 # 如果param_group=True,輸出層中的模型參數(shù)將使用十倍的學(xué)習(xí)率 def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=train_augs),batch_size=batch_size, shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs),batch_size=batch_size)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction="none")if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]trainer = torch.optim.SGD([{'params': params_1x},{'params': net.fc.parameters(),'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices) train_fine_tuning(finetune_net, 5e-5) #為了進(jìn)行比較,我們定義了一個(gè)相同的模型,但是將其所有模型參數(shù)初始化為隨機(jī)值。 #由于整個(gè)模型需要從頭開始訓(xùn)練,因此我們需要使用更大的學(xué)習(xí)率。 scratch_net = torchvision.models.resnet18() scratch_net.fc = nn.Linear(scratch_net.fc.in_features, 2) train_fine_tuning(scratch_net, 5e-4, param_group=False)- 參考網(wǎng)址:https://zh-v2.d2l.ai/chapter_computer-vision/fine-tuning.html
總結(jié)
- 上一篇: panabit之MAC管控
- 下一篇: java聊天室类图怎么画,UML课程设计