码迷,mamicode.com
首页 > Web开发 > 详细

LeNet

时间:2018-11-30 14:14:41      阅读:168      评论:0      收藏:0      [点我收藏+]

标签:return   start   jpg   ali   cal   kernel   oat   bat   worker   

技术分享图片

import mxnet as mx
import sys
from mxnet import autograd,nd
from mxnet import gluon,init
from mxnet.gluon import nn,loss as gloss
from mxnet.gluon import data as gdata

# 读取数据
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False)

batch_size = 256

trainsformer = gdata.vision.transforms.ToTensor()

if sys.platform.startswith(win):
    num_workers = 0
else:
    num_workers = 4

train_iter = gdata.DataLoader(mnist_train.transform_first(trainsformer),batch_size=batch_size,shuffle=True,num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(trainsformer),batch_size=batch_size,shuffle=False,num_workers=num_workers)

# 使用GPU
def try_gpu():
    try:
        ctx = mx.gpu()
        _ = nd.zeros((1,),ctx=ctx)
    except mx.base.MXNetError:
        ctx = mx.cpu()
    return ctx

# 计算正确率
def accuracy(y_hat,y):
    return (y_hat.argmax(axis=1)==y.astype(float32).mean().asscalar())
def evaluate_accuracy(data_iter,net,ctx):
    acc = nd.array([0],ctx=ctx)
    for X,y in data_iter:
        X = X.as_in_context(ctx)
        y = y.as_in_context(ctx)
        acc += accuracy(net(X),y)
    return acc.asscalar() / len(data_iter)

# LeNet,建立卷积神经网络
net = nn.Sequential()
net.add(nn.Conv2D(channels=6, kernel_size=5, activation=sigmoid),
        nn.MaxPool2D(pool_size=2, strides=2),
        nn.Conv2D(channels=16, kernel_size=5, activation=sigmoid),
        nn.MaxPool2D(pool_size=2, strides=2),
        # Dense 会默认将(批量大小,通道,高,宽)形状的输入转换成
        # (批量大小,通道 * 高 * 宽)形状的输入。
        nn.Dense(120, activation=sigmoid),
        nn.Dense(84, activation=sigmoid),
        nn.Dense(10))

X = nd.random.uniform(shape=(1,1,28,28))
net.initialize()

for layer in net:
    X = layer(X)
    print(layer.name,output shape:\t,X.shape)

K = nd.array([[[0, 1], [2, 3]], [[1, 2], [3, 4]]])
K = nd.stack(K, K + 1, K + 2)
print(K)

 

LeNet

标签:return   start   jpg   ali   cal   kernel   oat   bat   worker   

原文地址:https://www.cnblogs.com/TreeDream/p/10043092.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!