码迷,mamicode.com
首页 > Web开发 > 详细

mxnet 神经网络训练和预测

时间:2018-12-04 18:58:13      阅读:200      评论:0      收藏:0      [点我收藏+]

标签:asi   net   ram   back   loaded   expect   dict   cpu   ack   

https://mxnet.incubator.apache.org/tutorials/basic/module.html

技术分享图片

import logging
import random
logging.getLogger().setLevel(logging.INFO)

import mxnet as mx
import numpy as np

mx.random.seed(1234)
np.random.seed(1234)
random.seed(1234)

# 准备数据
fname = mx.test_utils.download(https://s3.us-east-2.amazonaws.com/mxnet-public/letter_recognition/letter-recognition.data)
data = np.genfromtxt(fname=fname,delimiter=,)[:,1:]
label = np.array([ord(l.split(,)[0])-ord(A) for l in open(fname, r)])

batch_size = 32
ntrain = int(data.shape[0]*0.8)

train_iter = mx.io.NDArrayIter(data[:ntrain,:],label[:ntrain],batch_size,shuffle=True)
val_iter = mx.io.NDArrayIter(data[ntrain:,:],label[ntrain:],batch_size)


# 定义网络
net = mx.sym.Variable(data)
net = mx.sym.FullyConnected(net, name=fc1, num_hidden=64)
net = mx.sym.Activation(net, name=relu1, act_type="relu")
net = mx.sym.FullyConnected(net, name=fc2, num_hidden=26)
net = mx.sym.SoftmaxOutput(net, name=softmax)
mx.viz.plot_network(net, node_attrs={"shape":"oval","fixedsize":"false"})



# # 创建模块
mod = mx.mod.Module(symbol=net,
                    context=mx.cpu(),
                    data_names=[data],
                    label_names=[softmax_label])

# # 中层接口
# # 训练模型
# mod.bind(data_shapes=train_iter.provide_data,label_shapes=train_iter.provide_label)
# mod.init_params(initializer=mx.init.Uniform(scale=.1))
# mod.init_optimizer(optimizer=‘sgd‘,optimizer_params=((‘learning_rate‘,0.1),))
# metric = mx.metric.create(‘acc‘)
#
# for epoch in range(100):
#     train_iter.reset()
#     metric.reset()
#     for batch in train_iter:
#         mod.forward(batch,is_train=True)
#         mod.update_metric(metric,batch.label)
#         mod.backward()
#         mod.update()
#     print(‘Epoch %d,Training %s‘ % (epoch,metric.get()))

# fit 高层接口
train_iter.reset()
mod = mx.mod.Module(symbol=net,
                    context=mx.cpu(),
                    data_names=[data],
                    label_names=[softmax_label])

mod.fit(train_iter,
        eval_data=val_iter,
        optimizer=sgd,
        optimizer_params={learning_rate:0.1},
        eval_metric=acc,
        num_epoch=10)



# 预测和评估
y = mod.predict(val_iter)
assert y.shape == (4000,26)

# 评分
score = mod.score(val_iter,[acc])
print("Accuracy score is %f"%(score[0][1]))
assert score[0][1] > 0.76, "Achieved accuracy (%f) is less than expected (0.76)" % score[0][1]

# 保存和加载
# 构造一个回调函数保存检查点
model_prefix = mx_mlp
checkpoint = mx.callback.do_checkpoint(model_prefix)

mod = mx.mod.Module(symbol=net)
mod.fit(train_iter,num_epoch=5,epoch_end_callback=checkpoint)

sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
assert sym.tojson() == net.tojson()

# assign the loaded parameters to the module
mod.set_params(arg_params, aux_params)

mod = mx.mod.Module(symbol=sym)
mod.fit(train_iter,
        num_epoch=21,
        arg_params=arg_params,
        aux_params=aux_params,
        begin_epoch=3)
assert score[0][1] > 0.77, "Achieved accuracy (%f) is less than expected (0.77)" % score[0][1]

 

mxnet 神经网络训练和预测

标签:asi   net   ram   back   loaded   expect   dict   cpu   ack   

原文地址:https://www.cnblogs.com/TreeDream/p/10065616.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!