标签:red The https merge 情况 更新 lin nbsp logs
module或简写为mod,提供一个用于执行Symbol算的中高级接口,可理解为module是执行Symbol定义好的程序的机器。
module.Module接受Symbol作为输入:
data = mx.sym.Variable(‘data‘) fc1 = mx.sym.FullyConnected(data, name=‘fc1‘, num_hidden=128) act1 = mx.sym.Activation(fc1, name=‘relu1‘, act_type="relu") fc2 = mx.sym.FullyConnected(act1, name=‘fc2‘, num_hidden=10) out = mx.sym.SoftmaxOutput(fc2, name = ‘softmax‘) mod = mx.mod.Module(out) # create a module by given a Symbol 根据symbol建立module
关于module的一套训练流程见这里。本节的目的是选择一些常用module API,包括一些重要的属性和方法做个分析。
module包提供了以下几个module:最主要的还是第二个。
显然BaseModule是其他所有module的基类, 基类提供以下方法:
1. 初始化空间:
2. 参数操作:
3. 训练预测
4. 前反向传播
5. 参数更新
6. 输入输出
7. 其他
以上这些方法是所有类共有的,而Module类自己还有下面的内置方法:
这么多方法选择一些常用重要的方法进行介绍。
module表示计算组件。可把module看作是一台计算机器。模块可以执行向前和向后传递并更新模型中的参数。
一个module有几个状态:
before binding:为了使module产生交互,它必须能够在初始状态(bind之前)已知以下信息:
after binding:绑定后,module应能提供以下更丰富的信息:
binded:bool,指示是否已分配计算所需的内存缓冲区。
for_training:模块是否绑定进行训练。
params_initialized:bool,指示此模块的参数是否已初始化。
optimizer_initialized:bool,指示是否定义并初始化了优化器。
inputs_need_grad:bool,指示是否需要相输入数据的梯度。在实现模块组合时可能很有用。
data_shapes:(name、shape)的列表。理论上,由于内存是分配的,可以直接提供数据数组。但在数据并行的情况下,数据数组的形状可能与从外部世界看的不同。
label_shapes:(name、shape)的列表。如果模块不需要标签(如顶部不包含loss函数),或者模块未绑定以进行训练,则此值可能为[]。
outpu_shapes:(name、shape)的列表。
get_params():返回一个(arg_params,aux_params)的元组。每一个都是一个name到NDArray的映射。由于NDArray总是使用CPU,用于计算的实际参数可能存在于其他设备(GPU)上,此函数将检索最新参数。
set_params(arg_params,aux_params):为执行计算的设备分配参数。
init_params(…):一个更灵活的接口来分配或初始化参数。
bind():为计算准备环境。
init_optimizer():安装用于参数更新的优化器。
prepare():根据当前数据批准备模块。
forward(data_batch):前向操作。
backward(out_grads=None):反向操作。
update():根据安装的优化器更新参数。
get_outputs():获取上一个前向操作的输出。
get_input_grads():获取与上一个后向操作中计算的输入的梯度。
update_metric(metric,labels,pre_sliced=False):更新之前前向传播结果的性能度量,就是metric。
当这些中间层API被正确实现时,以下高级API将自动可用于模块:
1. score
(eval_data, eval_metric, num_batch=None, batch_end_callback=None, score_end_callback=None, reset=True, epoch=0, sparse_row_id_fn=None)[source]
这个方法用于预测eval_data的结果,并根据eval_metric提供的指标进行评估
一个例子:
# An example of using score for prediction. # Evaluate accuracy on val_dataiter metric = mx.metric.Accuracy() mod.score(val_dataiter, metric) mod.score(val_dataiter, [‘mse‘, ‘acc‘])
2. predict
(eval_data, num_batch=None, merge_batches=True, reset=True, always_output_list=False, sparse_row_id_fn=None)
这方法主要是得到eval_data的测试结果。
# An example of using `predict` for prediction. # Predict on the first 10 batches of val_dataiter mod.predict(eval_data=val_dataiter, num_batch=10)
3. fit
(train_data, eval_data=None, eval_metric=‘acc‘, epoch_end_callback=None, batch_end_callback=None, kvstore=‘local‘, optimizer=‘sgd‘, optimizer_params=((‘learning_rate‘, 0.01), ), eval_end_callback=None, eval_batch_end_callback=None, initializer=, arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, sparse_row_id_fn=None)[source]
这个是最重要的方法,用于训练网络。
# An example of using fit for training. # Assume training dataIter and validation dataIter are ready # Assume loading a previously checkpointed model sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3) mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer=‘sgd‘, optimizer_params={‘learning_rate‘:0.01, ‘momentum‘: 0.9}, arg_params=arg_params, aux_params=aux_params, eval_metric=‘acc‘, num_epoch=10, begin_epoch=3)
4. get_params()
返回一对字典。类型分别是arg_params和aux_params。每个字典都是从参数映射到NDArray:
# An example of getting module parameters. print mod.get_params()
5. init_params和set_params
init_params
(initializer=, arg_params=None, aux_params=None, allow_missing=False, force_init=False, allow_extra=False)
set_params
(arg_params, aux_params, allow_missing=False, force_init=True, allow_extra=False)[source]
前者初始化参数和辅助参数的状态,后者给参数和辅助参数赋值。
# An example of initializing module parameters. mod.init_params() # An example of setting module parameters. sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, n_epoch_load) mod.set_params(arg_params=arg_params, aux_params=aux_params)
6. load_params和save_params
载入和保存参数
# An example of saving module parameters. mod.save_params(‘myfile‘) # An example of loading module parameters. mod.load_params(‘myfile‘)
7. forawrd(data_batch, is_train=None)和backwar(out_grads=None)
data_batch是DataBatch类型。
import mxnet as mx
from collections import namedtuple
Batch = namedtuple(‘Batch‘, [‘data‘])
data = mx.sym.Variable(‘data‘)
out = data * 2
mod = mx.mod.Module(symbol=out, label_names=None)
mod.bind(data_shapes=[(‘data‘, (1, 10))])
mod.init_params()
data1 = [mx.nd.ones((1, 10))]
mod.forward(Batch(data1))
print mod.get_outputs()[0].asnumpy()
# Forward with data batch of different shape
data2 = [mx.nd.ones((3, 5))]
mod.forward(Batch(data2))
print mod.get_outputs()[0].asnumpy()
# An example of backward computation. mod.backward() print mod.get_input_grads()[0].asnumpy() ]
上面在前向和反向传播时分别用到了get_outputs(merge_multi_context=True)和get_input_grads(merge_multi_context=True)两个函数:
前者得到前向计算后的输出,后者得到反向传播后关于输入的梯度。
8. init_optimizer
(kvstore=‘local‘, optimizer=‘sgd‘, optimizer_params=((‘learning_rate‘, 0.01), ), force_init=False)
指定了这个还需指定fit里面那个optimizer吗???
# An example of initializing optimizer. mod.init_optimizer(optimizer=‘sgd‘, optimizer_params=((‘learning_rate‘, 0.005),))
9. update()和update_metric(eval_metric, labels,pre_sliced=False)
update根据已配置的优化器和上一个前向-反向批处理中计算的梯度更新参数。
# An example of updating module parameters. mod.init_optimizer(kvstore=‘local‘, optimizer=‘sgd‘, optimizer_params=((‘learning_rate‘, 0.01), )) mod.backward() mod.update() print mod.get_params()[0][‘fc3_weight‘].asnumpy() ]
update_metric对上次前向计算的输出求值并累积评估metric。
# An example of updating evaluation metric. mod.forward(data_batch) mod.update_metric(metric, data_batch.label)
10. bind
(data_shapes, label_shapes=None, for_training=True, inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req=‘write‘)
非常重要,不用fit这个高级api的话,就需要bind来搞起训练。
# An example of binding symbols. mod.bind(data_shapes=[(‘data‘, (1, 10, 10))]) # Assume train_iter is already created. mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
标签:red The https merge 情况 更新 lin nbsp logs
原文地址:https://www.cnblogs.com/king-lps/p/13066148.html