Tic商业评论

关注微信公众号【站长自定义模块】,定时推送前沿、专业、深度的商业资讯。

 找回密码
 立即注册

QQ登录

只需一步,快速开始

微信登录

微信扫码,快速开始

YOLOv3(darknet)训练自己的数据集

0
回复
4460
查看
[复制链接]
已绑定手机

49

主题

4

回帖

1244

积分

管理员

积分
1244
QQ
来源: 2021-4-20 15:41:52 显示全部楼层 |阅读模式
一,制作自己的数据集,工具为 labelimg
    将数据集保存到文件夹,然后使用工具对所有数据进行标记,并保存为xml文件。里面包含种类的位置,图像的大小等信息。    具体参考上一篇博客。
二,下载darknet框架,并下载YOLO工程。在进行编译,如下:
git clone https://github.com/pjreddie/darknet

cd darknet
       修改Makefile文件,修改内容如下:
GPU=1    #使用GPU

CUDNN=1  #使用cudnn加速工具

OPENCV=1   #使用opencv,在读取视频,以及摄像头时使用

OPENMP=0  #零默认不使用

DEBUG=0

     这里我使用的GPU和cudnn,以及opencv所以将其置1,
#对其进行编译

make -j32   #-j 32 为使用32线程进行编译

三,数据集制作:


1.获取person数据
这里我们使用的是VOCdevkit数据集里person类别,首先要对person类别进行提取,包括图片和xml文件,定义文件名为get_person.py
import os

import shutil

import bs4



#存放VOCdevkit数据的图片位置以及xml文件位置,

ann_filepath = '/home/lijingle/darknet/darknet/scripts/VOCdevkit/VOC2012/Annotations/'

img_filepath = '/home/lijingle/darknet/darknet/scripts/VOCdevkit/VOC2012/JPEGImages/'



#保存图片以及xml位置,可以自己定义

img_savepath = './VOCPerson/JPEGImages/'

ann_savepath = './VOCPerson/Annotations/'

if not os.path.exists(img_savepath):

  os.mkdir(img_savepath)



if not os.path.exists(ann_savepath):

  os.mkdir(ann_savepath)

names = locals()

classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',

       'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',

       'dog', 'horse', 'motorbike', 'pottedplant',

       'sheep', 'sofa', 'train', 'tvmonitor', 'person']



for file in os.listdir(ann_filepath):

  print(file)



  fp = open(ann_filepath + '/' + file)  # 打开Annotations文件

  ann_savefile = ann_savepath + file

  fp_w = open(ann_savefile, 'w')

  lines = fp.readlines()



  ind_start = []

  ind_end = []

  lines_id_start = lines[:]



  lines_id_end = lines[:]



#定义classes的类,及要提取的类

  #   classes1 = '\t\t<name>bicycle</name>\n'

  #   classes2 = '\t\t<name>bus</name>\n'

  #   classes3 = '\t\t<name>car</name>\n'

  #   classes4 = '\t\t<name>motorbike</name>\n'

  classes5 = '\t\t<name>person</name>\n'



  # 在xml中找到object块,并将其记录下来

  while "\t<object>\n" in lines_id_start:

    a = lines_id_start.index("\t<object>\n")

    ind_start.append(a)  # ind_start是<object>的行数

    lines_id_start[a] = "delete"



  while "\t</object>\n" in lines_id_end:

    b = lines_id_end.index("\t</object>\n")

    ind_end.append(b)  # ind_end是</object>的行数

    lines_id_end[b] = "delete"



  # names中存放所有的object块

  i = 0

  for k in range(0, len(ind_start)):

    names['block%d' % k] = []

    for j in range(0, len(classes)):

      if classes[j] in lines[ind_start[i] + 1]:

        a = ind_start[i]

        for o in range(ind_end[i] - ind_start[i] + 1):

          names['block%d' % k].append(lines[a + o])

        break

    i += 1

    # print(names['block%d' % k])



  # xml头

  string_start = lines[0:ind_start[0]]



  # xml尾

  if ((file[2:4] == '09') | (file[2:4] == '10') | (file[2:4] == '11')):

    string_end = lines[(len(lines) - 11):(len(lines))]

  else:

    string_end = [lines[len(lines) - 1]]



    # 在给定的类中搜索,若存在则,写入object块信息

  a = 0

  for k in range(0, len(ind_start)):

    #     if classes1 in names['block%d' % k]:

    #       a += 1

    #       string_start += names['block%d' % k]

    #     if classes2 in names['block%d' % k]:

    #       a += 1

    #       string_start += names['block%d' % k]

    #     if classes3 in names['block%d' % k]:

    #       a += 1

    #       string_start += names['block%d' % k]

    #     if classes4 in names['block%d' % k]:

    #       a += 1

    #       string_start += names['block%d' % k]

    if classes5 in names['block%d' % k]:

      a += 1

      string_start += names['block%d' % k]



  string_start += string_end

  # print(string_start)

  for c in range(0, len(string_start)):

    fp_w.write(string_start[c])

  fp_w.close()

  # 如果没有我们寻找的模块,则删除此xml,有的话拷贝图片

  if a == 0:

    os.remove(ann_savepath + file)

  else:

    name_img = img_filepath + os.path.splitext(file)[0] + ".jpg"

    shutil.copy(name_img, img_savepath)

  fp.close()
