码迷,mamicode.com
首页 > Windows程序 > 详细

Module API

时间:2020-06-08 19:14:26      阅读:85      评论:0      收藏:0      [点我收藏+]

标签:red   The   https   merge   情况   更新   lin   nbsp   logs   

module或简写为mod,提供一个用于执行Symbol算的中高级接口,可理解为module是执行Symbol定义好的程序的机器。

module.Module接受Symbol作为输入:

data = mx.sym.Variable(data)
fc1 = mx.sym.FullyConnected(data, name=fc1, num_hidden=128)
act1 = mx.sym.Activation(fc1, name=relu1, act_type="relu")
fc2 = mx.sym.FullyConnected(act1, name=fc2, num_hidden=10)
out = mx.sym.SoftmaxOutput(fc2, name = softmax)
mod = mx.mod.Module(out) # create a module by given a Symbol    根据symbol建立module

关于module的一套训练流程见这里。本节的目的是选择一些常用module API,包括一些重要的属性和方法做个分析。

module包提供了以下几个module:最主要的还是第二个。

技术图片

  

显然BaseModule是其他所有module的基类, 基类提供以下方法

1. 初始化空间:

技术图片

2.  参数操作:

技术图片

 3. 训练预测

技术图片

4. 前反向传播 

技术图片

5. 参数更新

技术图片

6. 输入输出

技术图片

7.  其他

技术图片

以上这些方法是所有类共有的,而Module类自己还有下面的内置方法:

技术图片

这么多方法选择一些常用重要的方法进行介绍。

 

module表示计算组件。可把module看作是一台计算机器。模块可以执行向前和向后传递并更新模型中的参数。 
一个module有几个状态:

  • 初始状态-initial state:内存尚未分配,因此模块尚未准备好进行计算。
  • 绑定-binded:输入、输出和参数的形状都是已知的,内存已分配,模块已准备好进行计算。
  • 参数已初始化:对于具有参数的模块,在初始化参数之前执行计算可能会导致未定义的输出。
  • 优化器已安装:优化器可以安装到module。在此之后,在计算梯度(前向后)之后,可以根据优化器更新模块的参数。

before binding:为了使module产生交互,它必须能够在初始状态(bind之前)已知以下信息:

  • 数据名称data_names:表示所需输入数据名称的字符串类型列表。
  • 输出名称output_names:表示所需输出名称的字符串类型列表。

after binding:绑定后,module应能提供以下更丰富的信息:

  • 状态信息:

binded:bool,指示是否已分配计算所需的内存缓冲区。

for_training:模块是否绑定进行训练。

params_initialized:bool,指示此模块的参数是否已初始化。

optimizer_initialized:bool,指示是否定义并初始化了优化器。

inputs_need_grad:bool,指示是否需要相输入数据的梯度。在实现模块组合时可能很有用。

  • 输入输出信息:

data_shapes:(name、shape)的列表。理论上,由于内存是分配的,可以直接提供数据数组。但在数据并行的情况下,数据数组的形状可能与从外部世界看的不同。
label_shapes:(name、shape)的列表。如果模块不需要标签(如顶部不包含loss函数),或者模块未绑定以进行训练,则此值可能为[]。
outpu_shapes:(name、shape)的列表。

  • 参数信息: 

get_params():返回一个(arg_params,aux_params)的元组。每一个都是一个name到NDArray的映射。由于NDArray总是使用CPU,用于计算的实际参数可能存在于其他设备(GPU)上,此函数将检索最新参数。
set_params(arg_params,aux_params):为执行计算的设备分配参数。
init_params(…):一个更灵活的接口来分配或初始化参数。

  • setup:

bind():为计算准备环境。
init_optimizer():安装用于参数更新的优化器。
prepare():根据当前数据批准备模块。

  • 计算:

forward(data_batch):前向操作。
backward(out_grads=None):反向操作。
update():根据安装的优化器更新参数。
get_outputs():获取上一个前向操作的输出。
get_input_grads():获取与上一个后向操作中计算的输入的梯度。
update_metric(metric,labels,pre_sliced=False):更新之前前向传播结果的性能度量,就是metric。

  

