Tic商业评论

关注微信公众号【站长自定义模块】,定时推送前沿、专业、深度的商业资讯。

 找回密码
 立即注册

QQ登录

只需一步,快速开始

微信登录

微信扫码,快速开始

pytorch SSD获取VOC行人数据

0
回复
6293
查看
[复制链接]
已绑定手机

49

主题

4

回帖

1228

积分

管理员

积分
1228
QQ
来源: 2021-4-30 17:04:08 显示全部楼层 |阅读模式
本文主要讲 pytorch版本的SSD训练行人的数据,以及中间遇到的问题,这里采用的方法并不是比较好的。但是可行。

新建相应的文件夹
2.jpg
首先是获取行人数据:
import os

import shutil

import bs4



#存放VOCdevkit数据的图片位置以及xml文件位置,

ann_filepath = '/home/lijingle/darknet/darknet/scripts/VOCdevkit/VOC2012/Annotations/'

img_filepath = '/home/lijingle/darknet/darknet/scripts/VOCdevkit/VOC2012/JPEGImages/'



#保存图片以及xml位置,可以自己定义

img_savepath = './VOC2012/JPEGImages/'

ann_savepath = './VOC2012/Annotations/'

if not os.path.exists(img_savepath):

  os.mkdir(img_savepath)



if not os.path.exists(ann_savepath):

  os.mkdir(ann_savepath)

names = locals()

classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',

       'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',

       'dog', 'horse', 'motorbike', 'pottedplant',

       'sheep', 'sofa', 'train', 'tvmonitor', 'person']



for file in os.listdir(ann_filepath):

  print(file)



  fp = open(ann_filepath + '/' + file)  # 打开Annotations文件

  ann_savefile = ann_savepath + file

  fp_w = open(ann_savefile, 'w')

  lines = fp.readlines()



  ind_start = []

  ind_end = []

  lines_id_start = lines[:]



  lines_id_end = lines[:]



  #   classes1 = '\t\t<name>bicycle</name>\n'

  #   classes2 = '\t\t<name>bus</name>\n'

  #   classes3 = '\t\t<name>car</name>\n'

  #   classes4 = '\t\t<name>motorbike</name>\n'

  classes5 = '\t\t<name>person</name>\n'



  # 在xml中找到object块,并将其记录下来

  while "\t<object>\n" in lines_id_start:

    a = lines_id_start.index("\t<object>\n")

    ind_start.append(a)  # ind_start是<object>的行数

    lines_id_start[a] = "delete"



  while "\t</object>\n" in lines_id_end:

    b = lines_id_end.index("\t</object>\n")

    ind_end.append(b)  # ind_end是</object>的行数

    lines_id_end[b] = "delete"



  # names中存放所有的object块

  i = 0

  for k in range(0, len(ind_start)):

    names['block%d' % k] = []

    for j in range(0, len(classes)):

      if classes[j] in lines[ind_start[i] + 1]:

        a = ind_start[i]

        for o in range(ind_end[i] - ind_start[i] + 1):

          names['block%d' % k].append(lines[a + o])

        break

    i += 1

    # print(names['block%d' % k])



  # xml头

  string_start = lines[0:ind_start[0]]



  # xml尾

  if ((file[2:4] == '09') | (file[2:4] == '10') | (file[2:4] == '11')):

    string_end = lines[(len(lines) - 11):(len(lines))]

  else:

    string_end = [lines[len(lines) - 1]]



    # 在给定的类中搜索,若存在则,写入object块信息

  a = 0

  for k in range(0, len(ind_start)):

    #     if classes1 in names['block%d' % k]:

    #       a += 1

    #       string_start += names['block%d' % k]

    #     if classes2 in names['block%d' % k]:

    #       a += 1

    #       string_start += names['block%d' % k]

    #     if classes3 in names['block%d' % k]:

    #       a += 1

    #       string_start += names['block%d' % k]

    #     if classes4 in names['block%d' % k]:

    #       a += 1

    #       string_start += names['block%d' % k]

    if classes5 in names['block%d' % k]:

      a += 1

      string_start += names['block%d' % k]



  string_start += string_end

  # print(string_start)

  for c in range(0, len(string_start)):

    fp_w.write(string_start[c])

  fp_w.close()

  # 如果没有我们寻找的模块,则删除此xml,有的话拷贝图片

  if a == 0:

    os.remove(ann_savepath + file)

  else:

    name_img = img_filepath + os.path.splitext(file)[0] + ".jpg"

    shutil.copy(name_img, img_savepath)

  fp.close()

