码迷,mamicode.com
首页 > 其他好文 > 详细

Pytorch-Load Predict

时间:2021-03-30 13:16:42      阅读:0      评论:0      收藏:0      [点我收藏+]

标签:shuffle   mic   tran   pytorch   pool   UNC   loading   super   ORC   

  • 载入训练好的模型并且进行预测

代码:

import torch
import numpy as np
import torchvision #torch的视觉包
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data.dataloader import DataLoader
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import PIL.Image as Image
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
class CNNModel(nn.Module):
    def __init__(self):   #初始化
        super(CNNModel, self).__init__()  #调用父类
        self.conv1 = nn.Conv2d(1, 20,5)   #二维卷积 输入特征的维数是1 1*28*28 输出为20个特征维度 卷积核为5
        self.conv2 = nn.Conv2d(20, 12,5)  #输入为20 输出为12  12*8*8
        self.fc1 = nn.Linear(12*4*4, 100,bias=True)   #线性层
        self.fc2 = nn.Linear(100, 10,bias=True)   #线性层
  
    def forward(self, x):
        x=x
        x = self.conv1(x) #前向卷积
        x = F.relu(x)
        x = F.max_pool2d(x,kernel_size=2,stride=2)
         
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x,kernel_size=2,stride=2)#取最大值 12*8*8 变成 12*4*4
         
        x = x.reshape(-1,12*4*4) #转为1维的数据
        x = self.fc1(x)
        x = F.relu(x)
         
        x=self.fc2(x)
        return x
cnnmodel=torch.load(‘D:/Project_Encyclopedia/cnnmodel.pkl‘)#导入模型
print(cnnmodel)
cnnmodel1=torch.load(‘D:/Project_Encyclopedia/cnnmodel.pt‘)#导入模型
print(cnnmodel1)
print(type(cnnmodel))
print(type(cnnmodel1))
cnnmodel1=CNNModel()
cnnmodel1.load_state_dict(torch.load(‘D:/Project_Encyclopedia/cnnmodel.pt‘))
print(type(cnnmodel1))
root=‘D:\Project_Encyclopedia‘
mnist=torchvision.datasets.MNIST(root,train=False,transform=ToTensor(),target_transform=None,download=False)
bs=8
mnist_loader=torch.utils.data.DataLoader(dataset=mnist,batch_size=bs,shuffle=True,pin_memory=True)
len(mnist) #现在为测试集
batch=next(iter(mnist_loader))
image,labels=batch
out=cnnmodel1(image)
out.shape
out.argmax(dim=1)
labels
grid=torchvision.utils.make_grid(image,nrow=8)#创建一个网络
plt.figure(figsize=(15,15))
plt.imshow(np.transpose(grid,(1,2,0)))
print("labels:",labels)
print(‘predicts:‘,out.argmax(dim=1))

技术图片

 

 

  

Pytorch-Load Predict

标签:shuffle   mic   tran   pytorch   pool   UNC   loading   super   ORC   

原文地址:https://www.cnblogs.com/jgg54335/p/14589866.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!