yolov5训练_YoloV5模型训练实战教程:Kaggle全球小麦检测竞赛
寫在前面
前段時間參加了Kaggle的一個目標檢測競賽,比賽后期因為工作較繁忙就擱置了,但仍然獲得了銅牌(前10%)。因此在這里想跟大家分享下自己的方案,希望能幫助大家更好的了解目標檢測這一經典的計算機視覺領域。
這篇教程的主要代碼來源于這個git倉庫(https://github.com/ultralytics/yolov5),是國外一個公司開源的。選擇這個項目一是因為性能好,最新mAP達到了50.8/25.5ms,太強大了;二是因為該項目是用pytorch實現的,使用門檻低,很適合初學者。下面開始我們的實戰教程。
YoloV5
數據分析
這里需要重點說明下,CV任務的第一步絕對不是搭模型,而是觀察數據,只有了解的數據的組成和分布,才能搭出性能更好好的模型。首先看下比賽數據,看下面9張圖片,可以看出小麥品種不一,風格差異很大,所以很明顯Domain Gap是這個比賽的難點。
小麥數據
再來看下標注框,每張圖有幾十個目標,分布非常密集,所以這個任務其實屬于密集小目標檢測問題,因此像FPN這種金字塔模型肯定是必不可少的。
數據標注展示
數據處理
目標檢測任務跟分類不同,數據的格式有很多種,比較常用的是COCO和VOC格式,YoloV5使用的是YOLO自己的格式。Kaggle的數據標簽是用csv格式保存的,需要轉換成YOLO的標注格式。格式轉換可以參考下面這段代碼。
import numpy as np # linear algebraimport pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)import osdf = pd.read_csv('../input/global-wheat-detection/train.csv')bboxs = np.stack(df['bbox'].apply(lambda x: np.fromstring(x[1:-1], sep=',')))for i, column in enumerate(['x', 'y', 'w', 'h']): df[column] = bboxs[:,i]df.drop(columns=['bbox'], inplace=True)df['x_center'] = df['x'] + df['w']/2df['y_center'] = df['y'] + df['h']/2df['classes'] = 0from tqdm.auto import tqdmimport shutil as shdf = df[['image_id','x', 'y', 'w', 'h','x_center','y_center','classes']]index = list(set(df.image_id))index = list(set(df.image_id))source = 'train'if True: for fold in [0]: val_index = index[len(index)*fold//5:len(index)*(fold+1)//5] for name,mini in tqdm(df.groupby('image_id')): if name in val_index: path2save = 'val2017/' else: path2save = 'train2017/' if not os.path.exists('convertor/fold{}/labels/'.format(fold)+path2save): os.makedirs('convertor/fold{}/labels/'.format(fold)+path2save) with open('convertor/fold{}/labels/'.format(fold)+path2save+name+".txt", 'w+') as f: row = mini[['classes','x_center','y_center','w','h']].astype(float).values row = row/1024 row = row.astype(str) for j in range(len(row)): text = ' '.join(row[j]) f.write(text) f.write("") if not os.path.exists('convertor/fold{}/images/{}'.format(fold,path2save)): os.makedirs('convertor/fold{}/images/{}'.format(fold,path2save)) sh.copy("../input/global-wheat-detection/{}/{}.jpg".format(source,name),'convertor/fold{}/images/{}/{}.jpg'.format(fold,path2save,name))模型訓練
先看下YoloV5長啥樣,畢竟從YoloV1發展到V5,模型也確實復雜了很多,不過其實核心結構沒有變,只是增加了大量的trick,每一個trick都需要大量的調參,這也是YoloV5作者的一個非常大的貢獻。不過我建議剛上手不用了解的這么深入,先跑起來再說。
YoloV5網絡結構圖
數據處理完畢后,下載YoloV5的源代碼,在data文件下配置自己的數據路徑,例如下面的wheat0.yaml,主要是訓練和驗證集的路徑,類別數量和類別名這4個參數。
# train and val datasets (image directory or *.txt file with image paths)train: ./convertor/fold0/images/train2017/val: ./convertor/fold0/images/val2017/# number of classesnc: 1# class namesnames: ['wheat']配置文件更改完成后就可以直接訓練了,需要訓練100個epoch才能得到一個較好的模型。使用如下的訓練腳本開始訓練。
python train.py --img 1024 --batch 2 --epochs 100 --data ../input/configyolo5/wheat0.yaml --cfg ../input/configyolo5/yolov5x.yaml --name yolov5x_fold0下面是我訓練模型的可視化結果,可以看出來效果還是不錯的,基本沒有漏標或誤判。
預測結果展示
寫在后面
以上就是YoloV5的實戰訓練教程,其實跑起來還是很簡單的。大家可以先試下,把訓練流程跑通,下篇文章我會把我的測試流程和kaggle的提交代碼也分享出來,歡迎大家關注和轉發。有任何問題可以在文章下面評論,我會及時回復。
總結
以上是生活随笔為你收集整理的yolov5训练_YoloV5模型训练实战教程:Kaggle全球小麦检测竞赛的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: group by rollup 变量名为
- 下一篇: automation服务器不能创建对象是