从零开始AlignedReID_05
生活随笔
收集整理的這篇文章主要介紹了
从零开始AlignedReID_05
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
模型訓(xùn)練
前言
經(jīng)過前面的內(nèi)容,所有需要的模塊都介紹過了,那么接下來就是模型的訓(xùn)練與測試。
模型的訓(xùn)練可以大致分為以下幾個過程:
- 數(shù)據(jù)預(yù)處理
- 初始化數(shù)據(jù)集
- 數(shù)據(jù)增強
- 導(dǎo)入數(shù)據(jù)
- 模型加載
- 損失函數(shù)加載
- 優(yōu)化器加載
- 模型訓(xùn)練
- 模型測試
- 保存模型
那么是整個實現(xiàn)過程,下面所有代碼均在AlignedReID目錄下,train_models.py文件中。
加載頭文件
#-*-coding:utf-8-*- from __future__ import absolute_importimport os import sys import os.path as osp import time import datetime import argparse import numpy as npimport torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.utils.data import DataLoader import torchvision.transforms as T from torch.optim import lr_scheduler# 導(dǎo)入自己創(chuàng)建的工具 import models from losses.tripletloss import TripletAlignedReIDloss,DeepSupervision from data_process import dataset_manager from data_process.data_loader import ImageDataset from utils.util import AverageMeter,Logger,save_checkpoint from utils.optimizers import init_optim from utils.samplers import RandomIdentitySampler from utils import re_ranking from utils.eval_metrics import evaluate設(shè)置相關(guān)配置項
# 0.設(shè)置一些常見的選項 parser = argparse.ArgumentParser(description='Train AlignedReID with cross entropy loss and triplet hard loss') # Datasets # path parser.add_argument('--root', type=str, default='/home/user/桌面/code/data', help="root path to data directory") # dataset name parser.add_argument('-d', '--dataset', type=str, default='market1501',choices=dataset_manager.get_names()) # 多線程 4個 parser.add_argument('-j', '--workers', default=4, type=int,help="number of data loading workers (default: 4)") # image height parser.add_argument('--height', type=int, default=256,help="height of an image (default: 256)") # image weight parser.add_argument('--width', type=int, default=128,help="width of an image (default: 128)") # split-id 默認(rèn)為0 parser.add_argument('--split-id', type=int, default=0, help="split index")# Optimization options parser.add_argument('--labelsmooth', action='store_true', help="label smooth") # 默認(rèn)使用adam優(yōu)化 parser.add_argument('--optim', type=str, default='adam', help="optimization algorithm (see optimizers.py)") # 總共300epoch parser.add_argument('--max-epoch', default=300, type=int,help="maximum epochs to run") parser.add_argument('--start-epoch', default=0, type=int,help="manual epoch number (useful on restarts)") # batch size parser.add_argument('--train-batch', default=32, type=int,help="train batch size") parser.add_argument('--test-batch', default=32, type=int, help="test batch size") # 學(xué)習(xí)率初始值 parser.add_argument('--lr', '--learning-rate', default=0.0002, type=float,help="initial learning rate") # 步長 parser.add_argument('--stepsize', default=150, type=int,help="stepsize to decay learning rate (>0 means this is enabled)") # 學(xué)習(xí)率衰減系數(shù) parser.add_argument('--gamma', default=0.1, type=float,help="learning rate decay") parser.add_argument('--weight-decay', default=5e-04, type=float,help="weight decay (default: 5e-04)")# triplet hard loss parser.add_argument('--margin', type=float, default=0.3, help="margin for triplet loss") parser.add_argument('--num-instances', type=int, default=4,help="number of instances per identity") parser.add_argument('--htri-only', action='store_true', default=False,help="if this is True, only htri loss is used in training") # Architecture parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names())# Miscs parser.add_argument('--print-freq', type=int, default=10, help="print frequency") parser.add_argument('--seed', type=int, default=1, help="manual seed") parser.add_argument('--resume', type=str, default='', metavar='PATH') parser.add_argument('--evaluate', action='store_true', help="evaluation only") # 每20個epoch執(zhí)行一次測試 parser.add_argument('--eval-step', type=int, default=20,help="run evaluation for every N epochs (set to -1 to test after training)") parser.add_argument('--start-eval', type=int, default=0, help="start to evaluate after specific epoch") parser.add_argument('--save-dir', type=str, default='log') parser.add_argument('--use_cpu', action='store_true', help="use cpu") # 默認(rèn)使用gup-0 parser.add_argument('--gpu-devices', default='0', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES') parser.add_argument('--reranking',action= 'store_true', help= 'result re_ranking')parser.add_argument('--test_distance',type = str, default='global', help= 'test distance type') parser.add_argument('--unaligned',action= 'store_true', help= 'test local feature with unalignment')args = parser.parse_args()主函數(shù)main()
# 主函數(shù) def main():# 判斷是否有GPUuse_gpu = torch.cuda.is_available()# 使用cpu則Gpu關(guān)if args.use_cpu:use_gpu = False# 節(jié)省內(nèi)存pin_memory =True if use_gpu else False# 日志輸出if not args.evaluate:sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))else:sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))print("==========\nArgs:{}\n==========".format(args))# 使用GPU的一些設(shè)置if use_gpu:print("Currently using GPU {}".format(args.gpu_devices))os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devicescudnn.benchmark = True# 確定隨機(jī)初始化seedtorch.cuda.manual_seed_all(args.seed)else:print("Currently using CPU (GPU is highly recommended)")# 1. 數(shù)據(jù)預(yù)處理# 1.1 初始化數(shù)據(jù)集print("Initializing dataset {}".format(args.dataset))dataset = dataset_manager.init_img_dataset(root=args.root, name=args.dataset)# 1.2 data augmentation# 訓(xùn)練集和測試集采用的方式不同# 測試集不需要數(shù)據(jù)增強 僅需要修改圖片格式# 這里把圖片轉(zhuǎn)成的256,128 且為tensortransform_train = T.Compose([T.Resize((args.height,args.width)),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),])transform_test = T.Compose([T.Resize((args.height,args.width)),T.ToTensor(),T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),])# 1.3讀入圖片數(shù)據(jù)trainloader = DataLoader(ImageDataset(dataset.train,transform=transform_train),sampler=RandomIdentitySampler(dataset.train,num_instances=args.num_instances),batch_size=args.train_batch,num_workers=args.workers,pin_memory=pin_memory,# 丟掉不滿足一個batch的數(shù)據(jù)drop_last=True,)queryloader = DataLoader(ImageDataset(dataset.query,transform=transform_test),# shuffle =False 不打亂順序batch_size=args.test_batch,shuffle=False,num_workers=args.workers,pin_memory=pin_memory,drop_last=False,)galleryloader = DataLoader(ImageDataset(dataset.gallery, transform=transform_test),batch_size=args.test_batch, shuffle=False, num_workers=args.workers,pin_memory=pin_memory, drop_last=False,)# 2.加載模型print("Initializing model: {}".format(args.arch))model = models.init_model(name=args.arch,num_classes = dataset.num_train_pids,loss={'softmax','metric'},aligned =True,use_gpu=use_gpu)print("Model size: {:.5f}M".format(sum(p.numel() for p in model.parameters()) / 1000000.0))# 3.加載損失函數(shù)if args.labelsmooth:# overfitcriterion_class = nn.CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids, use_gpu=use_gpu)else:# 交叉熵?fù)p失(分類)criterion_class = nn.CrossEntropyLoss()# 三元損失(度量)criterion_metric = TripletAlignedReIDloss(margin=args.margin)# 4.加載模型優(yōu)化器# 這里對模型的優(yōu)化器進(jìn)行了重構(gòu) 可以根據(jù)參數(shù)調(diào)用不同的優(yōu)化器# args.optim 決定使用的優(yōu)化器# model.parameters()對所有的參數(shù)進(jìn)行更新 model.conv1對模型的第一層進(jìn)行更新# args.lr 初始學(xué)習(xí)率# args.wight_decay模型正則化參數(shù)optimizer = init_optim(args.optim,model.parameters(),args.lr,args.weight_decay)# 根據(jù)需求選擇不同調(diào)整學(xué)習(xí)率方法# 學(xué)習(xí)率的衰減 避免模型震蕩if args.stepsize > 0:scheduler = lr_scheduler.StepLR(optimizer,step_size=args.stepsize,gamma=args.gamma)start_epoch = args.start_epoch# 是否需要恢復(fù)模型if args.resume:print("Loading checkpoint from '{}'".format(args.resume))checkpoint = torch.load(args.resume)model.load_state_dict(checkpoint['state_dict'])start_epoch = checkpoint['epoch']# 使用并行庫if use_gpu:model = nn.DataParallel(model).cuda()# 測試if args.evaluate:print("Evaluate only")# resnet50,query, testtest(model, queryloader, galleryloader, use_gpu)return 0# 5.模型訓(xùn)練start_time = time.time()train_time = 0best_rank1 = -np.infbest_epoch = 0print("==> Start training")# 5.1 從開始的epoch,到結(jié)束的epoch開始循環(huán)for epoch in range(start_epoch, args.max_epoch):start_train_time = time.time()# 5.2 調(diào)用訓(xùn)練函數(shù)進(jìn)行訓(xùn)練train(epoch, model, criterion_class, criterion_metric, optimizer, trainloader, use_gpu)# 計算了一下訓(xùn)練的時間train_time += round(time.time() - start_train_time)# 學(xué)習(xí)率衰減if args.stepsize > 0: scheduler.step()# 測試if (epoch + 1) > args.start_eval and args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:print("==> Test")rank1 = test(model, queryloader, galleryloader, use_gpu)is_best = rank1 > best_rank1if is_best:best_rank1 = rank1best_epoch = epoch + 1if use_gpu:state_dict = model.module.state_dict()else:state_dict = model.state_dict()# 6.保存模型save_checkpoint({'state_dict': state_dict,'rank1': rank1,'epoch': epoch,}, is_best, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(best_rank1, best_epoch))elapsed = round(time.time() - start_time)elapsed = str(datetime.timedelta(seconds=elapsed))train_time = str(datetime.timedelta(seconds=train_time))print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))訓(xùn)練函數(shù)train()
def train(epoch, model, criterion_class, criterion_metric, optimizer, trainloader, use_gpu):# 確定模型實在訓(xùn)練模式model.train()losses = AverageMeter()batch_time = AverageMeter()data_time = AverageMeter()xent_losses = AverageMeter()global_losses = AverageMeter()local_losses = AverageMeter()end = time.time()# 使用trainloader迭代器吐數(shù)據(jù)for batch_idx, (imgs, pids, _) in enumerate(trainloader):if use_gpu:imgs, pids = imgs.cuda(), pids.cuda()# measure data loading timedata_time.update(time.time() - end)# [32,751],[32,128,8],[32,2048]outputs, global_features, local_features = model(imgs)# `htri`: triplet loss with hard positive/negative mining [4]if args.htri_only:# isinstance() 函數(shù)來判斷一個對象是否是一個已知的類型,類似 type() tuple 元組 數(shù)組if isinstance(global_features, tuple):global_loss, local_loss = DeepSupervision(criterion_metric, global_features, pids, local_features)else:global_loss, local_loss = criterion_metric(global_features, pids, local_features)else:if isinstance(outputs, tuple):# `xent`: cross entropy + label smoothing regularizerxent_loss = DeepSupervision(criterion_class, outputs, pids)else:xent_loss = criterion_class(outputs, pids)if isinstance(global_features, tuple):global_loss, local_loss = DeepSupervision(criterion_metric, global_features, pids, local_features)else:global_loss, local_loss = criterion_metric(global_features, pids, local_features)# 計算損失loss = xent_loss + global_loss + local_loss# 清空優(yōu)化器梯度optimizer.zero_grad()# 反向傳播loss.backward()# 更新模型參數(shù)optimizer.step()batch_time.update(time.time() - end)end = time.time()losses.update(loss.item(), pids.size(0))xent_losses.update(xent_loss.item(), pids.size(0))global_losses.update(global_loss.item(), pids.size(0))local_losses.update(local_loss.item(), pids.size(0))if (batch_idx+1) % args.print_freq == 0:print('Epoch: [{0}][{1}/{2}]\t''Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t''Data {data_time.val:.3f} ({data_time.avg:.3f})\t''Loss {loss.val:.4f} ({loss.avg:.4f})\t''CLoss {xent_loss.val:.4f} ({xent_loss.avg:.4f})\t''GLoss {global_loss.val:.4f} ({global_loss.avg:.4f})\t''LLoss {local_loss.val:.4f} ({local_loss.avg:.4f})\t'.format(epoch+1, batch_idx+1, len(trainloader), batch_time=batch_time,data_time=data_time,loss=losses,xent_loss=xent_losses, global_loss=global_losses, local_loss = local_losses))測試函數(shù)test()
def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20]):batch_time = AverageMeter()# 和model.train相對 表示不使用BatchNormalization和Dropout,保證BN和dropout不變化model.eval()# 反向傳播時都不會自動求導(dǎo)。volatile可以實現(xiàn)一定速度的提升,并節(jié)省一半的顯存,因為其不需要保存梯度with torch.no_grad():# 1.queryloader處理# list列表數(shù)據(jù)類型,列表是一種可變序列qf, q_pids, q_camids, lqf = [], [], [], []# 從迭代器中取數(shù)據(jù)for batch_idx, (imgs, pids, camids) in enumerate(queryloader):# 轉(zhuǎn)化為cuda模式if use_gpu: imgs = imgs.cuda()end = time.time()global_features, local_features = model(imgs)batch_time.update(time.time() - end)# 將GPU上的tensor轉(zhuǎn)化為cpu上從而進(jìn)行一些只能在cpu上進(jìn)行的運算global_features = global_features.data.cpu()local_features = local_features.data.cpu()# 添加到列表qf.append(global_features)lqf.append(local_features)q_pids.extend(pids)q_camids.extend(camids)# torch.cat是將兩個張量(tensor)拼接在一起 #按維數(shù)0(行)拼接qf = torch.cat(qf, 0)lqf = torch.cat(lqf,0)q_pids = np.asarray(q_pids)q_camids = np.asarray(q_camids)print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))# 2.galleryloader處理gf, g_pids, g_camids, lgf = [], [], [], []end = time.time()for batch_idx, (imgs, pids, camids) in enumerate(galleryloader):if use_gpu: imgs = imgs.cuda()end = time.time()features, local_features = model(imgs)batch_time.update(time.time() - end)features = features.data.cpu()local_features = local_features.data.cpu()gf.append(features)lgf.append(local_features)g_pids.extend(pids)g_camids.extend(camids)gf = torch.cat(gf, 0)lgf = torch.cat(lgf,0)g_pids = np.asarray(g_pids)g_camids = np.asarray(g_camids)print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1)))print("==> BatchTime(s)/BatchSize(img): {:.3f}/{}".format(batch_time.avg, args.test_batch))# feature normlizationqf = 1. * qf / (torch.norm(qf, 2, dim = -1, keepdim=True).expand_as(qf) + 1e-12)gf = 1. * gf / (torch.norm(gf, 2, dim = -1, keepdim=True).expand_as(gf) + 1e-12)m, n = qf.size(0), gf.size(0)# 求距離a^2+b^2-2*a*bdistmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()distmat.addmm_(1, -2, qf, gf.t())distmat = distmat.numpy()if not args.test_distance== 'global':print("Only using global branch")from utils.distance import low_memory_local_distlqf = lqf.permute(0,2,1)lgf = lgf.permute(0,2,1)local_distmat = low_memory_local_dist(lqf.numpy(),lgf.numpy(),aligned= not args.unaligned)if args.test_distance== 'local':print("Only using local branch")distmat = local_distmatif args.test_distance == 'global_local':print("Using global and local branches")distmat = local_distmat+distmatprint("Computing CMC and mAP")cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)print("Results ----------")print("mAP: {:.1%}".format(mAP))print("CMC curve")for r in ranks:print("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1]))print("------------------")if args.reranking:if args.test_distance == 'global':print("Only using global branch for reranking")distmat = re_ranking(qf,gf,k1=20, k2=6, lambda_value=0.3)else:local_qq_distmat = low_memory_local_dist(lqf.numpy(), lqf.numpy(),aligned= not args.unaligned)local_gg_distmat = low_memory_local_dist(lgf.numpy(), lgf.numpy(),aligned= not args.unaligned)local_dist = np.concatenate([np.concatenate([local_qq_distmat, local_distmat], axis=1),np.concatenate([local_distmat.T, local_gg_distmat], axis=1)],axis=0)if args.test_distance == 'local':print("Only using local branch for reranking")distmat = re_ranking(qf,gf,k1=20,k2=6,lambda_value=0.3,local_distmat=local_dist,only_local=True)elif args.test_distance == 'global_local':print("Using global and local branches for reranking")distmat = re_ranking(qf,gf,k1=20,k2=6,lambda_value=0.3,local_distmat=local_dist,only_local=False)print("Computing CMC and mAP for re_ranking")cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, use_metric_cuhk03=args.use_metric_cuhk03)print("Results ----------")print("mAP(RK): {:.1%}".format(mAP))print("CMC curve(RK)")for r in ranks:print("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1]))print("------------------")return cmc[0]開啟訓(xùn)練
if __name__ == "__main__":main()訓(xùn)練結(jié)果
截取的某一次的訓(xùn)練過程如下圖:
保存的模型:
最終識別率:
可以看到在第240個epoch的時候,rank-1可以達(dá)到89.4.
總結(jié)
以上是生活随笔為你收集整理的从零开始AlignedReID_05的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 美妙的 CSS3 动画!一组梦幻般的按钮
- 下一篇: C语言求解迷宫问题