码迷,mamicode.com
首页 > 其他好文 > 详细

gluon 实现多层感知机MLP分类FashionMNIST

时间:2018-11-28 19:05:14      阅读:260      评论:0      收藏:0      [点我收藏+]

标签:win   start   读取   ==   info   eval   backward   size   print   

from mxnet import gluon,init
from mxnet.gluon import loss as gloss, nn
from mxnet.gluon import data as gdata
from mxnet import nd,autograd
import gluonbook as gb

import sys

# 读取数据
# 读取数据
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False)

batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
if sys.platform.startswith(win):
    num_workers = 0
else:
    num_workers = 4

# 小批量数据迭代器
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer),batch_size=batch_size,shuffle=True,num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer),batch_size=batch_size,shuffle=False,num_workers=num_workers)

# 定义网络
net = nn.Sequential()
net.add(nn.Dense(256,activation=relu),nn.Dense(10))
net.initialize(init.Normal(sigma=0.01))

# 损失函数
loss = gloss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(net.collect_params(),sgd,{learning_rate:0.5})


def accuracy(y_hat, y):
    return (y_hat.argmax(axis=1) == y.astype(float32)).mean().asscalar()

def evaluate_accuracy(data_iter, net):
    acc = 0
    for X, y in data_iter:
        acc += accuracy(net(X), y)
    return acc / len(data_iter)

num_epochs = 5

def train(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,trainer=None):
    for epoch in range(num_epochs):
        train_l_sum = 0
        train_acc_sum = 0
        for X,y in train_iter:
            with autograd.record():
                y_hat = net(X)
                l = loss(y_hat,y)
            l.backward()

            if trainer is None:
                gb.sgd(params,lr,batch_size)
            else:
                trainer.step(batch_size)

            train_l_sum += l.mean().asscalar()


        test_acc = evaluate_accuracy(test_iter,net)
        print(epoch %d,loss %.4f,test acc %.3f%(epoch+1,train_l_sum / len(train_iter),test_acc))

train(net,train_iter,test_iter,loss,num_epochs,batch_size,None,None,trainer)

技术分享图片

gluon 实现多层感知机MLP分类FashionMNIST

标签:win   start   读取   ==   info   eval   backward   size   print   

原文地址:https://www.cnblogs.com/TreeDream/p/10033557.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!