标签:框架 nump __init__ return ini 返回结果 module 黑白 range
如上篇文章所讲,将我们需用的环境搭建完成以后,我们就可以开始AI之路了,下面就让我们来看看第一个网络框架结构——全连接吧。
import torch.nn as nn
#导入所需库
class Net(nn.Module):
#初始化网络结构(设计神经网络)
def __init__(self):
super().__init__()
#设计一个多层结构的神经网络
self.layers = nn.Sequential(
nn.Linear(28*28,512),
#设计一层神经网络,有512个神经元,接受748个
nn.ReLU(),
nn.Linear(512,256),
nn.ReLU(),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128,10),
nn.Softmax(dim=1)
)
# 前向计算(使用神经网络),将数据x输入到网络中,返回结果
def forward(self, x):
return self.layers(x)
***************************************************************************************************************************************
import torch
import torchvision
import torch.nn as nn
from PIL import Image
import torch.utils.data as data
from my_net import Net
import numpy as np
import os
save_path = "module/net_ps.pth"
train_data = torchvision.datasets.MNIST(
root="MNIST_data",#单通道28*28黑白图片(0-9数字)
train=True,
transform=torchvision.transforms.ToTensor(),
download=True
)
test_data = torchvision.datasets.MNIST(
root="MNIST_data",
train=False,
transform=torchvision.transforms.ToTensor(),
download=False
)
if __name__ == ‘__main__‘:
#创建数据加载器,每次从train_data里面取100张数据,打乱
train = data.DataLoader(dataset=train_data,batch_size=100,shuffle=True)#用数据加载器从train中每次加载100张图片并打乱
#实例化网络对象
net = Net()
#判断本地是否已经有网络的参数,如果有,那就加载之前的参数
if os.path.isfile(save_path):
net = torch.load(save_path)
#定义损失函数
loss_fun = nn.MSELoss()#对(h-y)^2求平均
#定义优化器,用这个优化器来优化网络内部的参数
optimizer = torch.optim.Adam(net.parameters())
#取数据,训练网络
for epoch in range(1000000):
for i,(x,y) in enumerate(train):#N C H W形状
#将图片变为100,784
x = x.reshape(-1,28*28)
#将图片输入到网络,得到结果
out = net(x)
#将标签y进行one-hot编码
target = torch.zeros(y.size()[0],10).scatter_(1,y.view(-1,1),1)
#将网络的结果和标签拿来做损失
loss = loss_fun(target,out)
#优化损失
optimizer.zero_grad()#清空梯度
loss.backward()#根据损失进行反向求导
optimizer.step()#更新梯度
#每训练10次,进行一次测试
if i%10 == 0:
out_put = torch.argmax(out,dim=1)
# print("target:",y)
# print("out:",out_put)
print("loss:",loss.item())
#计算准确度
acc = np.mean(np.array(out_put==y,dtype=np.float32))
print("精度:",acc)
#保存网络参数
torch.save(net,save_path)
标签:框架 nump __init__ return ini 返回结果 module 黑白 range
原文地址:https://www.cnblogs.com/wangyueyyy/p/11822340.html