Tensorflow Object Detection API生成自己的tfrecord训练数据集
生活随笔
收集整理的這篇文章主要介紹了
Tensorflow Object Detection API生成自己的tfrecord训练数据集
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
Object Detection API谷歌
該文章部分參考別的大佬的,由于忘了內(nèi)容出處,所以沒有加轉(zhuǎn)載鏈接,請諒解,有原創(chuàng)作者看到可以聯(lián)系我添加。
========轉(zhuǎn)載請注明出處==========
此python文件放在dataset_tools下面
生成自己訓(xùn)練的數(shù)據(jù)集主要看個人annotation文件是什么格式的。我這里的每張圖都有自己的annotation文件,例如:
圖片xxx.jpg,其annotation文件為xxx.box
box文件內(nèi)容為:
Xmin Ymin Xmax Ymax? label? 如下圖:如果有多個label ,可以繼續(xù)追加在下一行:
Xmin Ymin Xmax Ymax? label \n
Xmin Ymin Xmax Ymax? label
?
from __future__ import absolute_import from __future__ import division from __future__ import print_functionimport hashlib import io import os import PIL.Image import tensorflow as tf import pandas as pd import cv2 from functools import reduce import operator from object_detection.utils import dataset_utilflags = tf.app.flags flags.DEFINE_string('train_imgs_dir', '/home/ai/Downloads/competition_change_box_img/img', 'Root directory to bc train dataset.') flags.DEFINE_string('train_labels', '/home/ai/Downloads/competition_change_box_img/box','(Relative) path to annotations directory.') flags.DEFINE_string('train_output', '../All_tf_record/competition_img_test.record', 'Path to output TFRecord') FLAGS = flags.FLAGSdef create_coordinate_info_of_content_list(image_dir,label_dir):content_list_all = []for item,file_name in enumerate(os.listdir(label_dir)):img = cv2.imread(os.path.join(image_dir,file_name.replace('.box','.jpg')))height = img.shape[0]width = img.shape[1]deepth = img.shape[2]content_list = [[file_name.replace('.box', '.jpg'), height, width, deepth]]with open(os.path.join(label_dir,file_name), 'r') as f: lines = f.readlines()for line in lines:new_line = line.split(' ')[:]content_one = [new_line[0],new_line[1],new_line[2],new_line[3],new_line[4]]content_list.append(content_one)a = reduce(operator.add,content_list)content_list_all.append(a)return content_list_alldef create_tf_example(content_list, imgs_dir):height = int(content_list[1])width = int(content_list[2])filename = content_list[0]img_path = os.path.join(imgs_dir, filename)with tf.gfile.GFile(img_path, 'rb') as fid:encoded_jpg = fid.read()encoded_jpg_io = io.BytesIO(encoded_jpg)image = PIL.Image.open(encoded_jpg_io)if image.format != 'JPEG':raise ValueError('Image format not JPEG')key = hashlib.sha256(encoded_jpg).hexdigest()xmin = []ymin = []xmax = []ymax = []classes = []classes_text = []box_num = int((len(content_list) - 4) / 5) #一張圖上可能有多個labelfor i in range(box_num):xmin.append(float(content_list[5 * i + 4 + 0]) / width)ymin.append(float(content_list[5 * i + 4 + 1]) / height)xmax.append(float(content_list[5 * i + 4 + 2]) / width)ymax.append(float(content_list[5 * i + 4 + 3]) / height)classes_text.append(content_list[5 * i + 4 + 4].encode('utf8'))classes.append(classMap[content_list[5 * i + 4 + 4]])print('the class id is {} '.format(classMap[content_list[5 * i + 4 + 4]]))example = tf.train.Example(features=tf.train.Features(feature={'image/height': dataset_util.int64_feature(height),'image/width': dataset_util.int64_feature(width),'image/filename': dataset_util.bytes_feature(filename.encode('utf8')),'image/source_id': dataset_util.bytes_feature(filename.encode('utf8')),'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),'image/encoded': dataset_util.bytes_feature(encoded_jpg),'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),'image/object/class/text': dataset_util.bytes_list_feature(classes_text),'image/object/class/label': dataset_util.int64_list_feature(classes),}))return exampledef main(_):# train tfrecord generateprint("Reading from {}".format(FLAGS.train_imgs_dir))writer = tf.python_io.TFRecordWriter(FLAGS.train_output)content_list_all = create_coordinate_info_of_content_list(FLAGS.train_imgs_dir, FLAGS.train_labels)for line in content_list_all:content_list = linetf_example = create_tf_example(content_list, FLAGS.train_imgs_dir)writer.write(tf_example.SerializeToString())writer.close()if __name__ == '__main__':tf.app.run()?
總結(jié)
以上是生活随笔為你收集整理的Tensorflow Object Detection API生成自己的tfrecord训练数据集的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 访问其他计算机被拒绝,对端口com1的访
- 下一篇: 动画和3D