最邻近方法nn_【庖丁解牛】从零实现RetinaNet(七):使用分布式方法训练RetinaNet...
下列代碼均在pytorch1.4版本中測(cè)試過,確認(rèn)正確無誤。
如何把nn.DataParallel模式訓(xùn)練代碼改為nn.parallel.DistributedDataParallel分布式訓(xùn)練代碼
首先,使用分布式訓(xùn)練時(shí),要多設(shè)置一個(gè)變量local_rank。local_rank初始值設(shè)為0,在進(jìn)行多卡的分布式訓(xùn)練時(shí),每張卡的local_rank會(huì)從0更新為0,1,2,......。
其次,在分布式訓(xùn)練時(shí),DataLoader中的batch_size指的不是總的batch_size,而是分到每張顯卡上的batch_size。
然后,我們要使用dist.init_process_group初始化進(jìn)程組。關(guān)于這部分在這里我不詳細(xì)展開,只提供一種在單機(jī)多卡模式下最簡(jiǎn)單的初始化方法:
dist.init_process_group(backend='nccl', init_method='env://')在單機(jī)多卡服務(wù)器上,如果要進(jìn)行多個(gè)分布式訓(xùn)練時(shí)(比如有4張卡,有兩張卡跑第一個(gè)分布式實(shí)驗(yàn)訓(xùn)練代碼,另外兩張卡跑第二個(gè)分布式實(shí)驗(yàn)訓(xùn)練代碼),每個(gè)分布式訓(xùn)練實(shí)驗(yàn)的train.sh啟動(dòng)代碼必須保證master_addr和master_port都不一樣。否則在單機(jī)多卡服務(wù)器上同時(shí)跑多個(gè)分布式訓(xùn)練代碼會(huì)報(bào)錯(cuò)。
python -m torch.distributed.launch --nproc_per_node=1 --master_addr 127.0.0.1 --master_port 20001 train.pynproc_per_node即要使用的顯卡的數(shù)量。
定義model后,需要使用nn.parallel.DistributedDataParallel API包裹model。如果是使用apex,那么也有類似的API:apex.parallel.DistributedDataParallel。
if args.sync_bn:model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)if args.apex:amp.register_float_function(torch, 'sigmoid')amp.register_float_function(torch, 'softmax')model, optimizer = amp.initialize(model, optimizer, opt_level='O1')model = apex.parallel.DistributedDataParallel(model,delay_allreduce=True)if args.sync_bn:model = apex.parallel.convert_syncbn_model(model)else:model = nn.parallel.DistributedDataParallel(model,device_ids=[local_rank],output_device=local_rank)注意使用apex時(shí)若要使用sync bn也要使用其自帶的apex.parallel.convert_syncbn_model將model中的BN層改為sync BN層。 然后,對(duì)于所有http://logger.info,我們?cè)O(shè)置為只有l(wèi)ocal_rank == 0時(shí)才寫入。否則,你用了幾張顯卡,Logger就會(huì)重復(fù)寫入幾遍。基于同樣的原因,我們?cè)O(shè)置validate時(shí)只在local_rank為0的顯卡上進(jìn)行validate。
完整分布式訓(xùn)練與測(cè)試代碼
config.py文件如下:
import os import sysBASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(BASE_DIR)from public.path import COCO2017_path from public.detection.dataset.cocodataset import CocoDetection, Resize, RandomFlip, RandomCrop, RandomTranslateimport torchvision.transforms as transforms import torchvision.datasets as datasetsclass Config(object):log = './log' # Path to save logcheckpoint_path = './checkpoints' # Path to store checkpoint modelresume = './checkpoints/latest.pth' # load checkpoint modelevaluate = None # evaluate model pathtrain_dataset_path = os.path.join(COCO2017_path, 'images/train2017')val_dataset_path = os.path.join(COCO2017_path, 'images/val2017')dataset_annotations_path = os.path.join(COCO2017_path, 'annotations')network = "resnet50_retinanet"pretrained = Falsenum_classes = 80seed = 0input_image_size = 600train_dataset = CocoDetection(image_root_dir=train_dataset_path,annotation_root_dir=dataset_annotations_path,set="train2017",transform=transforms.Compose([RandomFlip(flip_prob=0.5),RandomCrop(crop_prob=0.5),RandomTranslate(translate_prob=0.5),Resize(resize=input_image_size),]))val_dataset = CocoDetection(image_root_dir=val_dataset_path,annotation_root_dir=dataset_annotations_path,set="val2017",transform=transforms.Compose([Resize(resize=input_image_size),]))epochs = 12per_node_batch_size = 15lr = 1e-4num_workers = 4print_interval = 100apex = Truesync_bn = Falsetrain.py文件如下:
import sys import os import argparse import random import shutil import time import warnings import jsonBASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(BASE_DIR) warnings.filterwarnings('ignore')import numpy as np from thop import profile from thop import clever_format import apex from apex import amp from apex.parallel import convert_syncbn_model from apex.parallel import DistributedDataParallel import torch import torch.nn as nn import torch.nn.parallel import torch.distributed as dist import torch.backends.cudnn as cudnn from torch.utils.data import DataLoader from torchvision import transforms from config import Config from public.detection.dataset.cocodataset import COCODataPrefetcher, collater from public.detection.models.loss import RetinaLoss from public.detection.models.decode import RetinaDecoder from public.detection.models.retinanet import resnet50_retinanet from public.imagenet.utils import get_logger from pycocotools.cocoeval import COCOevaldef parse_args():parser = argparse.ArgumentParser(description='PyTorch COCO Detection Distributed Training')parser.add_argument('--network',type=str,default=Config.network,help='name of network')parser.add_argument('--lr',type=float,default=Config.lr,help='learning rate')parser.add_argument('--epochs',type=int,default=Config.epochs,help='num of training epochs')parser.add_argument('--per_node_batch_size',type=int,default=Config.per_node_batch_size,help='per_node batch size')parser.add_argument('--pretrained',type=bool,default=Config.pretrained,help='load pretrained model params or not')parser.add_argument('--num_classes',type=int,default=Config.num_classes,help='model classification num')parser.add_argument('--input_image_size',type=int,default=Config.input_image_size,help='input image size')parser.add_argument('--num_workers',type=int,default=Config.num_workers,help='number of worker to load data')parser.add_argument('--resume',type=str,default=Config.resume,help='put the path to resuming file if needed')parser.add_argument('--checkpoints',type=str,default=Config.checkpoint_path,help='path for saving trained models')parser.add_argument('--log',type=str,default=Config.log,help='path to save log')parser.add_argument('--evaluate',type=str,default=Config.evaluate,help='path for evaluate model')parser.add_argument('--seed', type=int, default=Config.seed, help='seed')parser.add_argument('--print_interval',type=bool,default=Config.print_interval,help='print interval')parser.add_argument('--apex',type=bool,default=Config.apex,help='use apex or not')parser.add_argument('--sync_bn',type=bool,default=Config.sync_bn,help='use sync bn or not')parser.add_argument('--local_rank',type=int,default=0,help='LOCAL_PROCESS_RANK')return parser.parse_args()def validate(val_dataset, model, decoder):model = model.module# switch to evaluate modemodel.eval()with torch.no_grad():all_eval_result = evaluate_coco(val_dataset, model, decoder)return all_eval_resultdef evaluate_coco(val_dataset, model, decoder):results, image_ids = [], []for index in range(len(val_dataset)):data = val_dataset[index]scale = data['scale']cls_heads, reg_heads, batch_anchors = model(data['img'].cuda().permute(2, 0, 1).float().unsqueeze(dim=0))scores, classes, boxes = decoder(cls_heads, reg_heads, batch_anchors)scores, classes, boxes = scores.cpu(), classes.cpu(), boxes.cpu()boxes /= scale# make sure decode batch_size=1# scores shape:[1,max_detection_num]# classes shape:[1,max_detection_num]# bboxes shape[1,max_detection_num,4]assert scores.shape[0] == 1scores = scores.squeeze(0)classes = classes.squeeze(0)boxes = boxes.squeeze(0)# for coco_eval,we need [x_min,y_min,w,h] format pred boxesboxes[:, 2:] -= boxes[:, :2]for object_score, object_class, object_box in zip(scores, classes, boxes):object_score = float(object_score)object_class = int(object_class)object_box = object_box.tolist()if object_class == -1:breakimage_result = {'image_id':val_dataset.image_ids[index],'category_id':val_dataset.find_category_id_from_coco_label(object_class),'score':object_score,'bbox':object_box,}results.append(image_result)image_ids.append(val_dataset.image_ids[index])print('{}/{}'.format(index, len(val_dataset)), end='r')if not len(results):print("No target detected in test set images")returnjson.dump(results,open('{}_bbox_results.json'.format(val_dataset.set_name), 'w'),indent=4)# load results in COCO evaluation toolcoco_true = val_dataset.cocococo_pred = coco_true.loadRes('{}_bbox_results.json'.format(val_dataset.set_name))coco_eval = COCOeval(coco_true, coco_pred, 'bbox')coco_eval.params.imgIds = image_idscoco_eval.evaluate()coco_eval.accumulate()coco_eval.summarize()all_eval_result = coco_eval.statsreturn all_eval_resultdef main():args = parse_args()global local_ranklocal_rank = args.local_rankif local_rank == 0:global loggerlogger = get_logger(__name__, args.log)torch.cuda.empty_cache()if args.seed is not None:random.seed(args.seed)torch.manual_seed(args.seed)torch.cuda.manual_seed_all(args.seed)cudnn.deterministic = Truetorch.cuda.set_device(local_rank)dist.init_process_group(backend='nccl', init_method='env://')global gpus_numgpus_num = torch.cuda.device_count()if local_rank == 0:logger.info(f'use {gpus_num} gpus')logger.info(f"args: {args}")cudnn.benchmark = Truecudnn.enabled = Truestart_time = time.time()# dataset and dataloaderif local_rank == 0:logger.info('start loading data')train_sampler = torch.utils.data.distributed.DistributedSampler(Config.train_dataset, shuffle=True)train_loader = DataLoader(Config.train_dataset,batch_size=args.per_node_batch_size,shuffle=False,num_workers=args.num_workers,collate_fn=collater,sampler=train_sampler)if local_rank == 0:logger.info('finish loading data')model = resnet50_retinanet(**{"pretrained": args.pretrained,"num_classes": args.num_classes,})for name, param in model.named_parameters():if local_rank == 0:logger.info(f"{name},{param.requires_grad}")flops_input = torch.randn(1, 3, args.input_image_size,args.input_image_size)flops, params = profile(model, inputs=(flops_input, ))flops, params = clever_format([flops, params], "%.3f")if local_rank == 0:logger.info(f"model: '{args.network}', flops: {flops}, params: {params}")criterion = RetinaLoss(image_w=args.input_image_size,image_h=args.input_image_size).cuda()decoder = RetinaDecoder(image_w=args.input_image_size,image_h=args.input_image_size).cuda()model = model.cuda()optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,patience=3,verbose=True)if args.sync_bn:model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)if args.apex:amp.register_float_function(torch, 'sigmoid')amp.register_float_function(torch, 'softmax')model, optimizer = amp.initialize(model, optimizer, opt_level='O1')model = apex.parallel.DistributedDataParallel(model,delay_allreduce=True)if args.sync_bn:model = apex.parallel.convert_syncbn_model(model)else:model = nn.parallel.DistributedDataParallel(model,device_ids=[local_rank],output_device=local_rank)if args.evaluate:if not os.path.isfile(args.evaluate):if local_rank == 0:logger.exception('{} is not a file, please check it again'.format(args.resume))sys.exit(-1)if local_rank == 0:logger.info('start only evaluating')logger.info(f"start resuming model from {args.evaluate}")checkpoint = torch.load(args.evaluate,map_location=torch.device('cpu'))model.load_state_dict(checkpoint['model_state_dict'])if local_rank == 0:all_eval_result = validate(Config.val_dataset, model, decoder)if all_eval_result is not None:logger.info(f"val: epoch: {checkpoint['epoch']:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}")returnbest_map = 0.0start_epoch = 1# resume trainingif os.path.exists(args.resume):if local_rank == 0:logger.info(f"start resuming model from {args.resume}")checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))start_epoch += checkpoint['epoch']best_map = checkpoint['best_map']model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])scheduler.load_state_dict(checkpoint['scheduler_state_dict'])if local_rank == 0:logger.info(f"finish resuming model from {args.resume}, epoch {checkpoint['epoch']}, best_map: {checkpoint['best_map']}, "f"loss: {checkpoint['loss']:3f}, cls_loss: {checkpoint['cls_loss']:2f}, reg_loss: {checkpoint['reg_loss']:2f}")if not os.path.exists(args.checkpoints):os.makedirs(args.checkpoints)if local_rank == 0:logger.info('start training')for epoch in range(start_epoch, args.epochs + 1):train_sampler.set_epoch(epoch)cls_losses, reg_losses, losses = train(train_loader, model, criterion,optimizer, scheduler, epoch,args)if local_rank == 0:logger.info(f"train: epoch {epoch:0>3d}, cls_loss: {cls_losses:.2f}, reg_loss: {reg_losses:.2f}, loss: {losses:.2f}")if epoch % 5 == 0 or epoch == args.epochs:if local_rank == 0:all_eval_result = validate(Config.val_dataset, model, decoder)logger.info(f"eval done.")if all_eval_result is not None:logger.info(f"val: epoch: {epoch:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}")if all_eval_result[0] > best_map:torch.save(model.module.state_dict(),os.path.join(args.checkpoints, "best.pth"))best_map = all_eval_result[0]if local_rank == 0:torch.save({'epoch': epoch,'best_map': best_map,'cls_loss': cls_losses,'reg_loss': reg_losses,'loss': losses,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'scheduler_state_dict': scheduler.state_dict(),}, os.path.join(args.checkpoints, 'latest.pth'))if local_rank == 0:logger.info(f"finish training, best_map: {best_map:.3f}")training_time = (time.time() - start_time) / 3600if local_rank == 0:logger.info(f"finish training, total training time: {training_time:.2f} hours")def train(train_loader, model, criterion, optimizer, scheduler, epoch, args):cls_losses, reg_losses, losses = [], [], []# switch to train modemodel.train()iters = len(train_loader.dataset) // (args.per_node_batch_size * gpus_num)prefetcher = COCODataPrefetcher(train_loader)images, annotations = prefetcher.next()iter_index = 1while images is not None:images, annotations = images.cuda().float(), annotations.cuda()cls_heads, reg_heads, batch_anchors = model(images)cls_loss, reg_loss = criterion(cls_heads, reg_heads, batch_anchors,annotations)loss = cls_loss + reg_lossif cls_loss == 0.0 or reg_loss == 0.0:optimizer.zero_grad()continueif args.apex:with amp.scale_loss(loss, optimizer) as scaled_loss:scaled_loss.backward()else:loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)optimizer.step()optimizer.zero_grad()cls_losses.append(cls_loss.item())reg_losses.append(reg_loss.item())losses.append(loss.item())images, annotations = prefetcher.next()if local_rank == 0 and iter_index % args.print_interval == 0:logger.info(f"train: epoch {epoch:0>3d}, iter [{iter_index:0>5d}, {iters:0>5d}], cls_loss: {cls_loss.item():.2f}, reg_loss: {reg_loss.item():.2f}, loss_total: {loss.item():.2f}")iter_index += 1scheduler.step(np.mean(losses))return np.mean(cls_losses), np.mean(reg_losses), np.mean(losses)if __name__ == '__main__':main()啟動(dòng)訓(xùn)練的train.sh:
python -m torch.distributed.launch --nproc_per_node=1 --master_addr 127.0.0.1 --master_port 20001 train.py分布式訓(xùn)練結(jié)果
模型在COCO數(shù)據(jù)集上的性能表現(xiàn)如下(輸入分辨率為600,約等于RetinaNet論文中的分辨率450):
上面所有實(shí)驗(yàn)均在DistributedDataParallel模式下訓(xùn)練。如果只用一張顯卡,那么使用sync BN和不使用sync BN是完全一樣的。所有實(shí)驗(yàn)訓(xùn)練時(shí)使用RandomFlip+Resize數(shù)據(jù)增強(qiáng),測(cè)試時(shí)直接Resize。帶-aug表示訓(xùn)練時(shí)還額外使用了RandomCrop和RandomTranslate數(shù)據(jù)增強(qiáng)。GPU全部使用RTX 2080ti。0.255,0.59表示mAP為0.255,此時(shí)的總loss為0.59。2h28min表示2小時(shí)28分。
根據(jù)結(jié)果,在同樣數(shù)據(jù)增強(qiáng)情況下我的代碼訓(xùn)練出來的RetinaNet(0.279)要比論文中低3.2個(gè)點(diǎn)(論文中分辨率450時(shí)點(diǎn)數(shù)推算應(yīng)該在0.311左右),這應(yīng)該是由于使用了Adam優(yōu)化器代替SGD優(yōu)化器,以及上一篇文章中提出的問題1、3帶來的點(diǎn)數(shù)差距。
iscrowd問題
在COCO數(shù)據(jù)集的標(biāo)注中,有一個(gè)屬性iscrowd。當(dāng)iscrowd=1時(shí),表明標(biāo)注的為一群目標(biāo)(比如一群人),當(dāng)iscrowd=0時(shí),表明標(biāo)注的為單一目標(biāo)。在前面的所有實(shí)驗(yàn)結(jié)果中,訓(xùn)練時(shí)讀取的標(biāo)注目標(biāo)均為(self.coco.getAnnIds中iscrowd=None)iscrowd=0+iscrowd=1的所有目標(biāo)。
我查閱了detectron和detectron2中讀取COCO數(shù)據(jù)集的代碼,發(fā)現(xiàn)它們?cè)谀繕?biāo)檢測(cè)和分割訓(xùn)練時(shí)均過濾了iscrowd=1的目標(biāo),沒有將其用于訓(xùn)練。因此,我去除了iscrowd=1的所有標(biāo)注目標(biāo)重新訓(xùn)練了一次(self.coco.getAnnIds中iscrowd取False)。
訓(xùn)練結(jié)果如下:
ResNet50-RetinaNet-aug即上面分布式訓(xùn)練結(jié)果中最后一項(xiàng),ResNet50-RetinaNet-aug-iscrowd即上面分布式訓(xùn)練結(jié)果中最后一項(xiàng)基礎(chǔ)上self.coco.getAnnIds中iscrowd取False后的結(jié)果??梢钥闯鰞烧卟顒e很小,不過為了與其他框架訓(xùn)練結(jié)果對(duì)齊,在之后的改進(jìn)實(shí)驗(yàn)中,我統(tǒng)一使用ResNet50-RetinaNet-aug-iscrowd作為baseline。
所有代碼已上傳到本人github repository:
zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training?github.com本文也同時(shí)放在了我的CSDN博客上:
【庖丁解?!繌牧銓?shí)現(xiàn)RetinaNet(七):使用分布式方法訓(xùn)練RetinaNet_記憶碎片的博客-CSDN博客?blog.csdn.net總結(jié)
以上是生活随笔為你收集整理的最邻近方法nn_【庖丁解牛】从零实现RetinaNet(七):使用分布式方法训练RetinaNet...的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 天气模式_江西现罕见持续阴雨寡照天气 市
- 下一篇: cmake导入so库_libgo - 协