标签:pycharm sha star enum 模型 targe tar time() pyplot
用Pytorch写了两个CNN网络,数据集用的是FashionMNIST。其中CNN_1只有一个卷积层、一个全连接层,CNN_2有两个卷积层、一个全连接层,但训练完之后的准确率两者差不多,且CNN_1训练时间短得多,且跟两层的全连接的准确性也差不多,看来深度学习水很深,还需要进一步调参和调整网络结构。
CNN_1:
runnig time:29.795 sec.
accuracy: 0.8688
CNN_2:
runnig time:165.101 sec.
accuracy: 0.8837
1 import time 2 import torch.nn as nn 3 from torchvision.datasets import FashionMNIST 4 import torch 5 import numpy as np 6 from torch.utils.data import DataLoader 7 import torch.utils.data as Data 8 import matplotlib.pyplot as plt 9 10 11 #import os 12 #os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 13 ‘‘‘数据集为FashionMNIST‘‘‘ 14 data=FashionMNIST(‘../pycharm_workspace/data/‘) 15 16 def train_test_split(data,test_pct=0.3): 17 test_len=int(data.data.size(0)*test_pct) 18 x_test=data.data[0:test_len].type(torch.float) 19 x_train=data.data[test_len:].type(torch.float) 20 21 y_test=data.targets[0:test_len] 22 y_train=data.targets[test_len:] 23 24 return x_train,y_train,x_test,y_test 25 26 def cal_accuracy(model,x_test,y_test,samples=10000): 27 ‘‘‘取一定数量的样本,用于评估‘‘‘ 28 y_pred=model(x_test[:samples]) 29 ‘‘‘把模型输出(向量)转为label形式‘‘‘ 30 y_pred_=list(map(lambda x:np.argmax(x),y_pred.data.numpy())) 31 ‘‘‘计算准确率‘‘‘ 32 acc=sum(y_pred_==y_test.numpy()[:samples])/samples 33 return acc 34 35 class CNN_1(nn.Module): 36 def __init__(self): 37 super().__init__() 38 self.conv1=nn.Sequential( 39 nn.Conv2d(1,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致 40 32,#out_channels,即filter的数量 41 4,#kernel_size,4代表(4,4)即正方形的filter,若为长方形,则(height,width) 42 stride=2,#filter移动的步长,2代表(2,2)表示右移和下移都是一个像素,否则用(n,m)表示步长 43 padding=2#图片外围每一条边补充0的层数,output_size=1+(input_size+2*padding-filter_size)/stride 44 ), 45 nn.ReLU(), 46 nn.MaxPool2d(kernel_size=2) 47 ) 48 self.out=nn.Linear(32*7*7,10) 49 50 def forward(self,x): 51 x=self.conv1(x) 52 temp=x.view(x.shape[0],-1) 53 out=self.out(temp) 54 return out 55 56 class CNN_2(nn.Module): 57 def __init__(self): 58 super().__init__() 59 self.conv1=nn.Sequential( 60 nn.Conv2d(1,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致 61 32,#out_channels,即filter的数量 62 5,#kernel_size,3代表(3,3)即正方形的filter,若为长方形,则(height,width) 63 stride=1,#filter移动的步长,1代表(1,1)表示右移和下移都是一个像素,否则用(n,m)表示步长 64 padding=2#图片外围每一条边补充0的层数,此处设置为2是为了保持输出的长宽与图片的长宽一致,因为output_size=1+(input_size+2*padding-filter_size)/stride 65 ), 66 nn.ReLU(), 67 nn.MaxPool2d(kernel_size=2) 68 ) 69 self.conv2=nn.Sequential( 70 nn.Conv2d(32,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致 71 16,#out_channels,即filter的数量 72 5,#kernel_size,5代表(5,5)即正方形的filter,若为长方形,则(height,width) 73 stride=1,#filter移动的步长,1代表(1,1)表示右移和下移都是一个像素,否则用(n,m)表示步长 74 padding=2#图片外围每一条边补充0的层数,此处设置为2是为了保持输出的长宽与图片的长宽一致,因为output_size=1+(input_size+2*padding-filter_size)/stride 75 ), 76 nn.ReLU(), 77 nn.MaxPool2d(kernel_size=2) 78 ) 79 self.out=nn.Linear(16*7*7,10) 80 81 def forward(self,x): 82 x=self.conv1(x) 83 x=self.conv2(x) 84 x=x.view(x.size(0),-1) 85 out=self.out(x) 86 return out 87 88 def train_3(): 89 num_epoch=5 90 #t_data=data.data.type(torch.float) 91 x_train,y_train,x_test,y_test=train_test_split(data,0.2) 92 ‘‘‘使用DataLoader批量输入训练数据‘‘‘ 93 dl_train=DataLoader(Data.TensorDataset(x_train,y_train),batch_size=100,shuffle=True) 94 ‘‘‘创建模型对象‘‘‘ 95 model=CNN_2() 96 ‘‘‘定义损失函数‘‘‘ 97 loss_func=torch.nn.CrossEntropyLoss() 98 ‘‘‘定义优化器‘‘‘ 99 optimizer=torch.optim.Adam(model.parameters(),lr=0.001) 100 start=time.time() 101 102 acc_hist=[] 103 loss_hist=[] 104 for i in range(num_epoch): 105 for index,(x_data,y_data) in enumerate(dl_train): 106 prediction=model(torch.unsqueeze(x_data, dim=1)) 107 loss=loss_func(prediction,y_data) 108 print(‘No.%s,loss=%.3f‘%(index+1,loss.data.numpy())) 109 optimizer.zero_grad() 110 loss.backward() 111 optimizer.step() 112 loss_val=loss.data.numpy() 113 if i==0: 114 acc=cal_acc(prediction,y_data) 115 acc_hist.append(acc) 116 loss_hist.append(loss_val) 117 print(‘No.%s,loss=%.3f‘%(i+1,loss_val)) 118 #loss_hist.append(loss_val) 119 #acc=cal_accuracy(model,x_test,y_test,samples=10000) 120 #acc_hist.append(acc) 121 print(‘acc=‘,acc) 122 123 end=time.time() 124 print(‘runnig time:%.3f sec.‘%(end-start)) 125 acc=cal_accuracy(model,torch.unsqueeze(x_test,dim=1),y_test,samples=10000) 126 print(‘accuracy:‘,acc) 127 128 if __name__==‘__main__‘: 129 train_3()
标签:pycharm sha star enum 模型 targe tar time() pyplot
原文地址:https://www.cnblogs.com/aaronhoo/p/11739835.html