这里得到了相应的xml文件和jpg文件,如果直接使用这里的数据会出现错误
RuntimeError: cannot perform reduction function max on tensor with no elements because the operation does not have an identity

这里的错误是xml文件里difficult标签为1所致,这个标签解释是对于图片中的人标记是不容易分辨的。所以我们还是要去除其中difficult为1的图片和xml文件

我们在制作标签时可以对difficult为0的xml文件不制作标签。这样就可以得到包含1的图片标签。
import xml.etree.ElementTree as ET

import pickle

import random, shutil

import os

from os import listdir, getcwd

from os.path import join

import bs4

from PIL import Image



classes = ["person"]  #为了获得cls id

diffs = ["0"]

def convert(size, box):

  dw = 1. / (size[0])

  dh = 1. / (size[1])

  x = (box[0] + box[1]) / 2.0 - 1

  y = (box[2] + box[3]) / 2.0 - 1

  w = box[1] - box[0]

  h = box[3] - box[2]

  x = x * dw

  w = w * dw

  y = y * dh

  h = h * dh

  return (x, y, w, h)



def convert_annotation(image_id):

  global none_counts

  # 输入文件xml

  in_file = open('./VOC2012_bak/Annotations/%s.xml' % (image_id))

  # 输出label txt

  out_file = open('./VOC2012_bak/labels/%s.txt' % (image_id), 'w')

  tree = ET.parse(in_file)

  root = tree.getroot()

  size = root.find('size')

  # 这里对不标准的xml文件(没有size字段)做了特殊处理,打开对应的图片,获取h, w

  if size == None:

    print('{}不存在size字段'.format(image_id))   

    img = Image.open('/home/lijingle/fun/pytorch/SSD/datasets/VOC2012_bak/JPEGImages/' + image_id + '.jpg')

    w, h = img.size  #大小/尺寸

    print('{}.xml缺失size字段, 读取{}图片得到对应 w:{} h:{}'.format(image_id, image_id, w, h))     

    none_counts += 1

  else:

    w = int(size.find('width').text)

    h = int(size.find('height').text)

    

  for obj in root.iter('object'):

    cls = obj.find('name').text

    dif = obj.find('difficult').text

    if cls not in classes:

      continue

    if dif in diffs:

      continue

    cls_id = classes.index(cls)

#       print('cls_id is {}'.format(cls_id))

    xmlbox = obj.find('bndbox')

    b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),

       float(xmlbox.find('ymax').text))

    bb = convert((w, h), b)

    out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')



if __name__=='__main__':

  xml_count = 0

  none_counts = 0

  list_file = os.listdir('./VOC2012_bak/Annotations/')

  for file in list_file:

    image_id = file.replace('.xml', '')

    convert_annotation(image_id)

    xml_count = xml_count + 1

  print('没有size字段的xml文件数目:{}'.format(none_counts))

  print('转换的总xml个数是 {}'.format(xml_count))
获取的如图所示:
2.jpg

2.jpg
这样我们就可以区分difficult为1和0的图片。

然后获取相应的文件名字:
import os

import random

 

trainval_percent = 1

train_percent = 0.85

xmlfilepath = '/home/lijingle/fun/pytorch/SSD/datasets/VOC2012_bak/la'

txtsavepath = '/home/lijingle/fun/pytorch/SSD/datasets/VOC2012_bak'

total_xml = os.listdir(xmlfilepath)

 

num=len(total_xml)

list=range(num)

tv=int(num*trainval_percent)

trainval= random.sample(list,tv)



 

ftrainval = open(txtsavepath+'/trainval.txt', 'w')



 

for i  in list:

  name=total_xml[i][:-4]+'.xml'+'\n'

  if i in trainval:

    ftrainval.write(name)

  else:

    ftest.write(name)

 

ftrainval.close()
得到如图所示的文件:
2.jpg

对相应的xml文件和jpg文件进行清除:
# -*- coding: UTF-8 -*- 

#!/usr/bin/env python

import sys

import re

import shutil

import os

from PIL import Image

sys.path.append('./VOC2012_bak')

import numpy as np

 

data = []

for line in open("./VOC2012_bak/trainval.txt", "r"):  # 设置文件对象并读取每一行文件

  data.append(line)

for a in data:

  os.remove('./VOC2012_bak/Annotations/'+a[:-1])#打开改路径下的line3记录的的文件名

  print('./VOC2012_bak/Annotations/'+a[:-1])


这样就可以清除掉difficult为1的文件。对图像使用SSD进行训练如图所示:
2.jpg









回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册
电话咨询: 135xxxxxxx
关注微信