标签:mode ice 预测 get float user module cond 优化
dataset.py
‘‘‘ 准备数据集 ‘‘‘ import torch from torch.utils.data import DataLoader from torchvision.datasets import MNIST from torchvision.transforms import ToTensor,Compose,Normalize import torchvision import config def mnist_dataset(train): func = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=(0.1307,), std = (0.3081,) ) ]) #准备Mnist数据集 return MNIST(root="../mnist",train=train,download=False,transform=func) def get_dataloader(train = True): mnist = mnist_dataset(train) batch_size = config.train_batch_size if train else config.test_batch_size return DataLoader(mnist,batch_size=batch_size,shuffle=True) if __name__ == ‘__main__‘: for (images,labels) in get_dataloader(): print(images.size()) print(labels) break
model.py
‘‘‘定义模型‘‘‘ import torch.nn as nn import torch.nn.functional as F class MnistModel(nn.Module): def __init__(self): super(MnistModel,self).__init__() self.fc1 = nn.Linear(28*28,100) self.fc2 = nn.Linear(100,10) def forward(self,image): image_viwed = image.view(-1,28*28) fc1_out = self.fc1(image_viwed) fc1_out_relu = F.relu(fc1_out) out = self.fc2(fc1_out_relu) return F.log_softmax(out,dim=-1)
config.py
‘‘‘ 项目配置 ‘‘‘ import torch train_batch_size = 128 test_batch_size = 128 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train.py
‘‘‘ 进行模型的训练 ‘‘‘ from dataset import get_dataloader from models import MnistModel from torch import optim import torch.nn.functional as F import config from tqdm import tqdm import numpy as np import torch import os from eval import eval #实例化模型、优化器、损失函数 model = MnistModel().to(config.device) optimizer = optim.Adam(model.parameters(),lr=0.001) if os.path.exists("./model/mnist_net.pt"): model.load_state_dict(torch.load("./model/mnist_net.pt")) optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt")) #迭代训练 def train(epoch): train_dataloader = get_dataloader(train=True) bar = tqdm(enumerate(train_dataloader),total=len(train_dataloader)) total_loss = [] for idx,(input,target) in bar: input = input.to(config.device) target = target.to(config.device) #梯度置为0 optimizer.zero_grad() #计算得到预测值 output = model(input) #得到损失 loss = F.nll_loss(output,target) total_loss.append(loss.item()) #反向传播,计算损失 loss.backward() #参数更新 optimizer.step() if idx%10 ==0: bar.set_description("epoch:{} idx:{},loss:{}".format(epoch,idx,np.mean(total_loss))) torch.save(model.state_dict(),"model/mnist_net.pt") torch.save(optimizer.state_dict(),"model/mnist_optimizer.pt") if __name__ == ‘__main__‘: for i in range(10): train(i) eval()
eval.py
‘‘‘ 进行模型的训练 ‘‘‘ from dataset import get_dataloader from models import MnistModel from torch import optim import torch.nn.functional as F import config import numpy as np import torch import os #迭代训练 def eval(): # 实例化模型、优化器、损失函数 model = MnistModel().to(config.device) optimizer = optim.Adam(model.parameters(), lr=0.01) if os.path.exists("./model/mnist_net.pt"): model.load_state_dict(torch.load("./model/mnist_net.pt")) optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt")) test_dataloader = get_dataloader(train=False) total_loss = [] total_acc = [] with torch.no_grad(): for input,target in test_dataloader: input = input.to(config.device) target = target.to(config.device) #计算得到预测值 output = model(input) #计算损失 loss = F.nll_loss(output,target) #反向传播,计算损失 total_loss.append(loss.item()) #计算准确率 pred = output.max(dim=-1)[-1] total_acc.append(pred.eq(target).float().mean().item()) print("test loss:{},test acc:{}".format(np.mean(total_loss),np.mean(total_acc))) if __name__ == ‘__main__‘: eval()
D:\anaconda\python.exe C:/Users/liuxinyu/Desktop/pytorch_test/day3/手写数字识别/train.py epoch:0 idx:460,loss:0.32289110562095413: 100%|██████████| 469/469 [00:24<00:00, 19.05it/s] test loss:0.17968503131142147,test acc:0.9453125 epoch:1 idx:460,loss:0.15012750004513145: 100%|█████████▉| 468/469 [00:20<00:00, 22.10it/s]epoch:1 idx:460,loss:0.15012750004513145: 100%|██████████| 469/469 [00:20<00:00, 22.52it/s] test loss:0.12370304338916947,test acc:0.9624208860759493 epoch:2 idx:460,loss:0.10398845713577534: 99%|█████████▉| 464/469 [00:21<00:00, 22.78it/s]epoch:2 idx:460,loss:0.10398845713577534: 100%|█████████▉| 467/469 [00:21<00:00, 22.71it/s]epoch:2 idx:460,loss:0.10398845713577534: 100%|██████████| 469/469 [00:21<00:00, 21.82it/s] test loss:0.10385569722592077,test acc:0.9697389240506329 epoch:3 idx:460,loss:0.07973297938720653: 100%|█████████▉| 467/469 [00:22<00:00, 23.12it/s]epoch:3 idx:460,loss:0.07973297938720653: 100%|██████████| 469/469 [00:22<00:00, 20.84it/s] test loss:0.08691684670652015,test acc:0.9754746835443038 epoch:4 idx:460,loss:0.0650228117158285: 100%|█████████▉| 468/469 [00:21<00:00, 24.06it/s]epoch:4 idx:460,loss:0.0650228117158285: 100%|██████████| 469/469 [00:21<00:00, 21.79it/s] test loss:0.0803159438309413,test acc:0.9760680379746836 epoch:5 idx:460,loss:0.05270117848966101: 100%|██████████| 469/469 [00:21<00:00, 21.92it/s] test loss:0.08102699166423158,test acc:0.9759691455696202 epoch:6 idx:460,loss:0.04386751471317642: 100%|██████████| 469/469 [00:19<00:00, 24.58it/s] test loss:0.07991968260347089,test acc:0.9769580696202531 epoch:7 idx:460,loss:0.03656852366544161: 100%|██████████| 469/469 [00:15<00:00, 31.20it/s] test loss:0.07767781678917288,test acc:0.9774525316455697 epoch:8 idx:460,loss:0.03112584312896925: 100%|██████████| 469/469 [00:14<00:00, 32.41it/s] test loss:0.07755146227494071,test acc:0.9773536392405063 epoch:9 idx:460,loss:0.025217091969725495: 100%|██████████| 469/469 [00:14<00:00, 31.53it/s] test loss:0.07112929566845863,test acc:0.9802215189873418
标签:mode ice 预测 get float user module cond 优化
原文地址:https://www.cnblogs.com/LiuXinyu12378/p/12314982.html