标签:shuffle mic tran pytorch pool UNC loading super ORC
代码:
import torch
import numpy as np
import torchvision #torch的视觉包
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data.dataloader import DataLoader
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import PIL.Image as Image
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
class CNNModel(nn.Module):
def __init__(self): #初始化
super(CNNModel, self).__init__() #调用父类
self.conv1 = nn.Conv2d(1, 20,5) #二维卷积 输入特征的维数是1 1*28*28 输出为20个特征维度 卷积核为5
self.conv2 = nn.Conv2d(20, 12,5) #输入为20 输出为12 12*8*8
self.fc1 = nn.Linear(12*4*4, 100,bias=True) #线性层
self.fc2 = nn.Linear(100, 10,bias=True) #线性层
def forward(self, x):
x=x
x = self.conv1(x) #前向卷积
x = F.relu(x)
x = F.max_pool2d(x,kernel_size=2,stride=2)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x,kernel_size=2,stride=2)#取最大值 12*8*8 变成 12*4*4
x = x.reshape(-1,12*4*4) #转为1维的数据
x = self.fc1(x)
x = F.relu(x)
x=self.fc2(x)
return x
cnnmodel=torch.load(‘D:/Project_Encyclopedia/cnnmodel.pkl‘)#导入模型
print(cnnmodel)
cnnmodel1=torch.load(‘D:/Project_Encyclopedia/cnnmodel.pt‘)#导入模型
print(cnnmodel1)
print(type(cnnmodel))
print(type(cnnmodel1))
cnnmodel1=CNNModel()
cnnmodel1.load_state_dict(torch.load(‘D:/Project_Encyclopedia/cnnmodel.pt‘))
print(type(cnnmodel1))
root=‘D:\Project_Encyclopedia‘
mnist=torchvision.datasets.MNIST(root,train=False,transform=ToTensor(),target_transform=None,download=False)
bs=8
mnist_loader=torch.utils.data.DataLoader(dataset=mnist,batch_size=bs,shuffle=True,pin_memory=True)
len(mnist) #现在为测试集
batch=next(iter(mnist_loader))
image,labels=batch
out=cnnmodel1(image)
out.shape
out.argmax(dim=1)
labels
grid=torchvision.utils.make_grid(image,nrow=8)#创建一个网络
plt.figure(figsize=(15,15))
plt.imshow(np.transpose(grid,(1,2,0)))
print("labels:",labels)
print(‘predicts:‘,out.argmax(dim=1))

标签:shuffle mic tran pytorch pool UNC loading super ORC
原文地址:https://www.cnblogs.com/jgg54335/p/14589866.html