码迷,mamicode.com
首页 > 其他好文 > 详细

MNIST 数据集简单识别程序

时间:2020-10-21 20:38:41      阅读:15      评论:0      收藏:0      [点我收藏+]

标签:utils   数据   模型   测试数据   10个   als   sha   清零   test   

MNIST 数据集简单识别程序


"""
# @Time    :  2020/10/20
# @Author  :  Jimou Chen
"""
import torch
from torch import nn, optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # 结构中只有输入和输出层
        self.fc1 = nn.Linear(784, 10)
        # 给一个激活函数,dim=1是第一个维度,即输出第一个维度的概率
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # 全连接层把(64, 1, 28, 28)转换为二维(64, 784),view相当于reshape,784=1*28*28
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.softmax(x)
        return x


if __name__ == ‘__main__‘:
    # 训练集
    train_data = datasets.MNIST(root=‘./‘,
                                train=True,
                                transform=transforms.ToTensor(),
                                download=True)
    # 测试集
    test_data = datasets.MNIST(root=‘./‘,
                               train=False,
                               transform=transforms.ToTensor(),
                               download=True)

    # 批次大小,即一次加载多少数据
    batch_size = 64
    # 装载训练集,shuffle=True将数据打乱
    train_load = DataLoader(dataset=train_data,
                            batch_size=batch_size,
                            shuffle=True)
    # 装载测试集,将数据打乱
    test_load = DataLoader(dataset=test_data,
                           batch_size=batch_size,
                           shuffle=True)

    # for i, data in enumerate(train_load):
    #     inputs, labels = data
    #     print(inputs.shape)
    #     print(labels.shape)
    #     print(labels)
    #     break

    # 定义模型,损失函数,优化器
    model = Net()
    mse_loss = nn.MSELoss()
    opt = optim.SGD(model.parameters(), lr=0.5)


    def train():
        for i, data in enumerate(train_load):
            # 每一次迭代都返回一组输入数据和标签
            input_data, labels = data
            # 获得模型的结果
            out = model(input_data)
            # (64)——>(64, 1)
            labels = labels.reshape(-1, 1)
            # 转换为独热编码
            one_hot = torch.zeros(input_data.shape[0], 10).scatter(1, labels, 1)
            # 计算loss,out, one_hot的shape要一致
            loss = mse_loss(out, one_hot)
            # 梯度清零
            opt.zero_grad()
            # 计算梯度
            loss.backward()
            # 修改权值
            opt.step()


    # 定义一个测试数据的函数
    def test():
        correct = 0
        for i, data in enumerate(test_load):
            # 每一次迭代都返回一组输入数据和标签
            input_data, labels = data
            # 获得模型的结果
            out = model(input_data)
            # 获得第一个维度的最大值,以及最大值所在的位置
            max_value, pred_index = torch.max(out, 1)
            # 用这64个预测数据与标签做一个对比,统计预测正确的数量
            correct += (pred_index == labels).sum()

        print(‘准确率:{0}‘.format(correct.item() / len(test_data)))


    # 训练和测试10个周期
    for i in range(10):
        print(i, ‘:‘, end=‘‘)
        train()
        test()

  • 结果:
0 :准确率:0.8881
1 :准确率:0.9025
2 :准确率:0.9067
3 :准确率:0.9104
4 :准确率:0.9147
5 :准确率:0.9159
6 :准确率:0.9165
7 :准确率:0.9184
8 :准确率:0.9187
9 :准确率:0.9199

Process finished with exit code 0

MNIST 数据集简单识别程序

标签:utils   数据   模型   测试数据   10个   als   sha   清零   test   

原文地址:https://www.cnblogs.com/jmchen/p/pytorch.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!