当这些中间层API被正确实现时,以下高级API将自动可用于模块:

  • fit:在数据集上训练模块参数。
  • predict:在数据集上运行预测并收集输出。
  • score:在数据集上运行预测并评估性能

 

1. score(eval_data, eval_metric, num_batch=None, batch_end_callback=None, score_end_callback=None, reset=True, epoch=0, sparse_row_id_fn=None)[source]

这个方法用于预测eval_data的结果,并根据eval_metric提供的指标进行评估 

 

  • eval_data(DataIte类型):要运行预测的评估数据。
  • eval_metric(EvalMetric或EvalMetrics列表):要使用的评估度量。
  • num_batch(int):要运行的批数。默认为“无”,表示在DataIter完成之前运行。
  • batch_end_callback:(函数)也可以是函数列表。
  • reset(bool):默认为True。指示在开始计算之前是否应重置计算数据。
  • epoch(int):默认为0。为了兼容性,这将传递给回调(如果有的话)。在训练期间,这将与训练epoch数相对应。
  • sparse_row_id_fn(回调函数):函数将数据批作为输入并返回str->NDArray的dict。结果dict用于从kvstore中提取行稀疏参数,其中str键是参数的名称,值是要提取的参数的行id。

一个例子: 

# An example of using score for prediction.
# Evaluate accuracy on val_dataiter
metric = mx.metric.Accuracy()
mod.score(val_dataiter, metric)
mod.score(val_dataiter, [mse, acc])

 

2. predict(eval_data, num_batch=None, merge_batches=True, reset=True, always_output_list=False, sparse_row_id_fn=None)

这方法主要是得到eval_data的测试结果。

# An example of using `predict` for prediction.
# Predict on the first 10 batches of val_dataiter
mod.predict(eval_data=val_dataiter, num_batch=10)

 

3. fit(train_data, eval_data=None, eval_metric=‘acc‘, epoch_end_callback=None, batch_end_callback=None, kvstore=‘local‘, optimizer=‘sgd‘, optimizer_params=((‘learning_rate‘, 0.01), ), eval_end_callback=None, eval_batch_end_callback=None, initializer=, arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, sparse_row_id_fn=None)[source]

这个是最重要的方法,用于训练网络。

  • train_data(DataIter型)训练集数据迭代器。
  • eval_data(DataIter型):如果不是None,则将用作验证集,并评估每个epoch之后的性能。
  • eval_metric(str或EvalMetric):默认为“准确率”。用于在训练期间显示的性能度量。其他可能的预定义度量是:“ce”(交叉熵)、“f1”、“mae”、“mse”、“rmse”、“top_k_accurity”。
  • epoch_end_callback(函数或函数列表):将使用当前epoch、symbol、arg_params和aux_params调用每个回调。
  • batch_end_callback(函数或函数列表):将使用BatchEndParam调用每个回调。
  • kvstore(str或kvstore):默认为“local”。
  • optimizer(str或optimizer):默认为“sgd”。
  • optimizer_params(dict)–默认为((“learninf_rate”,0.01),)。优化器构造函数的参数。
  • eval_end_callback(函数或函数列表):这些函数将在每次完整评估结束时调用,度量值覆盖整个eval集。
  • eval_batch_end_callback(函数或函数列表):在评估期间,将在每个batch之后调用这些函数。
  • initializer(initializer):调用初始值设定项以在尚未初始化模块参数时对其进行初始化。
  • arg_params(dict)–默认值为None,如果不是None,则应该是来自经过训练的模型的现有参数或从checkpoint(以前保存的模型)加载的参数。在这种情况下,这里的值将用于初始化模块参数,除非它们已经由用户通过调用init_params或fit进行了初始化。arg_params的优先级高于初始值设定项
  • aux_params(dict)-默认为无。类似于arg_params,除了用于辅助状态。
  • allow_missing(bool)-默认为False。指示当arg_params和aux_params不是None时是否允许缺少参数。如果这是真的,那么丢失的参数将通过初始化器初始化。
  • force_rebind(bool)-默认为False。是否强制重新绑定已绑定的执行器。
  • force_init(bool)-默认为False。指示是否强制初始化,即使参数已初始化。
  • begin_epoch(int)-默认为0。通常,如果从epoch N上一个训练阶段保存的checkpoint恢复,则该值应为N+1。
  • num_epoch(int)–训练的epoch数。
  • sparse_row_id_fn(回调函数)–函数将数据批作为输入并返回str->NDArray的dict。结果dict用于从kvstore中提取行稀疏参数,其中str键是参数的名称,值是要提取的参数的行id。
