Tic商业评论

关注微信公众号【站长自定义模块】,定时推送前沿、专业、深度的商业资讯。

 找回密码
 立即注册

QQ登录

只需一步,快速开始

微信登录

微信扫码,快速开始

  • QQ空间
  • 回复
  • 收藏

darknet框架解析(一)-框架概述

lijingle 深度学习 2022-2-12 13:14 1694人围观

概述
darknet是一个很优秀的深度学习框架。由于其代码量少,框架简单并且使用的是c语言进行编写,所以可以很好的移植到其他平台上,这对于工程技术人员来说是很友好的。另外一个优点是其不依赖其他的库,甚至可以不依赖opencv。如果不使用GPU加速,只在cpu上进行训练可以不依赖cuda。当然也是有缺点的正是因为其代码使用c语言进行编写,并且里面有很多超参是很难进行修改的。如果想要修改网络,就要修改很多东西,这导致对模型进行更改,或者使用框架实现新网络就会有很大的工程量。

本系列主要介绍darknet框架,其中包括darknet的数据结构部分,代码运行流程,图像卷积操作,YOLO损失函数,以及深度学习中使用到的一些其它操作。如果时间够用,也会添加一些cuda编程的分析。由于本人能力有限,内容难免有些错误,请担待。


darknet框架分析主线
由于darknet是c语言进行编写,所以首先我们找到框架的main函数。本文主要介绍的是AlexeyAB DarkNet。我们可以在src/darknet.c文件中找到main函数。在main函数中我们可以看到函数run_yolo() run_detector() run_super() run_classifier()等函数。run_yolo()在yolo.c文件中实现,只提供了YOLO算法的原始实现。而run_detector()是AlexeyAB修改后添加了很多trick的YOLO算法。本文主要沿着run_detector()这条线进行分析,其它线路感兴趣可以自己去分析。函数实现在src/detector.c文件中。run_super()函数是高分辨率重建,在src/super.c文件中实现。run_classifer()函数是用于图像分类的网络,在src/classifier.c文件中实现。
下面是代码部分(只贴出部分代码)
int main(int argc, char **argv)
{
if (0 == strcmp(argv[1], "average")){
        average(argc, argv);
    } else if (0 == strcmp(argv[1], "yolo")){
        run_yolo(argc, argv);
    } else if (0 == strcmp(argv[1], "super")){
        run_super(argc, argv);
    } else if (0 == strcmp(argv[1], "detector")){
        run_detector(argc, argv);
    } else if (0 == strcmp(argv[1], "detect")){
        float thresh = find_float_arg(argc, argv, "-thresh", .24);
		int ext_output = find_arg(argc, argv, "-ext_output");
        char *filename = (argc > 4) ? argv[4]: 0;
        test_detector("cfg/coco.data", argv[2], argv[3], filename, thresh, 0.5, 0, ext_output, 0, NULL, 0, 0);
    } else if (0 == strcmp(argv[1], "cifar")){
        run_cifar(argc, argv);
    } else if (0 == strcmp(argv[1], "go")){
        run_go(argc, argv);
}

运行主线函数
run_detector()->train_detector()  [->read_data_cfg() ->parse_network_cfg_custom() 
->load_data() ->train_network()]->train_network_waitkey()->[train_network_datum()
->update_network()]->update_network()

run_detector函数
这里跟进run_detector函数,这个函数也是目标检测函数的集合其中包含train函数,valid函数,
以及map函数等。在这里我们只选择跟进train函数。具体run_detector的内容解释如下:
void run_detector(int argc, char **argv)
{
    int dont_show = find_arg(argc, argv, "-dont_show");//展示图像界面
    int benchmark = find_arg(argc, argv, "-benchmark");//评估模型的表现
    int benchmark_layers = find_arg(argc, argv, "-benchmark_layers");
    //if (benchmark_layers) benchmark = 1;
    if (benchmark) dont_show = 1;
    int show = find_arg(argc, argv, "-show");
    int letter_box = find_arg(argc, argv, "-letter_box");//是否对图像做letter-box变换
    int calc_map = find_arg(argc, argv, "-map");//是否计算map值
    int map_points = find_int_arg(argc, argv, "-points", 0);
    check_mistakes = find_arg(argc, argv, "-check_mistakes");//检查数据是否有误
    int show_imgs = find_arg(argc, argv, "-show_imgs");//显示图片
    int mjpeg_port = find_int_arg(argc, argv, "-mjpeg_port", -1);
    int json_port = find_int_arg(argc, argv, "-json_port", -1);
    char *http_post_host = find_char_arg(argc, argv, "-http_post_host", 0);
    int time_limit_sec = find_int_arg(argc, argv, "-time_limit_sec", 0);
    char *out_filename = find_char_arg(argc, argv, "-out_filename", 0);
    char *outfile = find_char_arg(argc, argv, "-out", 0);
    char *prefix = find_char_arg(argc, argv, "-prefix", 0);//模型保存的前缀
    float thresh = find_float_arg(argc, argv, "-thresh", .25);    // 置信度
    float iou_thresh = find_float_arg(argc, argv, "-iou_thresh", .5);    // 0.5 for mAP
    float hier_thresh = find_float_arg(argc, argv, "-hier", .5);
    int cam_index = find_int_arg(argc, argv, "-c", 0);//摄像头编号
    int frame_skip = find_int_arg(argc, argv, "-s", 0);//跳帧检测间隔
    int num_of_clusters = find_int_arg(argc, argv, "-num_of_clusters", 5);
    int width = find_int_arg(argc, argv, "-width", -1);// 输入网络的图像宽度
    int height = find_int_arg(argc, argv, "-height", -1);// 输入网络的图像高度
    // extended output in test mode (output of rect bound coords)
    // and for recall mode (extended output table-like format with results for best_class fit)
    int ext_output = find_arg(argc, argv, "-ext_output");
    int save_labels = find_arg(argc, argv, "-save_labels");
    if (argc < 4) {
        fprintf(stderr, "usage: %s %s [train/test/valid/demo/map] [data] [cfg] [weights (optional)]\n", argv[0], argv[1]);
        return;
    }    // 解析输入参数,获取GPU使用情况,如果使用单个GPU,那么调用时不需要指明GPU卡号,默认使用卡号0上的GPU;
    // 如果使用多块GPU,那么在调用时,其中有两个参数必须为:-gpus 0,1,2...(以逗号隔开)
    // 前者指明是GPU卡号参数,后者为多块GPU的卡号,find_char_arg就是将0,1,2...读入gpu_list
    char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);// 多个gpu训练
    int *gpus = 0;
    int gpu = 0;
    int ngpus = 0;
    if (gpu_list) {
        printf("%s\n", gpu_list);
        int len = (int)strlen(gpu_list);
        ngpus = 1;
        int i;
        for (i = 0; i < len; ++i) {
            if (gpu_list[i] == ',') ++ngpus;
        }
        gpus = (int*)xcalloc(ngpus, sizeof(int));
        for (i = 0; i < ngpus; ++i) {
            gpus[i] = atoi(gpu_list);
            gpu_list = strchr(gpu_list, ',') + 1;
        }
    }
    else {
        gpu = gpu_index;
        gpus = &gpu;
        ngpus = 1;
    }

    int clear = find_arg(argc, argv, "-clear");

    char *datacfg = argv[3];//存储训练集,验证集,以及类别对应名字等信息的cfg文件
    char *cfg = argv[4];//要训练的网络cfg文件
    char *weights = (argc > 5) ? argv[5] : 0;//是否有预训练模型
    if (weights)
        if (strlen(weights) > 0)
            if (weights[strlen(weights) - 1] == 0x0d) weights[strlen(weights) - 1] = 0;
    char *filename = (argc > 6) ? argv[6] : 0;
    if (0 == strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, dont_show, ext_output, save_labels, outfile, letter_box, benchmark_layers);//执行目标检测模型测试
    else if (0 == strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear, dont_show, calc_map, mjpeg_port, show_imgs, benchmark_layers);//目标检测模型训练
    else if (0 == strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile);//目标检测模型验证
    else if (0 == strcmp(argv[2], "recall")) validate_detector_recall(datacfg, cfg, weights);///计算验证集的召回率
    else if (0 == strcmp(argv[2], "map")) validate_detector_map(datacfg, cfg, weights, thresh, iou_thresh, map_points, letter_box, NULL);//计算验证集的map值
    else if (0 == strcmp(argv[2], "calc_anchors")) calc_anchors(datacfg, num_of_clusters, width, height, show);//计算验证集的anchors
    else if (0 == strcmp(argv[2], "demo")) {//demo展示
        list *options = read_data_cfg(datacfg);
        int classes = option_find_int(options, "classes", 20);
        char *name_list = option_find_str(options, "names", "data/names.list");
        char **names = get_labels(name_list);
        if (filename)
            if (strlen(filename) > 0)
                if (filename[strlen(filename) - 1] == 0x0d) filename[strlen(filename) - 1] = 0;
        demo(cfg, weights, thresh, hier_thresh, cam_index, filename, names, classes, frame_skip, prefix, out_filename,
            mjpeg_port, json_port, dont_show, ext_output, letter_box, time_limit_sec, http_post_host, benchmark, benchmark_layers);

        free_list_contents_kvp(options);
        free_list(options);
    }
    else printf(" There isn't such command: %s", argv[2]);

    if (gpus && gpu_list && ngpus > 1) free(gpus);
}

