标签:https argmax column ota mod sub sed ram transform
https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/rpn/rpn_target.py
def forward(self, ious): """RPNTargetSampler is only used in data transform with no batch dimension. Parameters ---------- ious: (N, M) i.e. (num_anchors, num_gt). Returns ------- samples: (num_anchors,) value 1: pos, -1: neg, 0: ignore. matches: (num_anchors,) value [0, M). """ matches = mx.nd.argmax(ious, axis=1) # samples init with 0 (ignore) ious_max_per_anchor = mx.nd.max(ious, axis=1) samples = mx.nd.zeros_like(ious_max_per_anchor) # set argmax (1, num_gt) ious_max_per_gt = mx.nd.max(ious, axis=0, keepdims=True) # ious (num_anchor, num_gt) >= argmax (1, num_gt) -> mark row as positive mask = mx.nd.broadcast_greater(ious + self._eps, ious_max_per_gt) # reduce column (num_anchor, num_gt) -> (num_anchor) mask = mx.nd.sum(mask, axis=1) # row maybe sampled by 2 columns but still only matches to most overlapping gt samples = mx.nd.where(mask, mx.nd.ones_like(samples), samples) # set positive overlap to 1 samples = mx.nd.where(ious_max_per_anchor >= self._pos_iou_thresh, mx.nd.ones_like(samples), samples) # set negative overlap to -1 tmp = (ious_max_per_anchor < self._neg_iou_thresh) * (ious_max_per_anchor >= 0) samples = mx.nd.where(tmp, mx.nd.ones_like(samples) * -1, samples) # subsample fg labels samples = samples.asnumpy() num_pos = int((samples > 0).sum()) if num_pos > self._max_pos: disable_indices = np.random.choice( np.where(samples > 0)[0], size=(num_pos - self._max_pos), replace=False) samples[disable_indices] = 0 # use 0 to ignore # subsample bg labels num_neg = int((samples < 0).sum()) # if pos_sample is less than quota, we can have negative samples filling the gap max_neg = self._num_sample - min(num_pos, self._max_pos) if num_neg > max_neg: disable_indices = np.random.choice( np.where(samples < 0)[0], size=(num_neg - max_neg), replace=False) samples[disable_indices] = 0 # convert to ndarray samples = mx.nd.array(samples, ctx=matches.context) return samples, matches
标签:https argmax column ota mod sub sed ram transform
原文地址:https://www.cnblogs.com/TreeDream/p/10192410.html