标签:输入 包括 waitkey bre rap apt ali mode append
上一篇中我们输出了训练的模型,这一篇中我们通过调用训练好的模型来完成测试工作。
在object_detection目录下创建test.py并输入以下内容:
import os import cv2 import numpy as np import tensorflow as tf import sys sys.path.append("..") from utils import label_map_util from utils import visualization_utils as vis_util ENERMY = 2 # 1 代表蓝色方,2 代表红色方 ,设置蓝色方为敌人 DEBUG = False THRE_VAL = 0.2 PATH_TO_CKPT =‘/home/xueaoru/models/research/inference_graph_v2/frozen_inference_graph.pb‘ PATH_TO_LABELS = ‘/home/xueaoru/models/research/object_detection/car_label_map.pbtxt‘ NUM_CLASSES = 2 label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) category_index = label_map_util.create_category_index(categories) detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_CKPT, ‘rb‘) as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name=‘‘) sess = tf.Session(graph=detection_graph) image_tensor = detection_graph.get_tensor_by_name(‘image_tensor:0‘) detection_boxes = detection_graph.get_tensor_by_name(‘detection_boxes:0‘) detection_scores = detection_graph.get_tensor_by_name(‘detection_scores:0‘) detection_classes = detection_graph.get_tensor_by_name(‘detection_classes:0‘) num_detections = detection_graph.get_tensor_by_name(‘num_detections:0‘) def video_test(): #cap = cv2.VideoCapture(1) cap = cv2.VideoCapture("/home/xueaoru/下载/RoboMaster2.mp4") while(1): time = cv2.getTickCount() ret, image = cap.read() if ret!= True: break image_expanded = np.expand_dims(image, axis=0)#[1,w,h,3] (boxes, scores, classes, num) = sess.run( [detection_boxes, detection_scores, detection_classes, num_detections], feed_dict={image_tensor: image_expanded}) #print(np.squeeze(classes).astype(np.int32)) #print(np.squeeze(scores)) #print(np.squeeze(boxes)) vis_util.visualize_boxes_and_labels_on_image_array( image, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8, min_score_thresh=0.4) cv2.imshow(‘Object detector‘, image) key = cv2.waitKey(1)&0xff time = cv2.getTickCount() - time print("处理时间:"+str(time*1000/cv2.getTickFrequency())) if key ==27: break cv2.destroyAllWindows() def pic_test(): image = cv2.imread("/home/xueaoru/models/research/images/image12.jpg") image_expanded = np.expand_dims(image, axis=0) # [1,w,h,3] (boxes, scores, classes, num) = sess.run( [detection_boxes, detection_scores, detection_classes, num_detections], feed_dict={image_tensor: image_expanded}) if DEBUG: vis_util.visualize_boxes_and_labels_on_image_array( image, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8, min_score_thresh=0.80) else: score = np.squeeze(scores) max_index = np.argmax(score) score = score[max_index] detected_class = np.squeeze(classes).astype(np.int32)[max_index] if score > THRE_VAL and detected_class == ENERMY: box = np.squeeze(boxes)[max_index]#(ymin,xmin,ymax,xmax) h,w,_ = image.shape min_point = (int(box[1]*w),int(box[0]*h)) max_point = (int(box[3]*w),int(box[2]*h)) cv2.rectangle(image,min_point,max_point,(0,255,255),2) cv2.imshow(‘Object detector‘, image) cv2.waitKey(0) cv2.destroyAllWindows() video_test()
好了,暂时就先这样吧,最后一篇详细讲解包括通过这些识别到的框到最后计算炮台偏转角度的代码。这段代码的讲解也放在后面。
[神经网络]一步一步使用Mobile-Net完成视觉识别(五)
标签:输入 包括 waitkey bre rap apt ali mode append
原文地址:https://www.cnblogs.com/aoru45/p/9868350.html