码迷,mamicode.com
首页 > 其他好文 > 详细

【Pytorch】CIFAR-10分类任务

时间:2019-01-10 15:46:28      阅读:550      评论:0      收藏:0      [点我收藏+]

标签:应该   one   输入   compose   有关   python版本   count   with open   load   

https://blog.csdn.net/weixin_39837402/article/details/81054106

【Pytorch】CIFAR-10分类任务

CIFAR-10数据集共有60000张32*32彩色图片,分为10类,每类有6000张图片。其中50000张用于训练,构成5个训练batch,每一批次10000张图片,其余10000张图片用于测试。

技术分享图片

CIFAR-10数据集下载地址:点击下载

数据读取,这里选择下载python版本的数据集,解压后得到如下文件:

技术分享图片

其中data_batch_1~data_batch_5为训练集的5个批次,test_batch为测试集。

这些文件是python的序列化模型,这里使用python3,可以使用pickle模块读取这些数据:

  1. def unpickle(file):
  2. import pickle
  3. with open(file, ‘rb‘) as fo:
  4. dict = pickle.load(fo, encoding=‘bytes‘)
  5. return dict

    每一个batch文件包括一个字典,字典的元素是:
data:一个尺寸为10000*3072,数据格式为uint8的numpy array,每
一行数据存储了一张32*32彩色图片的数据,前1024位是图像的红色
通道数据,接着是绿色通道和蓝色通道。

label:一个包含10000个0-9数字的列表,对应data里每张图片的标签。

    技术分享图片   

技术分享图片

    此外,数据集中还有一个batches.meta文件,它保存了一个python字典,
该字典对标签的10个数字0-9所代表的意义做了解释,比如0代表airplane,
1代表automobile。

技术分享图片

 

