码迷,mamicode.com
首页 > 其他好文 > 详细

pytorch 手写数字识别项目 增量式训练

时间:2020-02-15 23:15:44      阅读:108      评论:0      收藏:0      [点我收藏+]

标签:mode   ice   预测   get   float   user   module   cond   优化   

dataset.py

 

‘‘‘
准备数据集
‘‘‘
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor,Compose,Normalize
import torchvision
import config

def mnist_dataset(train):
    func = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=(0.1307,),
            std = (0.3081,)
        )
    ])

    #准备Mnist数据集
    return MNIST(root="../mnist",train=train,download=False,transform=func)

def get_dataloader(train = True):
    mnist = mnist_dataset(train)
    batch_size = config.train_batch_size if train else config.test_batch_size
    return DataLoader(mnist,batch_size=batch_size,shuffle=True)

if __name__ == ‘__main__‘:
    for (images,labels) in get_dataloader():
        print(images.size())
        print(labels)
        break

 

  model.py

 

‘‘‘定义模型‘‘‘

import torch.nn as nn
import torch.nn.functional as F

class MnistModel(nn.Module):
    def __init__(self):
        super(MnistModel,self).__init__()
        self.fc1 = nn.Linear(28*28,100)
        self.fc2 = nn.Linear(100,10)

    def forward(self,image):
        image_viwed = image.view(-1,28*28)
        fc1_out = self.fc1(image_viwed)
        fc1_out_relu = F.relu(fc1_out)
        out = self.fc2(fc1_out_relu)

        return F.log_softmax(out,dim=-1)

 

  config.py

 

‘‘‘
项目配置
‘‘‘
import torch

train_batch_size = 128
test_batch_size = 128

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

 

  train.py

 

‘‘‘
进行模型的训练
‘‘‘
from dataset import get_dataloader
from models import MnistModel
from torch import optim
import torch.nn.functional as F
import config
from tqdm import tqdm
import numpy as np
import torch
import os
from eval import eval

#实例化模型、优化器、损失函数
model = MnistModel().to(config.device)
optimizer = optim.Adam(model.parameters(),lr=0.001)

if os.path.exists("./model/mnist_net.pt"):
    model.load_state_dict(torch.load("./model/mnist_net.pt"))
    optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt"))


#迭代训练

def train(epoch):
    train_dataloader = get_dataloader(train=True)
    bar = tqdm(enumerate(train_dataloader),total=len(train_dataloader))
    total_loss = []
    for idx,(input,target) in bar:
        input = input.to(config.device)
        target = target.to(config.device)
        #梯度置为0
        optimizer.zero_grad()
        #计算得到预测值
        output = model(input)
        #得到损失
        loss = F.nll_loss(output,target)
        total_loss.append(loss.item())
        #反向传播,计算损失
        loss.backward()
        #参数更新
        optimizer.step()

        if idx%10 ==0:
            bar.set_description("epoch:{} idx:{},loss:{}".format(epoch,idx,np.mean(total_loss)))
            torch.save(model.state_dict(),"model/mnist_net.pt")
            torch.save(optimizer.state_dict(),"model/mnist_optimizer.pt")

if __name__ == ‘__main__‘:
    for i in range(10):
        train(i)
        eval()

 

  eval.py

 

‘‘‘
进行模型的训练
‘‘‘
from dataset import get_dataloader
from models import MnistModel
from torch import optim
import torch.nn.functional as F
import config
import numpy as np
import torch
import os




#迭代训练

def eval():
    # 实例化模型、优化器、损失函数
    model = MnistModel().to(config.device)
    optimizer = optim.Adam(model.parameters(), lr=0.01)

    if os.path.exists("./model/mnist_net.pt"):
        model.load_state_dict(torch.load("./model/mnist_net.pt"))
        optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt"))
    test_dataloader = get_dataloader(train=False)
    total_loss = []
    total_acc = []
    with torch.no_grad():
        for input,target in test_dataloader:
            input = input.to(config.device)
            target = target.to(config.device)
            #计算得到预测值
            output = model(input)
            #计算损失
            loss = F.nll_loss(output,target)
            #反向传播,计算损失
            total_loss.append(loss.item())
            #计算准确率
            pred = output.max(dim=-1)[-1]
            total_acc.append(pred.eq(target).float().mean().item())
    print("test loss:{},test acc:{}".format(np.mean(total_loss),np.mean(total_acc)))

if __name__ == ‘__main__‘:
        eval()

 

  

D:\anaconda\python.exe C:/Users/liuxinyu/Desktop/pytorch_test/day3/手写数字识别/train.py
epoch:0 idx:460,loss:0.32289110562095413: 100%|██████████| 469/469 [00:24<00:00, 19.05it/s]
test loss:0.17968503131142147,test acc:0.9453125
epoch:1 idx:460,loss:0.15012750004513145: 100%|█████████▉| 468/469 [00:20<00:00, 22.10it/s]epoch:1 idx:460,loss:0.15012750004513145: 100%|██████████| 469/469 [00:20<00:00, 22.52it/s]
test loss:0.12370304338916947,test acc:0.9624208860759493
epoch:2 idx:460,loss:0.10398845713577534:  99%|█████████▉| 464/469 [00:21<00:00, 22.78it/s]epoch:2 idx:460,loss:0.10398845713577534: 100%|█████████▉| 467/469 [00:21<00:00, 22.71it/s]epoch:2 idx:460,loss:0.10398845713577534: 100%|██████████| 469/469 [00:21<00:00, 21.82it/s]
test loss:0.10385569722592077,test acc:0.9697389240506329
epoch:3 idx:460,loss:0.07973297938720653: 100%|█████████▉| 467/469 [00:22<00:00, 23.12it/s]epoch:3 idx:460,loss:0.07973297938720653: 100%|██████████| 469/469 [00:22<00:00, 20.84it/s]
test loss:0.08691684670652015,test acc:0.9754746835443038
epoch:4 idx:460,loss:0.0650228117158285: 100%|█████████▉| 468/469 [00:21<00:00, 24.06it/s]epoch:4 idx:460,loss:0.0650228117158285: 100%|██████████| 469/469 [00:21<00:00, 21.79it/s]
test loss:0.0803159438309413,test acc:0.9760680379746836
epoch:5 idx:460,loss:0.05270117848966101: 100%|██████████| 469/469 [00:21<00:00, 21.92it/s]
test loss:0.08102699166423158,test acc:0.9759691455696202
epoch:6 idx:460,loss:0.04386751471317642: 100%|██████████| 469/469 [00:19<00:00, 24.58it/s]
test loss:0.07991968260347089,test acc:0.9769580696202531
epoch:7 idx:460,loss:0.03656852366544161: 100%|██████████| 469/469 [00:15<00:00, 31.20it/s]
test loss:0.07767781678917288,test acc:0.9774525316455697
epoch:8 idx:460,loss:0.03112584312896925: 100%|██████████| 469/469 [00:14<00:00, 32.41it/s]
test loss:0.07755146227494071,test acc:0.9773536392405063
epoch:9 idx:460,loss:0.025217091969725495: 100%|██████████| 469/469 [00:14<00:00, 31.53it/s]
test loss:0.07112929566845863,test acc:0.9802215189873418

  

 

pytorch 手写数字识别项目 增量式训练

标签:mode   ice   预测   get   float   user   module   cond   优化   

原文地址:https://www.cnblogs.com/LiuXinyu12378/p/12314982.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!