【图像分类】 基于Pytorch的细粒度图像分类实战
歡迎大家來到《圖像分類》專欄,今天講述基于pytorch的細粒度圖像分類實戰!
作者&編輯 | 郭冰洋
1 簡介
針對傳統的多類別圖像分類任務,經典的CNN網絡已經取得了非常優異的成績,但在處理細粒度圖像數據時,往往無法發揮自身的最大威力。
這是因為細粒度圖像間存在更加相似的外觀和特征,同時在采集中存在姿態、視角、光照、遮擋、背景干擾等影響,導致數據呈現類間差異性大、類內差異性小的現象,從而使分類更加具有難度。
為了改善經典CNN網絡在細粒度圖像分類中的表現,同時不借助其他標注信息,人們提出了雙線性網絡(Bilinear CNN)這一非常具有創意的結構,并在細粒度圖像分類中取得了相當可觀的進步。
本次實戰將通過CUB-200數據集進行訓練,對比經典CNN網絡結構和雙線性網絡結構間的差異性。
2 數據集
首先我們回顧一下在多類別圖像分類實戰中所提出的圖像分類任務的五個步驟。其中,在整個任務中最基礎的一環就是根據數據集的構成編寫相應的讀取代碼,這也是整個訓練的關鍵所在。
本次實戰選擇的數據集為CUB-200數據集,該數據集是細粒度圖像分類領域最經典,也是最常用的一個數據集。共包括annotations、attributes、attributes-yaml、images、lists五個文件夾。
此次實戰中,我們只利用數據集提供的類別標注信息。因此只需要關注lists文件夾下的train.txt和test.txt文件即可。
通過圖片我們可以看到,兩個txt文件中給出了不同圖片的相對路徑,而開頭數字則代表了對應的標記信息,但是pytorch中的標簽必須從0開始,因此我們只需要借助strip和split函數即可完成圖像和標簽信息的獲取。
# txt文件路徑
path = '/media/by/Udata/Datasets/bird/lists/train.txt'
txt = open(path,'r')
imgs = []
# 讀取每行信息
? ?line = line.strip('\n')
? ?# 將每行內容以'.'為標記劃分
? ?# 添加至列表
輸出結果示例如下圖所示:
此時我們只需要將上述模塊融合進pytorch的數據集讀取模塊即可,代碼如下:
class cub_dataset(Dataset):
? ?def __init__(self, transform):
????????'/media/by/Udata/Datasets/bird/lists/train.txt', 'r')
????? ? '/media/by/Udata/Datasets/bird/images/' + fn)
3 網絡搭建
本次實戰主要選取了經典Resnet 50網絡結構和基于Resnet 50的雙線性網絡結構。
Resnet 50作為經典的分類網絡,其結構不再贅述,在此詳細介紹一下雙線性網絡的構建。
如上圖所示,雙線性網絡包括兩個分支CNN結構,這兩個分支可以是相同的網絡,也可以是不同的網絡,本次實戰使用Resnet 50做為相同的分支網絡,以保證對比的客觀性。
在此網絡下將圖像送入兩個分支Resnet 50之后,把獲取到的兩個特征分支進行相應的融合操作。
具體代碼如下:
class Net(nn.Module):
??????????????????????????????????????????????????????resnet50().bn1,?
???? ???????????????????????????????????????????????? resnet50().relu,?
????????????????????????????????????????????????????? resnet50().maxpool,?
????????????????????????????????????????????????????? resnet50().layer1,
????????????????????????????????????????????????????? resnet50().layer2,
??????????????????????????????????????????????????????resnet50().layer3,
??????????????????????????????????????????????????????resnet50().layer4)
? ? ? ?torch.transpose(x, 1, 2)) / 28 ** 2).view(batch_size, -1)
??????????????torch.sqrt(torch.abs(x) + 1e-10))
? ? ? ?x = self.classifiers(x)
4 訓練及參數調試
損失函數選擇交叉熵損失函數,優化方式選擇SGD優化。初始學習率設置為0.01,batch size設置為8,衰減率設置為0.00001,迭代周期為20,采用top-5評價指標
最終的訓練結果如下圖所示:
Resnet 50最終取得的準確率約52%左右,而基于Resnet 50的雙線性網絡取得了近80%的準確率,由此可見不同的網絡在細粒度分類任務上的性能差異非常巨大。
項目代碼:發送“細粒度分類”到有三AI公眾號后臺可獲取。
總結
以上就是整個細粒度圖像分類實戰的過程,本次實戰并沒有進行精細的調參工作,因此雙線性網絡的性能與原文中具有一定的差異,同時也期待大家去發掘更有效、更精準的細粒度分類網絡哦!
有三AI夏季劃
有三AI夏季劃進行中,歡迎了解并加入,系統性成長為中級CV算法工程師。
轉載文章請后臺聯系
侵權必究
往期精選
【技術綜述】你真的了解圖像分類嗎?
【技術綜述】多標簽圖像分類綜述
【圖像分類】分類專欄正式上線啦!初入CV、AI你需要一份指南針!
【圖像分類】從數據集和經典網絡開始
【圖像分類】 基于Pytorch的多類別圖像分類實戰
【圖像分類】細粒度圖像分類是什么,有什么方法,發展的怎么樣
總結
以上是生活随笔為你收集整理的【图像分类】 基于Pytorch的细粒度图像分类实战的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【知识星球】ElementAI提出超复杂
- 下一篇: 【知识星球】softmax损失相关的小问