mxnet 线性模型
<wiz_code_mirror>
74
def data_loader(batch_size, X, y, shuffle=False):
1
import mxnet
2
import mxnet.ndarray as nd
3
from mxnet import gluon
4
from mxnet import autograd
5
6
7
# create data
8
9
def set_data(true_w, true_b, num_examples, *args, **kwargs):
10
num_inputs = len(true_w)
11
X = nd.random_normal(shape=(num_examples, num_inputs))
12
y = 0
13
for num in range(num_inputs):
14
# print(num)
15
y += true_w[num] * X[:, num]
16
y += true_b
17
y += 0.1 * nd.random_normal(shape=y.shape)
18
return X, y
19
20
21
# create data loader
22
def data_loader(batch_size, X, y, shuffle=False):
23
data_set = gluon.data.ArrayDataset(X, y)
24
data_iter = gluon.data.DataLoader(dataset=data_set, batch_size=batch_size, shuffle=shuffle)
25
return data_iter
26
27
28
# create net
29
def set_net(node_num):
30
net = gluon.nn.Sequential()
31
net.add(gluon.nn.Dense(node_num))
32
net.initialize()
33
return net
34
35
36
# create trainer
37
def trainer(net, loss_method, learning_rate):
38
trainer = gluon.Trainer(
39
net.collect_params(), loss_method, {‘learning_rate‘: learning_rate}
40
)
41
return trainer
42
43
44
square_loss = gluon.loss.L2Loss()
45
46
47
# start train
48
def start_train(epochs, batch_size, data_iter, net, loss_method, tariner, num_examples):
49
for e in range(epochs):
50
total_loss = 0
51
for data, label in data_iter:
52
with autograd.record():
53
output = net(data)
54
loss = loss_method(output, label)
55
loss.backward()
56
trainer.step(batch_size)
57
total_loss += nd.sum(loss).asscalar()
58
print("第 %d次训练, 平均损失: %f" % (e, total_loss / 1000))
59
dense = net[0]
60
61
print(dense.weight.data())
62
print(dense.bias.data())
63
return dense.weight.data(), dense.bias.data()
64
65
66
true_w = [5, 8, 6]
67
true_b = 6
68
X, y = set_data(true_w=true_w, true_b=true_b, num_examples=1000)
69
data_iter = data_loader(batch_size=10, X=X, y=y, shuffle=True)
70
net = set_net(1)
71
trainer = trainer(net=net, loss_method=‘sgd‘, learning_rate=0.1)
72
start_train(epochs=5, batch_size=10, data_iter=data_iter, net=net, loss_method=square_loss, tariner=trainer,
73
num_examples=1000)
74