这次使用Pytorch框架来进行实验,总体流程是,建立网络(这次小demo用Lenet),自定义数据集读取框架,虽然pytorch已经有关于cifar10的Dataset实例,但还是自己实现了一遍,接着用DataLoader分批读取数据集,定义损失函数和优化器,进行批次训练。

  1. import torch
  2. import torchvision
  3. from torch.autograd import Variable
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import torch.optim as optim
  7. import torch.utils.data as Data
  8. import torchvision.transforms as transforms
  9. import numpy as np
  10. from PIL import Image
  11. import matplotlib.pyplot as plt
  12.  
  13. #预设参数
  14. CLASS_NUM = 10
  15. BATCH_SIZE = 128
  16. EPOCH = 30
  17.  
  18. #Lenet网络代码
  19. class Lenet(nn.Module):
  20. def __init__(self):
  21. super(Lenet,self).__init__()
  22. #定义网络层
  23. #入通道数,出通道数,卷积尺寸
  24. self.conv1 = nn.Conv2d(3,6,5)
  25. self.conv2 = nn.Conv2d(6,16,5)
  26. self.fc1 = nn.Linear(16*5*5,120)
  27. self.fc2 = nn.Linear(120,84)
  28. self.fc3 = nn.Linear(84,10)
  29.  
  30. #将二维数据展开成一维数据以输入到全连接层
  31. def num_flat_features(self,x):
  32. #size为[batch_size,num_channels,height,width]
  33. #除去batch_size,num_channels*height*width就是展开后维度
  34. size = x.size()[1:]
  35. num_features = 1
  36. for s in size:
  37. num_features = num_features*s
  38. return num_features
  39.  
  40. def forward(self,x):
  41. #定义前向传播
  42. #输入 和 窗口尺寸
  43. x = F.max_pool2d(F.relu(self.conv1(x)), 2)
  44. x = F.max_pool2d(F.relu(self.conv2(x)), 2)
  45. x = x.view(-1, self.num_flat_features(x))
  46. x = F.relu(self.fc1(x))
  47. x = F.relu(self.fc2(x))
  48. x = self.fc3(x)
  49. return x
  50.  
  51. def unpickle(file):
  52. import pickle
  53. with open(file, ‘rb‘) as fo:
  54. dict = pickle.load(fo, encoding=‘bytes‘)
  55. return dict
  56.  
  57. #从源文件读取数据
  58. #返回 train_data[50000,3072]和labels[50000]
  59. # test_data[10000,3072]和labels[10000]
  60. def get_data(train=False):
  61. data = None
  62. labels = None
  63. if train == True:
  64. for i in range(1,6):
  65. batch = unpickle(‘data/cifar-10-batches-py/data_batch_‘+str(i))
  66. if i == 1:
  67. data = batch[b‘data‘]
  68. else:
  69. data = np.concatenate([data,batch[b‘data‘]])
  70.  
  71. if i == 1:
  72. labels = batch[b‘labels‘]
  73. else:
  74. labels = np.concatenate([labels,batch[b‘labels‘]])
  75. else:
  76. batch = unpickle(‘data/cifar-10-batches-py/test_batch‘)
  77. data = batch[b‘data‘]
  78. labels = batch[b‘labels‘]
  79. return data,labels
  80.  
  81. #图像预处理函数,Compose会将多个transform操作包在一起
  82. #对于彩色图像,色彩通道不存在平稳特性
  83. transform = transforms.Compose([
  84. # ToTensor是指把PIL.Image(RGB) 或者numpy.ndarray(H x W x C)
  85. # 从0到255的值映射到0到1的范围内,并转化成Tensor格式。
  86. transforms.ToTensor(),
  87. #Normalize函数将图像数据归一化到[-1,1]
  88. transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
  89. ]
  90. )
  91.  
  92. #将标签转换为torch.LongTensor
  93. def target_transform(label):
  94. label = np.array(label)
  95. target = torch.from_numpy(label).long()
  96. return target
  97.  
  98. ‘‘‘
  99. 自定义数据集读取框架来载入cifar10数据集
  100. 需要继承data.Dataset
  101. ‘‘‘
  102. class Cifar10_Dataset(Data.Dataset):
  103. def __init__(self,train=True,transform=None,target_transform=None):
  104. #初始化文件路径
  105. self.transform = transform
  106. self.target_transform = target_transform
  107. self.train = train
  108. #载入训练数据集
  109. if self.train:
  110. self.train_data,self.train_labels = get_data(train)
  111. self.train_data = self.train_data.reshape((50000, 3, 32, 32))
  112. # 将图像数据格式转换为[height,width,channels]方便预处理
  113. self.train_data = self.train_data.transpose((0, 2, 3, 1))
  114. #载入测试数据集
  115. else:
  116. self.test_data,self.test_labels = get_data()
  117. self.test_data = self.test_data.reshape((10000, 3, 32, 32))
  118. self.test_data = self.test_data.transpose((0, 2, 3, 1))
  119. pass
  120. def __getitem__(self, index):
  121. #从数据集中读取一个数据并对数据进行
  122. #预处理返回一个数据对,如(data,label)
  123. if self.train:
  124. img, label = self.train_data[index], self.train_labels[index]
  125. else:
  126. img, label = self.test_data[index], self.test_labels[index]
  127.  
  128. img = Image.fromarray(img)
  129. #图像预处理
  130. if self.transform is not None:
  131. img = self.transform(img)
  132. #标签预处理
  133. if self.target_transform is not None:
  134. target = self.target_transform(label)
  135.  
  136. return img, target
  137. def __len__(self):
  138. #返回数据集的size
  139. if self.train:
  140. return len(self.train_data)
  141. else:
  142. return len(self.test_data)
  143.  
  144. if __name__ == ‘__main__‘:
  145. #读取训练集和测试集
  146. train_data = Cifar10_Dataset(True,transform,target_transform)
  147. print(‘size of train_data:{}‘.format(train_data.__len__()))
  148. test_data = Cifar10_Dataset(False,transform,target_transform)
  149. print(‘size of test_data:{}‘.format(test_data.__len__()))
  150. train_loader = Data.DataLoader(dataset=train_data, batch_size = BATCH_SIZE, shuffle=True)
  151.  
  152. net = Lenet()
  153. optimizer = optim.Adam(net.parameters(), lr = 0.001, betas=(0.9, 0.99))
  154. #在使用CrossEntropyLoss时target直接使用类别索引,不适用one-hot
  155. loss_fn = nn.CrossEntropyLoss()
  156.  
  157. loss_list = []
  158. for epoch in range(1,EPOCH+1):
  159. #训练部分
  160. for step,(x,y) in enumerate(train_loader):
  161. b_x = Variable(x)
  162. b_y = Variable(y)
  163. output = net(b_x)
  164. loss = loss_fn(output,b_y)
  165. optimizer.zero_grad()
  166. loss.backward()
  167. optimizer.step()
  168. #记录loss
  169. if step%50 == 0:
  170. loss_list.append(loss)
  171. #每完成一个epoch进行一次测试观察效果
  172. pre_correct = 0.0
  173. test_loader = Data.DataLoader(dataset=test_data, batch_size = 100, shuffle=True)
  174. for (x,y) in (test_loader):
  175. b_x = Variable(x)
  176. b_y = Variable(y)
  177. output = net(b_x)
  178. pre = torch.max(output,1)[1]
  179. pre_correct = pre_correct+float(torch.sum(pre==b_y))
  180.  
  181. print(‘EPOCH:{epoch},ACC:{acc}%‘.format(epoch=epoch,acc=(pre_correct/float(10000))*100))
  182.  
  183. #保存网络模型
  184. torch.save(net,‘lenet_cifar_10.model‘)
  185. #绘制loss变化曲线
  186. plt.plot(loss_list)
  187. plt.show()

第一个pytorch demo跑通了,但是训练模型效果很不好,应该是Lenet作用于Cifar10有些过于力不从心了,刚开始接触深度学习的图像领域还不怎么懂,下次换一个更强大的网络。

技术分享图片

技术分享图片

【Pytorch】CIFAR-10分类任务

标签:应该   one   输入   compose   有关   python版本   count   with open   load   

原文地址:https://www.cnblogs.com/shuimuqingyang/p/10249711.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!