码迷,mamicode.com
首页 > 其他好文 > 详细

调用训练好的detectron模型

时间:2019-12-30 19:24:55      阅读:108      评论:0      收藏:0      [点我收藏+]

标签:训练   The   font   numpy   hat   range   while   img   return   

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from collections import defaultdict
import cv2  # NOQA (Must import before importing caffe2 due to bug in cv2)
from caffe2.python import workspace
from detectron.core.config import assert_and_infer_cfg
# from detectron.core.config import cfg
from detectron.core.config import merge_cfg_from_file
from detectron.utils.io import cache_url
from detectron.utils.timer import Timer
import detectron.core.test_engine as infer_engine
import detectron.datasets.dummy_datasets as dummy_datasets
import detectron.utils.c2 as c2_utils
import detectron.utils.vis as vis_utils
import numpy as np
import pycocotools.mask as mask_util
c2_utils.import_detectron_ops()
# OpenCL may be enabled by default in OpenCV3; disable it because it‘s not
# thread safe and causes unwanted GPU memory allocations.
# cv2.ocl.setUseOpenCL(False)
#coco
# weights = "/home/gaomh/Desktop/test/cocomodel/model_final.pkl"
# config = "/home/gaomh/Desktop/test/cocomodel/e2e_mask_rcnn_R-101-FPN_1x.yaml"
#hat
weights = "/home/gaomh/Desktop/test/models/kp-person/model_final.pkl"
config = "/home/gaomh/Desktop/test/models/kp-person/e2e_keypoint_rcnn_X-101-32x8d-FPN_1x.yaml"
#foot
# weights = "/home/gaomh/Desktop/test/trainMOdel/train/voc_2007_train/retinanet/model_final.pkl"
# config = "/home/gaomh/Desktop/test/trainMOdel/train/voc_2007_train/retinanet_R-50-FPN_1x.0.yaml"
gpuid = 0
workspace.GlobalInit([caffe2, --caffe2_log_level=0])
merge_cfg_from_file(config)
assert_and_infer_cfg(cache_urls=False)

model = infer_engine.initialize_model_from_cfg(weights, gpuid)
dataset = dummy_datasets.get_foot_dataset()


def convert_from_cls_format(cls_boxes, cls_segms):
    """Convert from the class boxes/segms/keyps format generated by the testing
    code.
    """
    box_list = [b for b in cls_boxes if len(b) > 0]
    if len(box_list) > 0:
        boxes = np.concatenate(box_list)
    else:
        boxes = None
    if cls_segms is not None:
        segms = [s for slist in cls_segms for s in slist]
    else:
        segms = None
    classes = []
    for j in range(len(cls_boxes)):
        classes += [j] * len(cls_boxes[j])
    return boxes, segms, classes


def vis_one_image(boxes, cls_segms, thresh=0.9):
    """Visual debugging of detections."""
    result_box = []
    result_mask = []
    if isinstance(boxes, list):
        boxes, segms, classes = convert_from_cls_format(boxes,cls_segms)

    if boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh:
        return result_box,result_mask
    if segms is not None:
        masks=mask_util.decode(segms)

    # Display in largest to smallest order to reduce occlusion
    areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
    sorted_inds = np.argsort(-areas)

    for i in sorted_inds:
        bbox = boxes[i, :4]
        score = boxes[i, -1]
        if score < thresh:
            continue
        result_box.append([dataset.classes[classes[i]], score, int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])])
        if segms is not None and len(segms)>i:
            result_mask.append(masks[:,:,i])
        else:
            result_mask.append([])

    return result_box, result_mask

# cap = cv2.VideoCapture("rtsp://192.168.123.231")
cap = cv2.VideoCapture("/home/gaomh/per.mp4")
# cv2.namedWindow("img", cv2.WINDOW_NORMAL)
while cap.isOpened():
    res, frame = cap.read()
    timers = defaultdict(Timer)
    with c2_utils.NamedCudaScope(0):
        cls_boxes, cls_segms, cls_keyps = infer_engine.im_detect_all(
            model, frame, None, timers=timers
        )
    # img = vis_utils.vis_one_image_opencv(im=frame, boxes=cls_boxes, segms=cls_segms, keypoints=cls_keyps, thresh=0.7, kp_thresh=2, show_box=False
    #                                      ,dataset=dataset, show_class=True)
    # vis_utils

    result_box, result_mask = vis_one_image(cls_boxes, cls_segms)
    print(result_box)
    for box in result_box:
        tit = box[0]
        thr = box[1]
        left = box[2]
        top = box[3]
        right = box[4]
        bottom = box[5]
        # if tit is "person":
        cv2.rectangle(frame, (left, top), (right, bottom), (255, 0, 0), 1)
        cv2.putText(frame, tit, (left-10, top-10), cv2.FONT_HERSHEY_COMPLEX, 0.4, (0, 0, 255))
    # print(result_box)
    cv2.imshow("img", frame)
    key = cv2.waitKey(1)
    if key == ord("q"):
        break

cv2.destroyAllWindows()

修改dummy_datasets.py,增加相应分类

def get_foot_dataset():
    """A dummy COCO dataset that includes only the ‘classes‘ field."""
    ds = AttrDict()
    classes = [
        __background__, person, foot, car
    ]
    ds.classes = {i: name for i, name in enumerate(classes)}
    return ds

效果图

技术图片

调用训练好的detectron模型

标签:训练   The   font   numpy   hat   range   while   img   return   

原文地址:https://www.cnblogs.com/answerThe/p/12121176.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!