CenterNet算法快速入门
目錄
1 簡介
2 網絡結構
3 損失函數
3.1 heatmap loss(改造的Focal Loss)
3.2 長寬預測loss(L1損失函數)
3.3 中心點偏移值loss(L1損失函數)
4 擴展:關節點預測和3D任務
4.1 人體關節點預測
4.2 3D目標檢測
1 簡介
- 時間:2019年論文《Objects as Points》
- 特點:
- 不需要anchor、也不需要NMS,模型結構簡單、速度快、精度高(比yolo3高4個點)!
- 只需少量修改head,就可以改造成3D目標檢測和人體關節點檢測。
2 網絡結構
輸入:3 x 512 x 512。
backbone:特征提取器(32倍下采樣) + 反卷積(8倍上采樣)。
head:3個分支進行預測,每個分支用2個卷積實現。
輸出:
- 80 x 128 x 128:目標分類信息和中心點位置信息,每個類單獨在一個熱圖中,熱圖中最亮的一些點就是坐標信息。
- 2 x 128 x 128:所有目標的w和h信息,一個預測w,另一個預測h。每個網格與熱圖中目標網格一一對應。
- 2 x 128 x 128:所有目標中心點的x和y偏移量信息。
3 損失函數
因為網絡輸出3個部分,所以損失函數也有3個部分:
- heatmap的loss(改造的Focal Loss)
- 目標長寬預測loss(L1損失)
- 目標中心點偏移值loss(L1損失)
以下參考:https://www.cnblogs.com/silence-cho/p/13955766.html
3.1 heatmap loss(改造的Focal Loss)
備注:一個目標正樣本就一個,負樣本指的是熱圖中心點附近的點。
關于熱圖,看一個官方源碼中生成的一個高斯分布:
每個點的范圍是0-1,而1則代表這個目標的中心點,也就是我們要預測要學習的點。
3.2 長寬預測loss(L1損失函數)
3.3 中心點偏移值loss(L1損失函數)
4 擴展:關節點預測和3D任務
4.1 人體關節點預測
這個問題,本質上,就相當于把人的每類肢體關節點,定義為一個類。
如下圖,假如要識別一張圖上,所有人的5個關節點,那么網絡輸出head定義如下:
輸入:一張2D圖像。
輸出:
- 5 x 128 x 128:5個熱圖,每類關節點單獨在一個熱圖中。
- 2 x 128 x 128:所有關節點的w和h信息。
- 2 x 128 x 128:所有關節點的x和y偏移量信息。
備注:也可以訓練centernet直接檢測85類目標(80個coco物體類+5個人體關節點類)。
4.2 3D目標檢測
3D目標檢測,需要在3D數據中,預測出目標(相對拍攝相機)的depth距離、目標的3D bbox框的長寬高信息、bbox的朝向信息。
輸入:2D圖(但標簽包含2D圖的3D信息,如自動駕駛KITTI數據集)
輸出:
- class x 128 x 128:每類目標單獨在一個熱圖中。
- 3 x 128 x 128:長、寬、高信息。
- 1 x 128 x 128:depth距離信息。
- 8 x 128 x 128:3D bbox的朝向信息。
效果類似如下:
?具體方法參考:https://zhuanlan.zhihu.com/p/350610859
總結
以上是生活随笔為你收集整理的CenterNet算法快速入门的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 度量学习:ArcFace算法和工程应用总
- 下一篇: 什么是self-attention、Mu