标签:ice lob end ali numpy text div seq 设置
1、快速开始
1.1 定义神经网络类,继承torch.nn.Module,文件名为digit_recog.py
1 import torch.nn as nn 2 3 4 class Net(nn.Module): 5 def __init__(self): 6 super(Net, self).__init__() 7 self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 5, 1, 2) 8 , nn.ReLU() 9 , nn.MaxPool2d(2, 2)) 10 self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5) 11 , nn.ReLU() 12 , nn.MaxPool2d(2, 2)) 13 self.fc1 = nn.Sequential( 14 nn.Linear(16 * 5 * 5, 120), 15 # nn.Dropout2d(), 16 nn.ReLU() 17 ) 18 self.fc2 = nn.Sequential( 19 nn.Linear(120, 84), 20 nn.Dropout2d(), 21 nn.ReLU() 22 ) 23 self.fc3 = nn.Linear(84, 10) 24 25 # 前向传播 26 def forward(self, x): 27 x = self.conv1(x) 28 x = self.conv2(x) 29 # 线性层的输入输出都是一维数据,所以要把多维度的tensor展平成一维 30 x = x.view(x.size()[0], -1) 31 x = self.fc1(x) 32 x = self.fc2(x) 33 x = self.fc3(x) 34 return x
上面的类定义了一个3层的网络结构,根据问题类型,最后一层是确定的
1.2 开始训练:
import torch import torchvision as tv import torchvision.transforms as transforms import torch.nn as nn import torch.optim as optim import os import copy import time from digit_recog import Net from digit_recog_mydataset import MyDataset # 读取已保存的模型 def getmodel(pth, net): state_filepath = pth if os.path.exists(state_filepath): # 加载参数 nn_state = torch.load(state_filepath) # 加载模型 net.load_state_dict(nn_state) # 拷贝一份 return copy.deepcopy(nn_state) else: return net.state_dict() # 构建数据集 def getdataset(batch_size): # 定义数据预处理方式 transform = transforms.ToTensor() # 定义训练数据集 trainset = tv.datasets.MNIST( root=‘./data/‘, train=True, download=True, transform=transform) # 去掉注释,加入自己的数据集 # trainset += MyDataset(os.path.abspath("./data/myimages/"), ‘train.txt‘, transform=transform) # 定义训练批处理数据 trainloader = torch.utils.data.DataLoader( trainset, batch_size=batch_size, shuffle=True, ) # 定义测试数据集 testset = tv.datasets.MNIST( root=‘./data/‘, train=False, download=True, transform=transform) # 去掉注释,加入自己的数据集 # testset += MyDataset(os.path.abspath("./data/myimages/"), ‘test.txt‘, transform=transform) # 定义测试批处理数据 testloader = torch.utils.data.DataLoader( testset, batch_size=batch_size, shuffle=False, ) return trainloader, testloader # 训练 def training(device, net, model, dataset_loader, epochs, criterion, optimizer, save_model_path): trainloader, testloader = dataset_loader # 最佳模型 best_model_wts = model # 最好分数 best_acc = 0.0 # 计时 since = time.time() for epoch in range(epochs): sum_loss = 0.0 # 训练数据集 for i, data in enumerate(trainloader): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) # 梯度清零,避免带入下一轮累加 optimizer.zero_grad() # 神经网络运算 outputs = net(inputs) # 损失值 loss = criterion(outputs, labels) # 损失值反向传播 loss.backward() # 执行优化 optimizer.step() # 损失值汇总 sum_loss += loss.item() # 每训练完100条数据就显示一下损失值 if i % 100 == 99: print(‘[%d, %d] loss: %.03f‘ % (epoch + 1, i + 1, sum_loss / 100)) sum_loss = 0.0 # 每训练完一轮测试一下准确率 with torch.no_grad(): correct = 0 total = 0 for data in testloader: images, labels = data images, labels = images.to(device), labels.to(device) outputs = net(images) # 取得分最高的 _, predicted = torch.max(outputs.data, 1) # print(labels) # print(torch.nn.Softmax(dim=1)(outputs.data).detach().numpy()[0]) # print(torch.nn.functional.normalize(outputs.data).detach().numpy()[0]) total += labels.size(0) correct += (predicted == labels).sum() print(‘测试结果:{}/{}‘.format(correct, total)) epoch_acc = correct.double() / total print(‘当前分数:{} 最高分数:{}‘.format(epoch_acc, best_acc)) if epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(net.state_dict()) print(‘第%d轮的识别准确率为:%d%%‘ % (epoch + 1, (100 * correct / total))) time_elapsed = time.time() - since print(‘训练完成于 {:.0f}m {:.0f}s‘.format( time_elapsed // 60, time_elapsed % 60)) print(‘最高分数: {:4f}‘.format(best_acc)) # 保存训练模型 if save_model_path is not None: save_state_path = os.path.join(‘model/‘, ‘net.pth‘) torch.save(best_model_wts, save_state_path) # 基于cpu还是gpu DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") NET = Net().to(DEVICE) # 超参数设置 EPOCHS = 8# 训练多少轮 BATCH_SIZE = 64 # 数据集批处理数量 64 LR = 0.001 # 学习率 # 交叉熵损失函数,通常用于多分类问题上 CRITERION = nn.CrossEntropyLoss() # 优化器 # OPTIMIZER = optim.SGD(net.parameters(), lr=LR, momentum=0.9) OPTIMIZER = optim.Adam(NET.parameters(), lr=LR) MODEL = getmodel(os.path.join(‘model/‘, ‘net.pth‘), NET) training(DEVICE, NET, MODEL, getdataset(BATCH_SIZE), 1, CRITERION, OPTIMIZER, os.path.join(‘model/‘, ‘net.pth‘))
利用标准的mnist数据集跑出来的识别率能达到99%
2、参与进来
目的是为了识别自己的图片,增加参与感
2.1 打开windows附件中的画图工具,用鼠标画几个数字,然后用截图工具保存下来
2.2 实现自己的数据集:
digit_recog_mydataset.py
from PIL import Image import torch import os # 实现自己的数据集 class MyDataset(torch.utils.data.Dataset): def __init__(self, root, datafile, transform=None, target_transform=None): super(MyDataset, self).__init__() fh = open(os.path.join(root, datafile), ‘r‘) datas = [] for line in fh: # 删除本行末尾的字符 line = line.rstrip() # 通过指定分隔符对字符串进行拆分,默认为所有的空字符,包括空格、换行、制表符等 words = line.split() # words[0]是图片信息,words[1]是标签 datas.append((words[0], int(words[1]))) self.datas = datas self.transform = transform self.target_transform = target_transform self.root = root # 必须实现的方法,用于按照索引读取每个元素的具体内容 def __getitem__(self, index): # 获取图片及标签,即上面每行中word[0]和word[1]的信息 img, label = self.datas[index] # 打开图片,重设尺寸,转换为灰度图 img = Image.open(os.path.join(self.root, img)).resize((28, 28)).convert(‘L‘) # 数据预处理 if self.transform is not None: img = self.transform(img) return img, label # 必须实现的方法,返回数据集的长度 def __len__(self): return len(self.datas)
2.3 在图片文件夹中新建两个文件,train.txt和test.txt,分别写上训练与测试集的数据,格式如下
训练与测试的数据要严格区分开,否则训练出来的模型会有问题
2.4 加入训练、测试数据集
反注释训练方法中的这两行
# trainset += MyDataset(os.path.abspath("./data/myimages/"), ‘train.txt‘, transform=transform) # testset += MyDataset(os.path.abspath("./data/myimages/"), ‘test.txt‘, transform=transform)
继续执行训练,这里我训练出来的最高识别率是98%
2.5 测试模型
# -*- coding: utf-8 -*- # encoding:utf-8 import torch import numpy as np from PIL import Image import os import matplotlib import matplotlib.pyplot as plt import glob from digit_recog import Net device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = Net().to(device) # 加载参数 nn_state = torch.load(os.path.join(‘model/‘, ‘net.pth‘)) # 参数加载到指定模型 net.load_state_dict(nn_state) # 指定默认字体 matplotlib.rcParams[‘font.sans-serif‘] = [‘SimHei‘] matplotlib.rcParams[‘font.family‘] = ‘sans-serif‘ # 解决负号‘-‘显示为方块的问题 matplotlib.rcParams[‘axes.unicode_minus‘] = False # 要识别的图片 file_list = glob.glob(os.path.join(‘data/test_image/‘, ‘*‘)) grid_rows = len(file_list) / 5 + 1 for i, file in enumerate(file_list): # 读取图片并重设尺寸 image = Image.open(file).resize((28, 28)) # 灰度图 gray_image = image.convert(‘L‘) # 图片数据处理 im_data = np.array(gray_image) im_data = torch.from_numpy(im_data).float() im_data = im_data.view(1, 1, 28, 28) # 神经网络运算 outputs = net(im_data) # 取最大预测值 _, pred = torch.max(outputs, 1) # print(torch.nn.Softmax(dim=1)(outputs).detach().numpy()[0]) # print(torch.nn.functional.normalize(outputs).detach().numpy()[0]) # 显示图片 plt.subplot(grid_rows, 5, i + 1) plt.imshow(gray_image) plt.title(u"你是{}?".format(pred.item()), fontsize=8) plt.axis(‘off‘) print(‘[{}]预测数字为: [{}]‘.format(file, pred.item())) plt.show()
可视化结果
这批图片是经过图片增强后识别的结果,准确率有待提高
3、优化
3.1 更多样本:
收集难度大
3.2 数据增强:
简单地处理一下自己手写的数字图片
# -*- coding: utf-8 -*- # encoding:utf-8 import torch import numpy as np from PIL import Image import os import matplotlib import matplotlib.pyplot as plt import glob from scipy.ndimage import filters class ImageProcceed: def __init__(self, image_folder): self.image_folder = image_folder def save(self, rotate, filter=None, to_gray=True): file_list = glob.glob(os.path.join(self.image_folder, ‘*.png‘)) print(len(file_list)) for i, file in enumerate(file_list): # 读取图片数据 image = Image.open(file) # .resize((28, 28)) # 灰度图 if to_gray == True: image = image.convert(‘L‘) # 旋转 image = image.rotate(rotate) if filter is not None: image = filters.gaussian_filter(image, 0.5) image = Image.fromarray(image) filename = os.path.basename(file) fileext = os.path.splitext(filename)[1] savefile = filename.replace(fileext, ‘-rt{}{}‘.format(rotate, fileext)) print(savefile) image.save(os.path.join(self.image_folder, savefile)) ip = ImageProcceed(‘data/myimages/‘) ip.save(20, filter=0.5)
3.3 改变网络大小:
比如把上面的Net类中的3层改为2层
3.4 调参:
改变学习率,训练更多次数等
标签:ice lob end ali numpy text div seq 设置
原文地址:https://www.cnblogs.com/migomiddle/p/11811356.html