码迷,mamicode.com
首页 > Web开发 > 详细

AlexNet 分类 FashionMNIST

时间:2018-11-30 20:17:31      阅读:231      评论:0      收藏:0      [点我收藏+]

标签:except   ali   http   else   lse   class   test   tput   return   

from mxnet import gluon,init,nd,autograd
from mxnet.gluon import data as gdata,nn
from mxnet.gluon import loss as gloss
import mxnet as mx
import time
import os
import sys


# 建立网络
net = nn.Sequential()
# 使用较大的 11 x 11 窗口来捕获物体。同时使用步幅 4 来较大减小输出高和宽。
# 这里使用的输入通道数比 LeNet 中的也要大很多。
net.add(nn.Conv2D(96, kernel_size=11, strides=4, activation=relu),
        nn.MaxPool2D(pool_size=3, strides=2),
        # 减小卷积窗口,使用填充为 2 来使得输入输出高宽一致,且增大输出通道数。
        nn.Conv2D(256, kernel_size=5, padding=2, activation=relu),
        nn.MaxPool2D(pool_size=3, strides=2),
        # 连续三个卷积层,且使用更小的卷积窗口。除了最后的卷积层外,进一步增大了输出通道数。
        # 前两个卷积层后不使用池化层来减小输入的高和宽。
        nn.Conv2D(384, kernel_size=3, padding=1, activation=relu),
        nn.Conv2D(384, kernel_size=3, padding=1, activation=relu),
        nn.Conv2D(256, kernel_size=3, padding=1, activation=relu),
        nn.MaxPool2D(pool_size=3, strides=2),
        # 这里全连接层的输出个数比 LeNet 中的大数倍。使用丢弃层来缓解过拟合。
        nn.Dense(4096, activation="relu"), nn.Dropout(0.5),
        nn.Dense(4096, activation="relu"), nn.Dropout(0.5),
        # 输出层。由于这里使用 Fashion-MNIST,所以用类别数为 10,而非论文中的 1000。
        nn.Dense(10))

X = nd.random.uniform(shape=(1,1,224,224))
net.initialize()
for layer in net:
    X = layer(X)
    print(layer.name,output shape:\t,X.shape)


# 读取数据
# fashionMNIST 28*28 转为224*224
def load_data_fashion_mnist(batch_size, resize=None, root=os.path.join(
        ~, .mxnet, datasets, fashion-mnist)):
    root = os.path.expanduser(root)  # 展开用户路径 ‘~‘。
    transformer = []
    if resize:
        transformer += [gdata.vision.transforms.Resize(resize)]
    transformer += [gdata.vision.transforms.ToTensor()]
    transformer = gdata.vision.transforms.Compose(transformer)
    mnist_train = gdata.vision.FashionMNIST(root=root, train=True)
    mnist_test = gdata.vision.FashionMNIST(root=root, train=False)
    num_workers = 0 if sys.platform.startswith(win32) else 4
    train_iter = gdata.DataLoader(
        mnist_train.transform_first(transformer), batch_size, shuffle=True,
        num_workers=num_workers)
    test_iter = gdata.DataLoader(
        mnist_test.transform_first(transformer), batch_size, shuffle=False,
        num_workers=num_workers)
    return train_iter, test_iter

batch_size = 128
train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=224)


def accuracy(y_hat,y):
    return (y_hat.argmax(axis=1)==y.astype(float32)).mean().asscalar()

def evaluate_accuracy(data_iter,net,ctx):
    acc = nd.array([0],ctx=ctx)
    for X,y in data_iter:
        X = X.as_in_context(ctx)
        y = y.as_in_context(ctx)
        acc+=accuracy(net(X),y)
    return acc.asscalar() / len(data_iter)


# 训练模型
def train(net,train_iter,test_iter,batch_size,trainer,ctx,num_epochs):
    print(training on,ctx)
    loss = gloss.SoftmaxCrossEntropyLoss()

    for epoch in range(num_epochs):
        train_l_sum = 0
        train_acc_sum = 0
        start = time.time()
        for X,y in train_iter:
            X = X.as_in_context(ctx)
            y = y.as_in_context(ctx)

            with autograd.record():
                y_hat = net(X)
                l = loss(y_hat,y)

            l.backward()
            trainer.step(batch_size)

            train_l_sum += l.mean().asscalar()
            train_acc_sum += evaluate_accuracy(test_iter,net,ctx)
        test_acc = evaluate_accuracy(test_iter,net,ctx)
        print(epoch %d, loss %.4f, train acc %.3f, test acc %.3f, 
              time %.1f sec % (epoch+1,train_l_sum/len(train_iter),test_acc,time.time()-start))

def try_gpu():
    try:
        ctx = mx.gpu()
        _ = nd.zeros((1,),ctx=ctx)
    except mx.base.MXNetError:
        ctx = mx.cpu()
    return ctx


lr = 0.01
num_epochs = 5
ctx = try_gpu()

net.initialize(force_reinit=True,ctx=ctx,init=init.Xavier())
trainer = gluon.Trainer(net.collect_params(),sgd,{learning_rate:lr})
train(net,train_iter,test_iter,batch_size,trainer,ctx,num_epochs)

 

 

技术分享图片

AlexNet 分类 FashionMNIST

标签:except   ali   http   else   lse   class   test   tput   return   

原文地址:https://www.cnblogs.com/TreeDream/p/10045670.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!