跟进train_detector
这个函数是网络训练的核心,其中包含数据的读取,网络文件的读取,以及网络传输和网络反向传输。
void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map, int mjpeg_port, int show_imgs, int benchmark_layers)
{
    // 从options找出训练图片路径信息,如果没找到,默认使用"data/train.list"路径下的图片信息(train.list含有标准的信息格式:<object-class> <x> <y> <width> <height>),
    // 该文件可以由darknet提供的scripts/voc_label.py根据自行在网上下载的voc数据集生成,所以说是默认路径,其实也需要使用者自行调整,也可以任意命名,不一定要为train.list,
    // 甚至可以不用voc_label.py生成,可以自己不厌其烦的制作一个
    // 读入后,train_images将含有训练图片中所有图片的标签以及定位信息
    list *options = read_data_cfg(datacfg);
    char *train_images = option_find_str(options, "train", "data/train.txt");
    char *valid_images = option_find_str(options, "valid", train_images);
    char *backup_directory = option_find_str(options, "backup", "/backup/");

    network net_map;
    //如果要计算map
    if (calc_map) {
        FILE* valid_file = fopen(valid_images, "r");
        if (!valid_file) {
            printf("\n Error: There is no %s file for mAP calculation!\n Don't use -map flag.\n Or set valid=%s in your %s file. \n", valid_images, train_images, datacfg);
            getchar();
            exit(-1);
        }
        else fclose(valid_file);

        cuda_set_device(gpus[0]);
        printf(" Prepare additional network for mAP calculation...\n");
        net_map = parse_network_cfg_custom(cfgfile, 1, 1);
        //分类数
        const int net_classes = net_map.layers[net_map.n - 1].classes;

        int k;  // free memory unnecessary arrays
        for (k = 0; k < net_map.n - 1; ++k) free_layer_custom(net_map.layers[k], 1);

        char *name_list = option_find_str(options, "names", "data/names.list");
        int names_size = 0;
        //获取类别对应的名字
        char **names = get_labels_custom(name_list, &names_size);
        if (net_classes != names_size) {
            printf(" Error: in the file %s number of names %d that isn't equal to classes=%d in the file %s \n",
                name_list, names_size, net_classes, cfgfile);
            if (net_classes > names_size) getchar();
        }
        free_ptrs((void**)names, net_map.layers[net_map.n - 1].classes);
    }

    srand(time(0));
     // 提取配置文件名称中的主要信息,用于输出打印(并无实质作用),比如提取cfg/yolo.cfg中的yolo,用于下面的输出打印
    char *base = basecfg(cfgfile);
    printf("%s\n", base);
    float avg_loss = -1;
    // 构建网络:用多少块GPU,就会构建多少个相同的网络(不使用GPU时,ngpus=1)
    network* nets = (network*)xcalloc(ngpus, sizeof(network));
	
	//设定随机数种子
    srand(time(0));
    int seed = rand();
    int i;
      // for循环次数为ngpus,使用多少块GPU,就循环多少次(不使用GPU时,ngpus=1,也会循环一次)
    // 这里每一次循环都会构建一个相同的神经网络,如果提供了初始训练参数,也会为每个网络导入相同的初始训练参数
    for (i = 0; i < ngpus; ++i) {
        srand(seed);
#ifdef GPU
        cuda_set_device(gpus[i]);
#endif
		//解析网络配置文件
        nets[i] = parse_network_cfg(cfgfile);
        //测试某一个网络层的相关指标如运行时间
        nets[i].benchmark_layers = benchmark_layers;
        //如果有预训练模型则加载
        if (weightfile) {
            load_weights(&nets[i], weightfile);
        }
        //
        if (clear) *nets[i].seen = 0;
        nets[i].learning_rate *= ngpus;
    }
    ...
}
到这里我们了解了网络读取配置文件。具体怎么解析我们还要了解darknet的网络文件怎么命名,已经读
取配置文件的规则,已经darknet里的数据结构






路过

雷人

握手

鲜花

鸡蛋
我有话说......

TA还没有介绍自己。

电话咨询: 135xxxxxxx
关注微信