Pytorch torchvision完成Faster-rcnn目标检测demo及源码详解
Torchvision更新到0.3.0后支持了更多的功能,其中新增模塊detection中實(shí)現(xiàn)了整個(gè)faster-rcnn的功能。本博客主要講述如何通過torchvision和pytorch使用faster-rcnn,并提供一個(gè)demo和對(duì)應(yīng)代碼及解析注釋。
目錄
如果你不想深入了解原理和訓(xùn)練,只想用Faster-rcnn做目標(biāo)檢測(cè),請(qǐng)看這里
torchvision中Faster-rcnn接口
一個(gè)demo
使用方法
如果你想深入了解原理,并訓(xùn)練自己的模型
環(huán)境搭建
準(zhǔn)備訓(xùn)練數(shù)據(jù)
模型訓(xùn)練
單張圖片檢測(cè)
效果
如果你不想深入了解原理和訓(xùn)練,只想用Faster-rcnn做目標(biāo)檢測(cè),請(qǐng)看這里
torchvision中Faster-rcnn接口
torchvision內(nèi)部集成了Faster-rcnn的模型,其接口和調(diào)用方式野非常簡(jiǎn)潔,目前官方提供resnet50+rpn在coco上訓(xùn)練的模型,調(diào)用該模型只需要幾行代碼:
>>> import torch
>>> import torchvision
?
// 創(chuàng)建模型,pretrained=True將下載官方提供的coco2017模型
>>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
?
?
?
注意網(wǎng)絡(luò)的輸入x是一個(gè)Tensor構(gòu)成的list,而輸出prediction則是一個(gè)由dict構(gòu)成list。prediction的長(zhǎng)度和網(wǎng)絡(luò)輸入的list中Tensor個(gè)數(shù)相同。prediction中的每個(gè)dict包含輸出的結(jié)果:
其中boxes是檢測(cè)框坐標(biāo),labels是類別,scores則是置信度。
>>> predictions[0]
?
{'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward>), 'labels': tensor([], dtype=torch.int64), 'scores': tensor([], grad_fn=<IndexBackward>)}
一個(gè)demo
如果你不想自己寫讀取圖片/預(yù)處理/后處理,我這里有個(gè)寫好的demo.py,可以跑在任何安裝了pytorch1.1+和torchvision0.3+的環(huán)境下,不需要其他依賴,可以用來完成目標(biāo)檢測(cè)的任務(wù)。
為了能夠顯示類別標(biāo)簽,我們將coco的所有類別寫入coco_names.py
names = {'0': 'background', '1': 'person', '2': 'bicycle', '3': 'car', '4': 'motorcycle', '5': 'airplane', '6': 'bus', '7': 'train', '8': 'truck', '9': 'boat', '10': 'traffic light', '11': 'fire hydrant', '13': 'stop sign', '14': 'parking meter', '15': 'bench', '16': 'bird', '17': 'cat', '18': 'dog', '19': 'horse', '20': 'sheep', '21': 'cow', '22': 'elephant', '23': 'bear', '24': 'zebra', '25': 'giraffe', '27': 'backpack', '28': 'umbrella', '31': 'handbag', '32': 'tie', '33': 'suitcase', '34': 'frisbee', '35': 'skis', '36': 'snowboard', '37': 'sports ball', '38': 'kite', '39': 'baseball bat', '40': 'baseball glove', '41': 'skateboard', '42': 'surfboard', '43': 'tennis racket', '44': 'bottle', '46': 'wine glass', '47': 'cup', '48': 'fork', '49': 'knife', '50': 'spoon', '51': 'bowl', '52': 'banana', '53': 'apple', '54': 'sandwich', '55': 'orange', '56': 'broccoli', '57': 'carrot', '58': 'hot dog', '59': 'pizza', '60': 'donut', '61': 'cake', '62': 'chair', '63': 'couch', '64': 'potted plant', '65': 'bed', '67': 'dining table', '70': 'toilet', '72': 'tv', '73': 'laptop', '74': 'mouse', '75': 'remote', '76': 'keyboard', '77': 'cell phone', '78': 'microwave', '79': 'oven', '80': 'toaster', '81': 'sink', '82': 'refrigerator', '84': 'book', '85': 'clock', '86': 'vase', '87': 'scissors', '88': 'teddybear', '89': 'hair drier', '90': 'toothbrush'}
然后構(gòu)建一個(gè)可以讀取圖片并檢測(cè)的demo.py
import torch
import torchvision
import argparse
import cv2
import numpy as np
import sys
sys.path.append('./')
import coco_names
import random
?
def get_args():
? ? parser = argparse.ArgumentParser(description='Pytorch Faster-rcnn Detection')
?
? ? parser.add_argument('image_path', type=str, help='image path')
? ? parser.add_argument('--model', default='fasterrcnn_resnet50_fpn', help='model')
? ? parser.add_argument('--dataset', default='coco', help='model')
? ? parser.add_argument('--score', type=float, default=0.8, help='objectness score threshold')
? ? args = parser.parse_args()
?
? ? return args
?
def random_color():
? ? b = random.randint(0,255)
? ? g = random.randint(0,255)
? ? r = random.randint(0,255)
?
? ? return (b,g,r)
?
def main():
? ? args = get_args()
? ? input = []
? ? num_classes = 91
? ? names = coco_names.names
? ? ? ??
? ? # Model creating
? ? print("Creating model")
? ? model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=True) ?
? ? model = model.cuda()
?
? ? model.eval()
?
? ? src_img = cv2.imread(args.image_path)
? ? img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
? ? img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().cuda()
? ? input.append(img_tensor)
? ? out = model(input)
? ? boxes = out[0]['boxes']
? ? labels = out[0]['labels']
? ? scores = out[0]['scores']
?
? ? for idx in range(boxes.shape[0]):
? ? ? ? if scores[idx] >= args.score:
? ? ? ? ? ? x1, y1, x2, y2 = boxes[idx][0], boxes[idx][1], boxes[idx][2], boxes[idx][3]
? ? ? ? ? ? name = names.get(str(labels[idx].item()))
? ? ? ? ? ? cv2.rectangle(src_img,(x1,y1),(x2,y2),random_color(),thickness=2)
? ? ? ? ? ? cv2.putText(src_img, text=name, org=(x1, y1+10), fontFace=cv2.FONT_HERSHEY_SIMPLEX,?
? ? ? ? ? ? ? ? fontScale=0.5, thickness=1, lineType=cv2.LINE_AA, color=(0, 0, 255))
?
? ? cv2.imshow('result',src_img)
? ? cv2.waitKey()
? ? cv2.destroyAllWindows()
?
? ??
?
if __name__ == "__main__":
? ? main()
運(yùn)行命令
$ python demo.py [image path]
就能完成檢測(cè),并且不需要任何其他依賴,只需要Pytorch1.1+和torchvision0.3+。看下效果:
使用方法
我發(fā)現(xiàn)好像很多人對(duì)上面這個(gè)demo怎么用不太清楚,照著下面的流程做就好了:
下載代碼:https://github.com/supernotman/Faster-RCNN-with-torchvision
下載模型:Baidu Cloud
運(yùn)行命令:
$ python detect.py --model_path [模型路徑] --image_path [圖片路徑]
其實(shí)非常簡(jiǎn)單。
如果你想深入了解原理,并訓(xùn)練自己的模型
這里提供一份我重構(gòu)過的代碼,把torchvision中的faster-rcnn部分提取出來,可以訓(xùn)練自己的模型(目前只支持coco),并有對(duì)應(yīng)博客講解。
代碼地址:https://github.com/supernotman/Faster-RCNN-with-torchvision
代碼解析博客:
Pytorch torchvision構(gòu)建Faster-rcnn(一)----coco數(shù)據(jù)讀取
Pytorch torchvision構(gòu)建Faster-rcnn(二)----基礎(chǔ)網(wǎng)絡(luò)
Pytorch torchvision構(gòu)建Faster-rcnn(三)----RPN
Pytorch torchvision構(gòu)建Faster-rcnn(四)----ROIHead
訓(xùn)練模型:Baidu Cloud
環(huán)境搭建
下載代碼:
$ git clone https://github.com/supernotman/Faster-RCNN-with-torchvision.git
安裝依賴:
$ pip install -r requirements.txt
注意:
代碼要求Pytorch版本大于1.1.0,torchvision版本大于0.3.0。
如果某個(gè)依賴項(xiàng)通過pip安裝過慢,推薦替換清華源:
$ pip install -i https://pypi.tuna.tsinghua.edu.cn/simple some-package
如果pytorch安裝過慢,可參考conda安裝Pytorch下載過慢解決辦法(7月23日更新ubuntu下pytorch1.1安裝方法)
準(zhǔn)備訓(xùn)練數(shù)據(jù)
下載coco2017數(shù)據(jù)集,下載地址:
http://images.cocodataset.org/zips/train2017.zip
http://images.cocodataset.org/annotations/annotations_trainval2017.zip
http://images.cocodataset.org/zips/val2017.zip
http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip
http://images.cocodataset.org/zips/test2017.zip
http://images.cocodataset.org/annotations/image_info_test2017.zip?
如果下載速度過慢,可參考博客COCO2017數(shù)據(jù)集國(guó)內(nèi)下載地址
數(shù)據(jù)下載后按照如下結(jié)構(gòu)放置:
? coco/
? ? 2017/
? ? ? annotations/
? ? ? test2017/
? ? ? train2017/
? ? ? val2017/
模型訓(xùn)練
$ python -m torch.distributed.launch --nproc_per_node=$gpus --use_env train.py --world-size $gpus --b 4
訓(xùn)練采用了Pytorch的distributedparallel方式,支持多gpu。
注意其中$gpus為指定使用的gpu數(shù)量,b為每個(gè)gpu上的batch_size,因此實(shí)際batch_size大小為$gpus × b。
實(shí)測(cè)當(dāng)b=4,1080ti下大概每張卡會(huì)占用11G顯存,請(qǐng)根據(jù)情況自行設(shè)定。
訓(xùn)練過程中每個(gè)epoch會(huì)給出一次評(píng)估結(jié)果,形式如下:
?Average Precision??(AP) @[ IoU=0.50:0.95 | area=???all | maxDets=100 ] = 0.352
?Average Precision??(AP) @[ IoU=0.50??????| area=???all | maxDets=100 ] = 0.573
?Average Precision??(AP) @[ IoU=0.75??????| area=???all | maxDets=100 ] = 0.375
?Average Precision??(AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.207
?Average Precision??(AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.387
?Average Precision??(AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.448
?Average Recall?????(AR) @[ IoU=0.50:0.95 | area=???all | maxDets=??1 ] = 0.296
?Average Recall?????(AR) @[ IoU=0.50:0.95 | area=???all | maxDets= 10 ] = 0.474
?Average Recall?????(AR) @[ IoU=0.50:0.95 | area=???all | maxDets=100 ] = 0.498
?Average Recall?????(AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.312
?Average Recall?????(AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.538
?Average Recall?????(AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.631
其中AP為準(zhǔn)確率,AR為召回率,第一行為訓(xùn)練結(jié)果的mAP,第四、五、六行分別為小/中/大物體對(duì)應(yīng)的mAP
單張圖片檢測(cè)
$ python detect.py --model_path result/model_13.pth --image_path imgs/1.jpg
model_path為模型路徑,image_path為測(cè)試圖片路徑。
代碼文件夾中assets給出了從coco2017測(cè)試集中挑選的11張圖片測(cè)試結(jié)果。
效果
總結(jié)
以上是生活随笔為你收集整理的Pytorch torchvision完成Faster-rcnn目标检测demo及源码详解的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: C++或C 实现AES ECB模式加密解
- 下一篇: 程序设计竞赛资源索引