如图所示:
2.jpg



2.将xml文件转化为txt文件,文件名为XML_to_TXT.py
import xml.etree.ElementTree as ET

import pickle

import os

from os import listdir, getcwd

from os.path import join

import bs4

from PIL import Image



classes = ["person"]  #为了获得cls id



def convert(size, box):

  dw = 1. / (size[0])

  dh = 1. / (size[1])

  x = (box[0] + box[1]) / 2.0 - 1

  y = (box[2] + box[3]) / 2.0 - 1

  w = box[1] - box[0]

  h = box[3] - box[2]

  x = x * dw

  w = w * dw

  y = y * dh

  h = h * dh

  return (x, y, w, h)



def convert_annotation(image_id):

  global none_counts

  # 输入文件xml

  in_file = open('./VOCPerson/Annotations/%s.xml' % (image_id))

  # 输出label txt

  out_file = open('./VOCPerson/labels/%s.txt' % (image_id), 'w')

  tree = ET.parse(in_file)

  root = tree.getroot()

  size = root.find('size')

  # 这里对不标准的xml文件(没有size字段)做了特殊处理,打开对应的图片,获取h, w

  if size == None:

    print('{}不存在size字段'.format(image_id))   

    img = Image.open('/home/lijingle/darknet/darknet/data/person/VOCPerson/JPEGImages/' + image_id + '.jpg')

    w, h = img.size  #大小/尺寸

    print('{}.xml缺失size字段, 读取{}图片得到对应 w:{} h:{}'.format(image_id, image_id, w, h))     

    none_counts += 1

  else:

    w = int(size.find('width').text)

    h = int(size.find('height').text)

    

  for obj in root.iter('object'):

    cls = obj.find('name').text

    if cls not in classes:

      continue

    cls_id = classes.index(cls)

#       print('cls_id is {}'.format(cls_id))

    xmlbox = obj.find('bndbox')

    b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),

       float(xmlbox.find('ymax').text))

    bb = convert((w, h), b)

    out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')



if __name__=='__main__':

  xml_count = 0

  none_counts = 0

  list_file = os.listdir('./VOCPerson/Annotations/')

  for file in list_file:

    image_id = file.replace('.xml', '')

    convert_annotation(image_id)

    xml_count = xml_count + 1

  print('没有size字段的xml文件数目:{}'.format(none_counts))

  print('转换的总xml个数是 {}'.format(xml_count))
如图所示:
2.jpg


3.将数据集进行train和valid划分,val_train.py,以及生成train.txt和valid.txt
# coding: utf-8



# 将YoloJPEGImages里面的图片随机划分成训练集,验证集,测试集



import os, random, shutil

def moveFile(fileDir):

  

    pathDir = os.listdir(fileDir)  #取图片的原始路径     

#     # 获取所有图片的名字前缀

#     total_pathdir_first = [i[:-4] for i in pathDir]

#     print(len(total_pathdir_first))

    filenumber=len(pathDir)

    

    

    # 训练集比率

    train_rate=0.8  #自定义抽取图片的比例,比方说100张抽10张,那就是0.1

    train_picknumber=int(filenumber*train_rate) #按照rate比例从文件夹中取一定数量图片

    

    

    # 验证集比率

    valid_rate=0.1  #自定义抽取图片的比例,比方说100张抽10张,那就是0.1

    valid_picknumber=int(filenumber*valid_rate) #按照rate比例从文件夹中取一定数量图片

    

    # 测试集比率

    test_rate=0.1  #自定义抽取图片的比例,比方说100张抽10张,那就是0.1

    test_picknumber=int(filenumber*test_rate) #按照rate比例从文件夹中取一定数量图片

    

    # 剪切训练集

    train_sample_list = random.sample(pathDir, train_picknumber)  #随机选取train_picknumber数量的样本图片     

    for name in train_sample_list:

        shutil.copy(fileDir+name, trainDir+name)

        

      

    pathDir = os.listdir(fileDir)  #取图片的原始路径  

    # 剪切验证集

    valid_sample_list = random.sample(pathDir, valid_picknumber)  #随机选取train_picknumber数量的样本图片     

    for name in valid_sample_list:

        shutil.copy(fileDir+name, validDir+name)



        

    pathDir = os.listdir(fileDir)  #取图片的原始路径  

    # 剪切测试集

    test_sample_list = random.sample(pathDir, test_picknumber)  #随机选取train_picknumber数量的样本图片     

    for name in test_sample_list:

        shutil.copy(fileDir+name, trainDir+name)

    

    return

  

# 指定划分数据集后的文件路径

trainDir = './images/train/'

testDir = './images/test/'

validDir = './images/valid/'



