码迷,mamicode.com
首页 > 其他好文 > 详细

detectron2训练visdrone记录

时间:2020-06-08 16:12:50      阅读:124      评论:0      收藏:0      [点我收藏+]

标签:cal   配置   split   width   always   continue   enc   加载   pac   

准备

VOC标签转换参见这篇
注意:object_name = name_dict[box[4]] 改为 object_name = name_dict[box[5]]。为了与detectron2统一,
标签文件夹命名为Annotations,图片文件夹命名为JPEGImages,train.txt位于xxx/ImageSets/Main/。

train

构建instance

# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import numpy as np
import os
import xml.etree.ElementTree as ET
from fvcore.common.file_io import PathManager

from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.structures import BoxMode

__all__ = ["register_visdrone_voc"]

CLASS_NAMES = [‘__background__‘,  # always index 0
               ‘pedestrian‘, ‘people‘, ‘bicycle‘, ‘car‘, ‘van‘, ‘truck‘, ‘tricycle‘, ‘awning-tricycle‘, ‘bus‘, ‘motor‘]


def load_voc_instances(dirname: str, split: str):
    """
    Load Pascal VOC detection annotations to Detectron2 format.

    Args:
        dirname: Contain "Annotations", "ImageSets", "JPEGImages"
        split (str): one of "train", "test", "val", "trainval"
    """
    with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f:
        fileids = np.loadtxt(f, dtype=np.str)

    # Needs to read many small annotation files. Makes sense at local
    annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/"))
    dicts = []
    for fileid in fileids:
        anno_file = os.path.join(annotation_dirname, fileid + ".xml")
        jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg")

        with PathManager.open(anno_file) as f:
            tree = ET.parse(f)

        r = {
            "file_name": jpeg_file,
            "image_id": fileid,
            "height": int(tree.findall("./size/height")[0].text),
            "width": int(tree.findall("./size/width")[0].text),
        }
        instances = []

        for obj in tree.findall("object"):
            cls = obj.find("name").text
            # We include "difficult" samples in training.
            # Based on limited experiments, they don‘t hurt accuracy.
            # difficult = int(obj.find("difficult").text)
            # if difficult == 1:
            # continue
            bbox = obj.find("bndbox")
            bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]]
            # Original annotations are integers in the range [1, W or H]
            # Assuming they mean 1-based pixel indices (inclusive),
            # a box with annotation (xmin=1, xmax=W) covers the whole image.
            # In coordinate space this is represented by (xmin=0, xmax=W)
            bbox[0] -= 1.0
            bbox[1] -= 1.0
            instances.append(
                {"category_id": CLASS_NAMES.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS}
            )
        r["annotations"] = instances
        dicts.append(r)
    return dicts


def register_visdrone_voc(name, dirname, split, year):
    DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split))
    MetadataCatalog.get(name).set(
        thing_classes=CLASS_NAMES, dirname=dirname, year=year, split=split
    )

train采用Faster R-CNN with FPN,backbone使用Resnext-101,群卷积32x8d,即32个group,每个group8个filter

from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.engine import DefaultPredictor
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.evaluation import COCOEvaluator, PascalVOCDetectionEvaluator, inference_on_dataset
from detectron2.data import build_detection_test_loader
from train_no_voc import *
from visdrone_voc import *
import os
import cv2
import torch

register_visdrone_voc(‘VISDRONE_VOC‘, os.path.join(‘/home/chenzhengxi/data/VisDrone/VisDrone2018-DET-train‘),
                      ‘train‘, 2012)
register_visdrone_voc(‘VISDRONE_VAL‘, os.path.join(‘/home/chenzhengxi/data/VisDrone/VisDrone2018-DET-val‘),
                      ‘val‘, 2012)
register_visdrone_voc(‘VISDRONE_TEST‘, os.path.join(‘/home/chenzhengxi/data/VisDrone/VisDrone2019-DET-test-dev‘),
                      ‘test‘, 2012)
# register_train_no_voc(‘TRAIN_NO_VOC‘, os.path.join(‘/home/chenzhengxi/data/TrainNO/VOCdevkit2007/TRAIN2020‘),
#                       ‘train‘, 2020)
cfg = get_cfg()
cfg.merge_from_file(‘configs/faster_rcnn_X_101_32x8d_FPN_3x.yaml‘)

# cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 4
# cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
# cfg.SOLVER.IMS_PER_BATCH = 8
# cfg.SOLVER.BASE_LR = 0.00025  # pick a good LR
# cfg.SOLVER.MAX_ITER = 3000
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128

# cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
# cfg.OUTPUT_DIR = ‘output‘
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
#resume=True可继续训练并加载最新权重
trainer.resume_or_load(resume=False)
trainer.train()

#以下代码可指定具体权重
#cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_0229999.pth")
#checkpointer = DetectionCheckpointer(trainer.model)
#checkpointer.load(cfg.MODEL.WEIGHTS)
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7   # set the testing threshold for this model

evaluator = PascalVOCDetectionEvaluator(cfg.DATASETS.TEST[0])
# val_loader = build_detection_test_loader(cfg, "VISDRONE_VAL")
# result_val = inference_on_dataset(trainer.model, val_loader, evaluator)
# print(result_val)
print(trainer.test(cfg, trainer.model, evaluator))

# predictor = DefaultPredictor(cfg)
# im = cv2.imread(‘/home/chenzhengxi/data/VisDrone/VisDrone2018-DET-val/JPEGImages/0000026_03500_d_0000031.jpg‘)
# outputs = predictor(im)
# ooo = outputs[‘instances‘].to(torch.device("cpu"))
# boxes = ooo.pred_boxes.tensor.numpy()
# print(boxes)
# for i in range(len(boxes)):
#     cv2.rectangle(im, tuple(boxes[i, 0:2]), tuple(boxes[i, 2:4]), (0, 255, 0), 2)
#
# cv2.imshow(‘visdrone‘, im)
# cv2.waitKey(0)
#测试记录
# AP            AP50       AP75
# 20.8397,39.0423,19.9213  94999  val
# 17.0480,32.8580,16.1755  94999 test
# 21.6114,39.1688,21.1583 149999  val
# 17.7422,33.4552,17.0088 149999 test
# 22.2992,40.0259,21.7078 169999  val
# 18.0168,33.5292,17.6131 169999 test
# 22.9258,41.0643,22.4904 214999  val
# 18.1249,33.7228,17.6461 214999 test
# 22.8556,40.9861,22.3667 269999  val
# 18.0256,33.5866,17.5259 269999 test

可以看出效果远高于yolo,最终配置和权重下载,提取码: 74s4

detectron2训练visdrone记录

标签:cal   配置   split   width   always   continue   enc   加载   pac   

原文地址:https://www.cnblogs.com/chenzhengxi/p/13065792.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!