码迷,mamicode.com
首页 > 其他好文 > 详细

利用pytorch复现spatial pyramid pooling层

时间:2018-05-09 20:52:21      阅读:439      评论:0      收藏:0      [点我收藏+]

标签:examples   idp   pytorch   from   UNC   network   scribe   code   default   

sppnet不讲了,懒得写。。。直接上代码

 1 from math import floor, ceil
 2 import torch
 3 import torch.nn as nn
 4 import torch.nn.functional as F
 5 
 6 class SpatialPyramidPooling2d(nn.Module):
 7     r"""apply spatial pyramid pooling over a 4d input(a mini-batch of 2d inputs 
 8     with additional channel dimension) as described in the paper
 9     ‘Spatial Pyramid Pooling in deep convolutional Networks for visual recognition‘
10     Args:
11         num_level:
12         pool_type: max_pool, avg_pool, Default:max_pool
13     By the way, the target output size is num_grid:
14         num_grid = 0
15         for i in range num_level:
16             num_grid += (i + 1) * (i + 1)
17         num_grid = num_grid * channels # channels is the channel dimension of input data
18     examples:
19         >>> input = torch.randn((1,3,32,32), dtype=torch.float32)
20         >>> net = torch.nn.Sequential(nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3,stride=1),21                                       nn.ReLU(),22                                       SpatialPyramidPooling2d(num_level=2,pool_type=‘avg_pool‘),23                                       nn.Linear(32 * (1*1 + 2*2), 10))
24         >>> output = net(input)
25     """
26     
27     def __init__(self, num_level, pool_type=max_pool):
28         super(SpatialPyramidPooling2d, self).__init__()
29         self.num_level = num_level
30         self.pool_type = pool_type
31 
32     def forward(self, x):
33         N, C, H, W = x.size()
34         for i in range(self.num_level):
35             level = i + 1
36             kernel_size = (ceil(H / level), ceil(W / level))
37             stride = (ceil(H / level), ceil(W / level))
38             padding = (floor((kernel_size[0] * level - H + 1) / 2), floor((kernel_size[1] * level - W + 1) / 2))
39 
40             if self.pool_type == max_pool:
41                 tensor = (F.max_pool2d(x, kernel_size=kernel_size, stride=stride, padding=padding)).view(N, -1)
42             else:
43                 tensor = (F.avg_pool2d(x, kernel_size=kernel_size, stride=stride, padding=padding)).view(N, -1)
44             
45             if i == 0:
46                 res = tensor
47             else:
48                 res = torch.cat((res, tensor), 1)
49         return res
50     def __repr__(self):
51         return self.__class__.__name__ + ( 52             + num_level =  + str(self.num_level) 53             + , pool_type =  + str(self.pool_type) + )
54     
55 
56 class SPPNet(nn.Module):
57     def __init__(self, num_level=3, pool_type=max_pool):
58         super(SPPNet,self).__init__()
59         self.num_level = num_level
60         self.pool_type = pool_type
61         self.feature = nn.Sequential(nn.Conv2d(3,64,3),62                                     nn.ReLU(),63                                     nn.MaxPool2d(2),64                                     nn.Conv2d(64,64,3),65                                     nn.ReLU())
66         self.num_grid = self._cal_num_grids(num_level)
67         self.spp_layer = SpatialPyramidPooling2d(num_level)
68         self.linear = nn.Sequential(nn.Linear(self.num_grid * 64, 512),69                                     nn.Linear(512, 10))
70     def _cal_num_grids(self, level):
71         count = 0
72         for i in range(level):
73             count += (i + 1) * (i + 1)
74         return count
75 
76     def forward(self, x):
77         x = self.feature(x)
78         x = self.spp_layer(x)
79         print(x.size())
80         x = self.linear(x)
81         return x
82 
83 if __name__ == __main__:
84     a = torch.rand((1,3,64,64))
85     net = SPPNet()
86     output = net(a)
87     print(output)

 

利用pytorch复现spatial pyramid pooling层

标签:examples   idp   pytorch   from   UNC   network   scribe   code   default   

原文地址:https://www.cnblogs.com/qinduanyinghua/p/9016235.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!