码迷,mamicode.com
首页 > Web开发 > 详细

mxnet 线性模型

时间:2018-02-18 16:23:20      阅读:222      评论:0      收藏:0      [点我收藏+]

标签:eth   lower   vsc   mode   head   word   tag   san   color   

mxnet 线性模型

<wiz_code_mirror>
 
 
 
74
def data_loader(batch_size, X, y, shuffle=False):
 
 
 
1
import mxnet
2
import mxnet.ndarray as nd
3
from mxnet import gluon
4
from mxnet import autograd
5

6

7
# create data
8

9
def set_data(true_w, true_b, num_examples, *args, **kwargs):
10
    num_inputs = len(true_w)
11
    X = nd.random_normal(shape=(num_examples, num_inputs))
12
    y = 0
13
    for num in range(num_inputs):
14
        # print(num)
15
        y += true_w[num] * X[:, num]
16
    y += true_b
17
    y += 0.1 * nd.random_normal(shape=y.shape)
18
    return X, y
19

20

21
# create data loader
22
def data_loader(batch_size, X, y, shuffle=False):
23
    data_set = gluon.data.ArrayDataset(X, y)
24
    data_iter = gluon.data.DataLoader(dataset=data_set, batch_size=batch_size, shuffle=shuffle)
25
    return data_iter
26

27

28
# create net
29
def set_net(node_num):
30
    net = gluon.nn.Sequential()
31
    net.add(gluon.nn.Dense(node_num))
32
    net.initialize()
33
    return net
34

35

36
# create trainer
37
def trainer(net, loss_method, learning_rate):
38
    trainer = gluon.Trainer(
39
        net.collect_params(), loss_method, {‘learning_rate‘: learning_rate}
40
    )
41
    return trainer
42

43

44
square_loss = gluon.loss.L2Loss()
45

46

47
# start train
48
def start_train(epochs, batch_size, data_iter, net, loss_method, tariner, num_examples):
49
    for e in range(epochs):
50
        total_loss = 0
51
        for data, label in data_iter:
52
            with autograd.record():
53
                output = net(data)
54
                loss = loss_method(output, label)
55
            loss.backward()
56
            trainer.step(batch_size)
57
            total_loss += nd.sum(loss).asscalar()
58
        print("第 %d次训练, 平均损失: %f" % (e, total_loss / 1000))
59
    dense = net[0]
60

61
    print(dense.weight.data())
62
    print(dense.bias.data())
63
    return dense.weight.data(), dense.bias.data()
64

65

66
true_w = [5, 8, 6]
67
true_b = 6
68
X, y = set_data(true_w=true_w, true_b=true_b, num_examples=1000)
69
data_iter = data_loader(batch_size=10, X=X, y=y, shuffle=True)
70
net = set_net(1)
71
trainer = trainer(net=net, loss_method=‘sgd‘, learning_rate=0.1)
72
start_train(epochs=5, batch_size=10, data_iter=data_iter, net=net, loss_method=square_loss, tariner=trainer,
73
            num_examples=1000)
74

 
 

mxnet 线性模型

标签:eth   lower   vsc   mode   head   word   tag   san   color   

原文地址:https://www.cnblogs.com/liaoxianfu/p/a00a6b34a5de7cdbd50ef91b7121cef3.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!