从零开始AlignedReID_05
生活随笔
收集整理的這篇文章主要介紹了
从零开始AlignedReID_05
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
模型訓練
前言
經過前面的內容,所有需要的模塊都介紹過了,那么接下來就是模型的訓練與測試。
模型的訓練可以大致分為以下幾個過程:
- 數據預處理
- 初始化數據集
- 數據增強
- 導入數據
- 模型加載
- 損失函數加載
- 優化器加載
- 模型訓練
- 模型測試
- 保存模型
那么是整個實現過程,下面所有代碼均在AlignedReID目錄下,train_models.py文件中。
加載頭文件
#-*-coding:utf-8-*- from __future__ import absolute_importimport os import sys import os.path as osp import time import datetime import argparse import numpy as npimport torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.utils.data import DataLoader import torchvision.transforms as T from torch.optim import lr_scheduler# 導入自己創建的工具 import models from losses.tripletloss import TripletAlignedReIDloss,DeepSupervision from data_process import dataset_manager from data_process.data_loader import ImageDataset from utils.util import AverageMeter,Logger,save_checkpoint from utils.optimizers import init_optim from utils.samplers import RandomIdentitySampler from utils import re_ranking from utils.eval_metrics import evaluate設置相關配置項
# 0.設置一些常見的選項 parser = argparse.ArgumentParser(description='Train AlignedReID with cross entropy loss and triplet hard loss') # Datasets # path parser.add_argument('--root', type=str, default='/home/user/桌面/code/data', help="root path to data directory") # dataset name parser.add_argument('-d', '--dataset', type=str, default='market1501',choices=dataset_manager.get_names()) # 多線程 4個 parser.add_argument('-j', '--workers', default=4, type=int,help="number of data loading workers (default: 4)") # image height parser.add_argument('--height', type=int, default=256,help="height of an image (default: 256)") # image weight parser.add_argument('--width', type=int, default=128,help="width of an image (default: 128)") # split-id 默認為0 parser.add_argument('--split-id', type=int, default=0, help="split index")# Optimization options parser.add_argument('--labelsmooth', action='store_true', help="label smooth") # 默認使用adam優化 parser.add_argument('--optim', type=str, default='adam', help="optimization algorithm (see optimizers.py)") # 總共300epoch parser.add_argument('--max-epoch', default=300, type=int,help="maximum epochs to run") parser.add_argument('--start-epoch', default=0, type=int,help="manual epoch number (useful on restarts)") # batch size parser.add_argument('--train-batch', default=32, type=int,help="train batch size") parser.add_argument('--test-batch', default=32, type=int, help="test batch size") # 學習率初始值 parser.add_argument('--lr', '--learning-rate', default=0.0002, type=float,help="initial learning rate") # 步長 parser.add_argument('--stepsize', default=150, type=int,help="stepsize to decay learning rate (>0 means this is enabled)") # 學習率衰減系數 parser.add_argument('--gamma', default=0.1, type=float,help="learning rate decay") parser.add_argument('--weight-decay', default=5e-04, type=float,help="weight decay (default: 5e-04)")# triplet hard loss parser.add_argument('--margin', type=float, default=0.3, help="margin for triplet loss") parser.add_argument('--num-instances', type=int, default=4,help="number of instances per identity") parser.add_argument('--htri-only', action='store_true', default=False,help="if this is True, only htri loss is used in training") # Architecture parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names())# Miscs parser.add_argument('--print-freq', type=int, default=10, help="print frequency") parser.add_argument('--seed', type=int, default=1, help="manual seed") parser.add_argument('--resume', type=str, default='', metavar='PATH') parser.add_argument('--evaluate', action='store_true', help="evaluation only") # 每20個epoch執行一次測試 parser.add_argument('--eval-step', type=int, default=20,help="run evaluation for every N epochs (set to -1 to test after training)") parser.add_argument('--start-eval', type=int, default=0, help="start to evaluate after specific epoch") parser.add_argument('--save-dir', type=str, default='log') parser.add_argument('--use_cpu', action='store_true', help="use cpu") # 默認使用gup-0 parser.add_argument('--gpu-devices', default='0', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES') parser.add_argument('--reranking',action= 'store_true', help= 'result re_ranking')parser.add_argument('--test_distance',type = str, default='global', help= 'test distance type') parser.add_argument('--unaligned',action= 'store_true', help= 'test local feature with unalignment')args = parser.parse_args()主函數main()
# 主函數 def main():# 判斷是否有GPUuse_gpu = torch.cuda.is_available()# 使用cpu則Gpu關if args.use_cpu:use_gpu = False# 節省內存pin_memory =True if use_gpu else False# 日志輸出if not args.evaluate:sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))else:sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))print("==========\nArgs:{}\n==========".format(args))# 使用GPU的一些設置if use_gpu:print("Currently using GPU {}".format(args.gpu_devices))os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devicescudnn.benchmark = True# 確定隨機初始化seedtorch.cuda.manual_seed_all(args.seed)else:print("Currently using CPU (GPU is highly recommended)")# 1. 數據預處理# 1.1 初始化數據集print("Initializing dataset {}".format(args.dataset))dataset = dataset_manager.init_img_dataset(root=args.root, name=args.dataset)# 1.2 data augmentation# 訓練集和測試集采用的方式不同# 測試集不需要數據增強 僅需要修改圖片格式# 這里把圖片轉成的256,128 且為tensortransform_train = T.Compose([T.Resize((args.height,args.width)),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),])transform_test = T.Compose([T.Resize((args.height,args.width)),T.ToTensor(),T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),])# 1.3讀入圖片數據trainloader = DataLoader(ImageDataset(dataset.train,transform=transform_train),sampler=RandomIdentitySampler(dataset.train,num_instances=args.num_instances),batch_size=args.train_batch,num_workers=args.workers,pin_memory=pin_memory,# 丟掉不滿足一個batch的數據drop_last=True,)queryloader = DataLoader(ImageDataset(dataset.query,transform=transform_test),# shuffle =False 不打亂順序batch_size=args.test_batch,shuffle=False,num_workers=args.workers,pin_memory=pin_memory,drop_last=False,)galleryloader = DataLoader(ImageDataset(dataset.gallery, transform=transform_test),batch_size=args.test_batch, shuffle=False, num_workers=args.workers,pin_memory=pin_memory, drop_last=False,)# 2.加載模型print("Initializing model: {}".format(args.arch))model = models.init_model(name=args.arch,num_classes = dataset.num_train_pids,loss={'softmax','metric'},aligned =True,use_gpu=use_gpu)print("Model size: {:.5f}M".format(sum(p.numel() for p in model.parameters()) / 1000000.0))# 3.加載損失函數if args.labelsmooth:# overfitcriterion_class = nn.CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids, use_gpu=use_gpu)else:# 交叉熵損失(分類)criterion_class = nn.CrossEntropyLoss()# 三元損失(度量)criterion_metric = TripletAlignedReIDloss(margin=args.margin)# 4.加載模型優化器# 這里對模型的優化器進行了重構 可以根據參數調用不同的優化器# args.optim 決定使用的優化器# model.parameters()對所有的參數進行更新 model.conv1對模型的第一層進行更新# args.lr 初始學習率# args.wight_decay模型正則化參數optimizer = init_optim(args.optim,model.parameters(),args.lr,args.weight_decay)# 根據需求選擇不同調整學習率方法# 學習率的衰減 避免模型震蕩if args.stepsize > 0:scheduler = lr_scheduler.StepLR(optimizer,step_size=args.stepsize,gamma=args.gamma)start_epoch = args.start_epoch# 是否需要恢復模型if args.resume:print("Loading checkpoint from '{}'".format(args.resume))checkpoint = torch.load(args.resume)model.load_state_dict(checkpoint['state_dict'])start_epoch = checkpoint['epoch']# 使用并行庫if use_gpu:model = nn.DataParallel(model).cuda()# 測試if args.evaluate:print("Evaluate only")# resnet50,query, testtest(model, queryloader, galleryloader, use_gpu)return 0# 5.模型訓練start_time = time.time()train_time = 0best_rank1 = -np.infbest_epoch = 0print("==> Start training")# 5.1 從開始的epoch,到結束的epoch開始循環for epoch in range(start_epoch, args.max_epoch):start_train_time = time.time()# 5.2 調用訓練函數進行訓練train(epoch, model, criterion_class, criterion_metric, optimizer, trainloader, use_gpu)# 計算了一下訓練的時間train_time += round(time.time() - start_train_time)# 學習率衰減if args.stepsize > 0: scheduler.step()# 測試if (epoch + 1) > args.start_eval and args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:print("==> Test")rank1 = test(model, queryloader, galleryloader, use_gpu)is_best = rank1 > best_rank1if is_best:best_rank1 = rank1best_epoch = epoch + 1if use_gpu:state_dict = model.module.state_dict()else:state_dict = model.state_dict()# 6.保存模型save_checkpoint({'state_dict': state_dict,'rank1': rank1,'epoch': epoch,}, is_best, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(best_rank1, best_epoch))elapsed = round(time.time() - start_time)elapsed = str(datetime.timedelta(seconds=elapsed))train_time = str(datetime.timedelta(seconds=train_time))print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))訓練函數train()
def train(epoch, model, criterion_class, criterion_metric, optimizer, trainloader, use_gpu):# 確定模型實在訓練模式model.train()losses = AverageMeter()batch_time = AverageMeter()data_time = AverageMeter()xent_losses = AverageMeter()global_losses = AverageMeter()local_losses = AverageMeter()end = time.time()# 使用trainloader迭代器吐數據for batch_idx, (imgs, pids, _) in enumerate(trainloader):if use_gpu:imgs, pids = imgs.cuda(), pids.cuda()# measure data loading timedata_time.update(time.time() - end)# [32,751],[32,128,8],[32,2048]outputs, global_features, local_features = model(imgs)# `htri`: triplet loss with hard positive/negative mining [4]if args.htri_only:# isinstance() 函數來判斷一個對象是否是一個已知的類型,類似 type() tuple 元組 數組if isinstance(global_features, tuple):global_loss, local_loss = DeepSupervision(criterion_metric, global_features, pids, local_features)else:global_loss, local_loss = criterion_metric(global_features, pids, local_features)else:if isinstance(outputs, tuple):# `xent`: cross entropy + label smoothing regularizerxent_loss = DeepSupervision(criterion_class, outputs, pids)else:xent_loss = criterion_class(outputs, pids)if isinstance(global_features, tuple):global_loss, local_loss = DeepSupervision(criterion_metric, global_features, pids, local_features)else:global_loss, local_loss = criterion_metric(global_features, pids, local_features)# 計算損失loss = xent_loss + global_loss + local_loss# 清空優化器梯度optimizer.zero_grad()# 反向傳播loss.backward()# 更新模型參數optimizer.step()batch_time.update(time.time() - end)end = time.time()losses.update(loss.item(), pids.size(0))xent_losses.update(xent_loss.item(), pids.size(0))global_losses.update(global_loss.item(), pids.size(0))local_losses.update(local_loss.item(), pids.size(0))if (batch_idx+1) % args.print_freq == 0:print('Epoch: [{0}][{1}/{2}]\t''Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t''Data {data_time.val:.3f} ({data_time.avg:.3f})\t''Loss {loss.val:.4f} ({loss.avg:.4f})\t''CLoss {xent_loss.val:.4f} ({xent_loss.avg:.4f})\t''GLoss {global_loss.val:.4f} ({global_loss.avg:.4f})\t''LLoss {local_loss.val:.4f} ({local_loss.avg:.4f})\t'.format(epoch+1, batch_idx+1, len(trainloader), batch_time=batch_time,data_time=data_time,loss=losses,xent_loss=xent_losses, global_loss=global_losses, local_loss = local_losses))測試函數test()
def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20]):batch_time = AverageMeter()# 和model.train相對 表示不使用BatchNormalization和Dropout,保證BN和dropout不變化model.eval()# 反向傳播時都不會自動求導。volatile可以實現一定速度的提升,并節省一半的顯存,因為其不需要保存梯度with torch.no_grad():# 1.queryloader處理# list列表數據類型,列表是一種可變序列qf, q_pids, q_camids, lqf = [], [], [], []# 從迭代器中取數據for batch_idx, (imgs, pids, camids) in enumerate(queryloader):# 轉化為cuda模式if use_gpu: imgs = imgs.cuda()end = time.time()global_features, local_features = model(imgs)batch_time.update(time.time() - end)# 將GPU上的tensor轉化為cpu上從而進行一些只能在cpu上進行的運算global_features = global_features.data.cpu()local_features = local_features.data.cpu()# 添加到列表qf.append(global_features)lqf.append(local_features)q_pids.extend(pids)q_camids.extend(camids)# torch.cat是將兩個張量(tensor)拼接在一起 #按維數0(行)拼接qf = torch.cat(qf, 0)lqf = torch.cat(lqf,0)q_pids = np.asarray(q_pids)q_camids = np.asarray(q_camids)print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))# 2.galleryloader處理gf, g_pids, g_camids, lgf = [], [], [], []end = time.time()for batch_idx, (imgs, pids, camids) in enumerate(galleryloader):if use_gpu: imgs = imgs.cuda()end = time.time()features, local_features = model(imgs)batch_time.update(time.time() - end)features = features.data.cpu()local_features = local_features.data.cpu()gf.append(features)lgf.append(local_features)g_pids.extend(pids)g_camids.extend(camids)gf = torch.cat(gf, 0)lgf = torch.cat(lgf,0)g_pids = np.asarray(g_pids)g_camids = np.asarray(g_camids)print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1)))print("==> BatchTime(s)/BatchSize(img): {:.3f}/{}".format(batch_time.avg, args.test_batch))# feature normlizationqf = 1. * qf / (torch.norm(qf, 2, dim = -1, keepdim=True).expand_as(qf) + 1e-12)gf = 1. * gf / (torch.norm(gf, 2, dim = -1, keepdim=True).expand_as(gf) + 1e-12)m, n = qf.size(0), gf.size(0)# 求距離a^2+b^2-2*a*bdistmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()distmat.addmm_(1, -2, qf, gf.t())distmat = distmat.numpy()if not args.test_distance== 'global':print("Only using global branch")from utils.distance import low_memory_local_distlqf = lqf.permute(0,2,1)lgf = lgf.permute(0,2,1)local_distmat = low_memory_local_dist(lqf.numpy(),lgf.numpy(),aligned= not args.unaligned)if args.test_distance== 'local':print("Only using local branch")distmat = local_distmatif args.test_distance == 'global_local':print("Using global and local branches")distmat = local_distmat+distmatprint("Computing CMC and mAP")cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)print("Results ----------")print("mAP: {:.1%}".format(mAP))print("CMC curve")for r in ranks:print("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1]))print("------------------")if args.reranking:if args.test_distance == 'global':print("Only using global branch for reranking")distmat = re_ranking(qf,gf,k1=20, k2=6, lambda_value=0.3)else:local_qq_distmat = low_memory_local_dist(lqf.numpy(), lqf.numpy(),aligned= not args.unaligned)local_gg_distmat = low_memory_local_dist(lgf.numpy(), lgf.numpy(),aligned= not args.unaligned)local_dist = np.concatenate([np.concatenate([local_qq_distmat, local_distmat], axis=1),np.concatenate([local_distmat.T, local_gg_distmat], axis=1)],axis=0)if args.test_distance == 'local':print("Only using local branch for reranking")distmat = re_ranking(qf,gf,k1=20,k2=6,lambda_value=0.3,local_distmat=local_dist,only_local=True)elif args.test_distance == 'global_local':print("Using global and local branches for reranking")distmat = re_ranking(qf,gf,k1=20,k2=6,lambda_value=0.3,local_distmat=local_dist,only_local=False)print("Computing CMC and mAP for re_ranking")cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, use_metric_cuhk03=args.use_metric_cuhk03)print("Results ----------")print("mAP(RK): {:.1%}".format(mAP))print("CMC curve(RK)")for r in ranks:print("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1]))print("------------------")return cmc[0]開啟訓練
if __name__ == "__main__":main()訓練結果
截取的某一次的訓練過程如下圖:
保存的模型:
最終識別率:
可以看到在第240個epoch的時候,rank-1可以達到89.4.
總結
以上是生活随笔為你收集整理的从零开始AlignedReID_05的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 美妙的 CSS3 动画!一组梦幻般的按钮
- 下一篇: C语言求解迷宫问题