码迷,mamicode.com
首页 > 其他好文 > 详细

pyTorch使用mnist数据集实现手写数字识别

时间:2020-06-14 19:02:50      阅读:76      评论:0      收藏:0      [点我收藏+]

标签:log   elf   orm   transform   加载   cal   class   函数   form   

使用mnist数据集实现手写数字识别是入门必做吧。这里使用pyTorch框架进行简单神经网络的搭建。

首先导入需要的包。

1 import torch
2 import torch.nn as nn
3 import torch.utils.data as Data
4 import torchvision

 

接下来需要下载mnist数据集。我们创建train_data。使用torchvision.datasets.MNIST进行数据集的下载。

1 train_data = torchvision.datasets.MNIST(
2     root=./mnist/,   #下载到该目录下
3     train=True,                                     #为训练数据
4     transform=torchvision.transforms.ToTensor(),    #将其装换为tensor的形式
5     download=True, #第一次设置为true表示下载,下载完成后,将其置成false
6 )

 之后将其导入data_loader中,这个数据加载类会自动帮我们进行数据集的切片。

 1 train_data = torchvision.datasets.MNIST(
 2     root=./mnist,
 3     train=True,
 4     transform=torchvision.transforms.ToTensor(),
 5     download=False
 6 )
 7 train_loader = Data.DataLoader(dataset=train_data, batch_size=32, shuffle=True, num_workers=0)
 8 test_data = torchvision.datasets.MNIST(
 9     root=./mnist,
10     train=False,
11     transform=torchvision.transforms.ToTensor(),
12 )
13 test_loader = Data.DataLoader(dataset=test_data, batch_size=32, shuffle=False, num_workers=0)
14 test_num = len(test_data)

之后开始定义我们的模型,由于minist数据集是灰度图像,并且图片的size都是(28, 28, 1),所以输入图片的时候不需要进行额外的修改。

 1 class Net(nn.Module):
 2     def __init__(self):
 3         super(Net, self).__init__()
 4         self.conv1 = nn.Sequential(#(1, 28, 28)
 5             nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),#(16, 28, 28)
 6             nn.ReLU(),#(16, 28, 28)
 7             nn.MaxPool2d(kernel_size=2)#(16, 14, 14)
 8         )
 9         self.conv2 = nn.Sequential(
10             nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),#(32, 14, 14)
11             nn.ReLU(),#(32, 14, 14)
12             nn.MaxPool2d(kernel_size=2)#(32, 7, 7)
13         )
14         self.fc = nn.Linear(32 * 7 * 7, 10)
15     def forward(self, x):
16         x = self.conv1(x)
17         x = self.conv2(x)
18         x = x.view(x.size(0), -1)
19         x = self.fc(x)
20         return x

特别注意在最后传入全连接层时,最好自己将x的size改变以确保不会因为自适应而造成错误。因为在传入全连接层时会默认压缩成二维,例如[1, 2, 3, 4]会被压缩成[1*2, 3*4]。

之后开始训练。

 1 net = Net()
 2 loss_fn = nn.CrossEntropyLoss()
 3 optim = torch.optim.Adam(net.parameters(), lr = 0.001)
 4 
 5 save_path = ./mnist.pth
 6 best_acc = 0.0
 7 for epoch in range(3):
 8 
 9     net.train()
10     running_loss = 0.0
11     for step, data in enumerate(train_loader, start=0):
12         images, labels = data
13         optim.zero_grad()
14         logits = net(images)
15         loss = loss_fn(logits, labels)
16         loss.backward()
17         optim.step()
18 
19 
20         running_loss += loss.item()
21         rate = (step+1)/len(train_loader)
22         a = "*" * int(rate * 50)
23         b = "." * int((1 - rate) * 50)
24         print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
25     print()
26 
27     net.eval()
28     acc = 0.0
29     with torch.no_grad():
30         for data_test in test_loader:
31             test_images, test_labels = data_test
32             outputs = net(test_images)
33             predict_y = torch.max(outputs, dim=1)[1]#torch.max返回两个数值,一个是最大值,一个是最大值的下标
34             acc += (predict_y == test_labels).sum().item()
35         test_accurate = acc / test_num
36         if test_accurate > best_acc:
37             best_acc = test_accurate
38             torch.save(net.state_dict(), save_path)
39         print([epoch %d] train_loss: %.3f  test_accuracy: %.3f %
40               (epoch + 1, running_loss / step, test_accurate))
41 
42 print(Finished Training)

在完成训练后,训练的权重会保存在所设置路径下的文件中,进行预测的时候,建立模型,载入权重,照一张数字的图片,对其进行裁剪,灰度等操作之后加载入模型进行预测。

 1 from PIL import Image
 2 import  matplotlib.pyplot as plt
 3 from torchvision import transforms
 4 import torch
 5 from model import Net
 6 
 7 img = Image.open("./YLY2@}8UMGLW37S$)NCVZ23.png")
 8 
 9 plt.imshow(img)
10 
11 # [N, C, H, W]
12 
13 train_transform = transforms.Compose([
14         transforms.Grayscale(),
15         transforms.Resize((28, 28)),
16         transforms.ToTensor(),
17 ])
18 
19 img = train_transform(img)
20 # expand batch dimension
21 img = torch.unsqueeze(img, dim=0)
22 
23 # create model
24 model = Net()
25 # load model weights
26 model_weight_path = "./mnist.pth"
27 model.load_state_dict(torch.load(model_weight_path))
28 
29 index_to_class = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
30 
31 
32 model.eval()
33 with torch.no_grad():
34     # predict class
35     y = model(img)
36     #print(y.size())
37     output = torch.squeeze(y)
38     #print(output)
39     predict = torch.softmax(output, dim=0)
40     #print(predict)
41     predict_cla = torch.argmax(predict).numpy()
42     #print(predict_cla)
43 print(index_to_class[predict_cla], predict[predict_cla].numpy())
44 plt.show()

需要注意的是,载入模型的图片必须多一个维度batch,所以我们用img = torch.unsqueeze(img, dim=0)在图片的开头增加一个batch维度。

之后载入图片,得到输出,将输出的batch维度压缩掉,使用softmax函数得到概率分布,再用argmax函数得到最大值的下标,打印最大值所对应的类别及其概率。

pyTorch使用mnist数据集实现手写数字识别

标签:log   elf   orm   transform   加载   cal   class   函数   form   

原文地址:https://www.cnblogs.com/1-0001/p/12227295.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!