if __name__ == '__main__':

  fileDir = "./VOCPerson/JPEGImages/"  # 源图片文件夹路径

  moveFile(fileDir)

  list_train_data=os.listdir(trainDir)

  file = open('‪‪finally_persontrain.txt', 'w+')

  for i in list_train_data:

    file.write(i+'\n')

  file.close()

  list_valid_data=os.listdir(validDir)

  file = open('‪‪finally_personvalid.txt', 'w+')

  for i in list_valid_data:

    file.write(i+'\n')

  file.close()
如图所示:
2.jpg

4,将标签文件放到对于的文件夹movelabel.py(相应文件夹需要自己新建):
# 将训练集,验证集的数据对应的label文件移到label文件夹下面对应的文件中



import os, random, shutil





train_data_label = './labels/train/'

valid_data_label = './labels/valid/'



# 原始的所有label文件夹



root_label = './VOCPerson/labels/'



# 操作训练集

count = 0

for i in open('./‪finally_persontrain.txt', 'r'):

  temp = i[:-5]

  #print(temp)

  shutil.copy(root_label + temp + '.txt', train_data_label + temp + '.txt')

  # count += 1

  # print(count)



# 操作验证集

for i in open('./‪finally_personvalid.txt', 'r'):

  temp = i[:-5]

  # print(temp)

  shutil.copy(root_label + temp + '.txt', valid_data_label + temp + '.txt')
如图所示:
2.jpg
6,将图片的路径以及名称放到txt文件下 例如persontrain.txt以及personvalid.txt  运行代码文件creatrtxt.py
# 根据训练数据集和验证数据集persontrain.txt and personvalid.txt

import os, random, shutil



trainDir = '/home/lijingle/darknet/darknet/data/person/images/train/'

validDir = '/home/lijingle/darknet/darknet/data/person/images/valid/'



train_pathDir = os.listdir(trainDir)  # 取图片的原始路径

print('训练集图片数目: {}'.format(len(train_pathDir)))



valid_pathDir = os.listdir(validDir)  # 取图片的原始路径

print('验证集图片数目: {}'.format(len(valid_pathDir)))





# 删除persontrain.txt and personvalid.txt



if(os.path.exists('./persontrain.txt')):

   os.remove('./persontrain.txt')

   print('删除persontrain.txt成功')





if(os.path.exists('./personvalid.txt')):

   os.remove('./personvalid.txt')

   print('删除personvalid.txt成功')





def text_save(root, filename, data):  # filename为写入CSV文件的路径,data为要写入数据列表.

  file = open(filename, 'a')

  for i in range(len(data)):

    s = str(data[i]).replace('[', '').replace(']', '')  # 去除[],这两行按数据不同,可以选择

    s = root + s.replace("'", '').replace(',', '') + '\n'  # 去除单引号,逗号,每行末尾追加换行符

    file.write(s)

  file.close()

  print("保存文件成功")





if __name__ == '__main__':

  text_save(trainDir, './persontrain.txt', train_pathDir)

  text_save(validDir, './personvalid.txt', valid_pathDir)



  print('persontrain.txt 有 {} 行'.format(len([i for i in open('./persontrain.txt', 'r')])))

  print('personvalid.txt 有 {} 行'.format(len([i for i in open('./personvalid.txt', 'r')])))
至此数据基本制作完成;如图所示:
2.jpg

四. 模型训练
制作模型训练所需要的文件:person.names以及person.data,这里person.data文件需要从原理文件中进行copy然后修改,不要自己新建这个文件
内容如下:
classes= 1

train  = /home/lijingle/darknet/darknet/data/person/persontrain.txt

valid  = /home/lijingle/darknet/darknet/data/person/personvalid.txt

names = /home/lijingle/darknet/darknet/data/person/person.names

backup = backup  #模型存放路径
person.names如下
person

这里还需要对配置文件yolov3.cfg进行修改
种类改变需要修改三个地方:
每个地方都必须要改2处, filters:3*(5+len(classes));这里我们只有person一个类别所以改为18
2.jpg

可以采取搜索的方式,搜索classes进行定位;

因为是训练模型,所以在开始处还要修改:
[net]

# Testing

# batch=1

# subdivisions=1

# Training

batch=64      #每batch个样本更新一次参数。

subdivisions=3  #如果内存不够大,将batch分割为subdivisions个子batch,每个子batch的大小为batch/subdivisions。

width=416

height=416

channels=3

momentum=0.9

decay=0.0005

angle=0

saturation = 1.5

exposure = 1.5

hue=.1


至此可以训练了,如下图所示:
2.jpg

creatrtxt.py (1.61 KB, 下载次数: 0) get_person.py (3.74 KB, 下载次数: 0) movelabel.py (713 Bytes, 下载次数: 0) person.data (227 Bytes, 下载次数: 0) person.names (6 Bytes, 下载次数: 0) train_valid.py (2.71 KB, 下载次数: 0) XML_to_TXT.py (2.3 KB, 下载次数: 0)


回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册
电话咨询: 135xxxxxxx
关注微信