# An example of using fit for training.
# Assume training dataIter and validation dataIter are ready
# Assume loading a previously checkpointed model
sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer=sgd,
optimizer_params={learning_rate:0.01, momentum: 0.9},
arg_params=arg_params, aux_params=aux_params,
eval_metric=acc, num_epoch=10, begin_epoch=3)

 

4. get_params()

返回一对字典。类型分别是arg_params和aux_params。每个字典都是从参数映射到NDArray:

# An example of getting module parameters.
print mod.get_params()

 

5. init_params和set_params

init_params(initializer=, arg_params=None, aux_params=None, allow_missing=False, force_init=False, allow_extra=False)

set_params(arg_params, aux_params, allow_missing=False, force_init=True, allow_extra=False)[source]

前者初始化参数和辅助参数的状态,后者给参数和辅助参数赋值。

# An example of initializing module parameters.
mod.init_params()

# An example of setting module parameters.
sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, n_epoch_load)
mod.set_params(arg_params=arg_params, aux_params=aux_params)

 

6. load_params和save_params

载入和保存参数

# An example of saving module parameters.
mod.save_params(myfile)

# An example of loading module parameters.
mod.load_params(myfile)

 

7. forawrd(data_batch, is_train=None)和backwar(out_grads=None)

data_batch是DataBatch类型。

import mxnet as mx
from collections import namedtuple
Batch = namedtuple(Batch, [data])
data = mx.sym.Variable(data)
out = data * 2
mod = mx.mod.Module(symbol=out, label_names=None)
mod.bind(data_shapes=[(data, (1, 10))])
mod.init_params()
data1 = [mx.nd.ones((1, 10))]
mod.forward(Batch(data1))
print mod.get_outputs()[0].asnumpy()
# Forward with data batch of different shape
data2 = [mx.nd.ones((3, 5))]
mod.forward(Batch(data2))
print mod.get_outputs()[0].asnumpy()
# An example of backward computation.
mod.backward()
print mod.get_input_grads()[0].asnumpy()
]

上面在前向和反向传播时分别用到了get_outputs(merge_multi_context=True)和get_input_grads(merge_multi_context=True)两个函数:

前者得到前向计算后的输出,后者得到反向传播后关于输入的梯度。

 

8. init_optimizer(kvstore=‘local‘, optimizer=‘sgd‘, optimizer_params=((‘learning_rate‘, 0.01), ), force_init=False)

指定了这个还需指定fit里面那个optimizer吗???

# An example of initializing optimizer.
mod.init_optimizer(optimizer=sgd, optimizer_params=((learning_rate, 0.005),))

 

9. update()和update_metric(eval_metric, labels,pre_sliced=False)

update根据已配置的优化器和上一个前向-反向批处理中计算的梯度更新参数。

# An example of updating module parameters.
mod.init_optimizer(kvstore=local, optimizer=sgd,
optimizer_params=((learning_rate, 0.01), ))
mod.backward()
mod.update()
print mod.get_params()[0][fc3_weight].asnumpy()
]

update_metric对上次前向计算的输出求值并累积评估metric。

# An example of updating evaluation metric.
mod.forward(data_batch)
mod.update_metric(metric, data_batch.label)

 

10. bind(data_shapes, label_shapes=None, for_training=True, inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req=‘write‘)

非常重要,不用fit这个高级api的话,就需要bind来搞起训练

# An example of binding symbols.
mod.bind(data_shapes=[(data, (1, 10, 10))])
# Assume train_iter is already created.
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)

Module API

标签:red   The   https   merge   情况   更新   lin   nbsp   logs   

原文地址:https://www.cnblogs.com/king-lps/p/13066148.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!