码迷,mamicode.com
首页 > 其他好文 > 详细

faster-rcnn代码阅读2

时间:2018-12-17 02:20:31      阅读:229      评论:0      收藏:0      [点我收藏+]

标签:调用   net   for   size   mos   hold   average   ast   cat   

二、训练

接下来回到train.py第160行,通过调用sw.train_model方法进行训练:

 1     def train_model(self, max_iters):
 2         """Network training loop."""
 3         last_snapshot_iter = -1
 4         timer = Timer()
 5         model_paths = []
 6         while self.solver.iter < max_iters:
 7             # Make one SGD update
 8             timer.tic()
 9             self.solver.step(1)
10             timer.toc()
11             if self.solver.iter % (10 * self.solver_param.display) == 0:
12                 print speed: {:.3f}s / iter.format(timer.average_time)
13 
14             if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
15                 last_snapshot_iter = self.solver.iter
16                 model_paths.append(self.snapshot())
17 
18         if last_snapshot_iter != self.solver.iter:
19             model_paths.append(self.snapshot())
20         return model_paths

方法中的self.solver.step(1)即是网络进行一次前向传播和反向传播。前向传播时,数据流会从第一层流动到最后一层,最后计算出loss,然后loss相对于各层输入的梯度会从最后一层计算回第一层。下面逐层来介绍faster-rcnn算法的运行过程。

2.1、input-data layer

第一层是由python代码构成的,其prototxt描述为:

从中可以看出,input-data层有三个输出:data、im_info、gt_boxes。其实现为faster-rcnn/lib/roi_data_layer/layer.py中的RoIDataLayer类。网络在构造过程中(即self.solver = caffe.SGDSolver(solver_prototxt))会调用该类的setup方法:

 1 __C.TRAIN.IMS_PER_BATCH = 1
 2 __C.TRAIN.SCALES = [600]
 3 __C.TRAIN.MAX_SIZE = 1000
 4 __C.TRAIN.HAS_RPN = True
 5 __C.TRAIN.BBOX_REG = True
 6 
 7     def setup(self, bottom, top):
 8         """Setup the RoIDataLayer."""
 9 
10         # parse the layer parameter string, which must be valid YAML
11         layer_params = yaml.load(self.param_str_)
12 
13         self._num_classes = layer_params[num_classes]
14 
15         self._name_to_top_map = {}
16 
17         # data blob: holds a batch of N images, each with 3 channels
18         idx = 0
19         top[idx].reshape(cfg.TRAIN.IMS_PER_BATCH, 3,
20             max(cfg.TRAIN.SCALES), cfg.TRAIN.MAX_SIZE)
21         self._name_to_top_map[data] = idx
22         idx += 1
23 
24         if cfg.TRAIN.HAS_RPN:
25             top[idx].reshape(1, 3)
26             self._name_to_top_map[im_info] = idx
27             idx += 1
28 
29             top[idx].reshape(1, 4)
30             self._name_to_top_map[gt_boxes] = idx
31             idx += 1
32         else: # not using RPN
33             # rois blob: holds R regions of interest, each is a 5-tuple
34             # (n, x1, y1, x2, y2) specifying an image batch index n and a
35             # rectangle (x1, y1, x2, y2)
36             top[idx].reshape(1, 5)
37             self._name_to_top_map[rois] = idx
38             idx += 1
39 
40             # labels blob: R categorical labels in [0, ..., K] for K foreground
41             # classes plus background
42             top[idx].reshape(1)
43             self._name_to_top_map[labels] = idx
44             idx += 1
45 
46             if cfg.TRAIN.BBOX_REG:
47                 # bbox_targets blob: R bounding-box regression targets with 4
48                 # targets per class
49                 top[idx].reshape(1, self._num_classes * 4)
50                 self._name_to_top_map[bbox_targets] = idx
51                 idx += 1
52 
53                 # bbox_inside_weights blob: At most 4 targets per roi are active;
54                 # thisbinary vector sepcifies the subset of active targets
55                 top[idx].reshape(1, self._num_classes * 4)
56                 self._name_to_top_map[bbox_inside_weights] = idx
57                 idx += 1
58 
59                 top[idx].reshape(1, self._num_classes * 4)
60                 self._name_to_top_map[bbox_outside_weights] = idx
61                 idx += 1
62 
63         print RoiDataLayer: name_to_top:, self._name_to_top_map
64         assert len(top) == len(self._name_to_top_map)

主要是对输出的shape进行定义(同时申请内存)。要说明的是,在前向传播的过程中,仍然会对输出的各top的shape进行重定义,并且二者定义的shape往往都是不同的。

 

 

faster-rcnn代码阅读2

标签:调用   net   for   size   mos   hold   average   ast   cat   

原文地址:https://www.cnblogs.com/pursuiting/p/10129049.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!