标签:格式 ima 数据流 它的 poc 不能 iterable orm sub
本文先就DataBatch、DataDesc、DataIter三个主要用到的类进行介绍,然后引出Mxnet中常见的迭代器。
MXNet中的数据迭代器Data iterators类似于Python迭代器对象。在Python中,函数iter允许通过对可iterable对象(如Python列表)调用next()按顺序获取项。迭代器提供了一个抽象接口,用于遍历各种类型的iterable集合,而无需公开底层数据源的详细信息。
在MXNet中,数据迭代器在每次调用next()时返回一批数据作为DataBatch。数据批处理通常包含n个训练示例及其相应的标签。这里n是迭代器的批处理大小。在数据流结束时,当没有更多的数据可读取时,迭代器会引发像Python iter那样的StopIteration异常。DataBatch结构在这。
看看DataBatch类以及他的方法:
class mxnet.io.
DataBatch
(data, label=None, pad=None, index=None, bucket_key=None, provide_data=None, provide_label=None)[source]
参数:
这个类就是一个批量的样本,每次data iterator调用next(),就会返回一个DataBatch,也即一个批量的样本。如果输入的数据是图像的话,这些图像的shape取决于DataDesc中的provide_data参数:
class mxnet.io.
DataDesc
[source]
DataDesc用于存储数据的名字,形状,类型和格式信息。
参数:
方法:
get_batch_axis
(layout):获取与批处理大小相对应的维度。get_list
(shapes, types):从属性列表中获取DataDesc列表。每个训练样本的名称、形状、类型和布局等信息及其相应的标签可以通过DataBatch中的provide_data和provide_label属性作为DataDesc数据描述符对象提供。这里定义了DataDesc的结构。
class mxnet.io.
DataIter
(batch_size=0)[source]
是mxnet中数据迭代器dataiter的基类。mxnet中所有的数据IO都由该类的子类来处理。mxnet中的dataiter迭代器是和python中的iterators很像,每次调用nxet都会返回一个Databatch代表了一个批量中的数据。
参数:
方法:
MXNet中的所有IO都通过mx.io.DataIter以及它的子类来处理。本文将讨论MXNet提供的一些常用迭代器。
import mxnet as mx %matplotlib inline import os import sys import subprocess import numpy as np import matplotlib.pyplot as plt import tarfile import warnings warnings.filterwarnings("ignore", category=DeprecationWarning)
import numpy as np # fix the seed np.random.seed(42) mx.random.seed(42) data = np.random.rand(100,3) label = np.random.randint(0, 10, (100,)) data_iter = mx.io.NDArrayIter(data=data, label=label, batch_size=30) for batch in data_iter: print([batch.data, batch.label, batch.pad])
#lets save `data` into a csv file first and try reading it back np.savetxt(‘data.csv‘, data, delimiter=‘,‘) data_iter = mx.io.CSVIter(data_csv=‘data.csv‘, data_shape=(3,), batch_size=30) for batch in data_iter: print([batch.data, batch.pad])
当所有内置的迭代器不能满足时,可以定制。
mxnet中的迭代器应当满足:
创建新迭代器时,可以从头开始定义迭代器,也可以重用现有迭代器之一。例如,在图像caption应用程序中,输入示例是图像,而标签是句子。因此,我们可以通过以下方法创建新的迭代器:
一个实例:
1 class SimpleIter(mx.io.DataIter): 2 def __init__(self, data_names, data_shapes, data_gen, 3 label_names, label_shapes, label_gen, num_batches=10): 4 self._provide_data = list(zip(data_names, data_shapes)) 5 self._provide_label = list(zip(label_names, label_shapes)) 6 self.num_batches = num_batches 7 self.data_gen = data_gen 8 self.label_gen = label_gen 9 self.cur_batch = 0 10 11 def __iter__(self): 12 return self 13 14 def reset(self): 15 self.cur_batch = 0 16 17 def __next__(self): 18 return self.next() 19 20 @property 21 def provide_data(self): 22 return self._provide_data 23 24 @property 25 def provide_label(self): 26 return self._provide_label 27 28 def next(self): 29 if self.cur_batch < self.num_batches: 30 self.cur_batch += 1 31 data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)] 32 label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)] 33 return mx.io.DataBatch(data, label) 34 else: 35 raise StopIteration
构建一个mlp:
import mxnet as mx num_classes = 10 net = mx.sym.Variable(‘data‘) net = mx.sym.FullyConnected(data=net, name=‘fc1‘, num_hidden=64) net = mx.sym.Activation(data=net, name=‘relu1‘, act_type="relu") net = mx.sym.FullyConnected(data=net, name=‘fc2‘, num_hidden=num_classes) net = mx.sym.SoftmaxOutput(data=net, name=‘softmax‘) print(net.list_arguments()) print(net.list_outputs())
通过mxnet的module模块来喂入数据。
import logging logging.basicConfig(level=logging.INFO) n = 32 data_iter = SimpleIter([‘data‘], [(n, 100)], [lambda s: np.random.uniform(-1, 1, s)], [‘softmax_label‘], [(n,)], [lambda s: np.random.randint(0, num_classes, s)]) mod = mx.mod.Module(symbol=net) mod.fit(data_iter, num_epoch=5)
因为data_iter是迭代器类型,所以可以有get_data()、get_label()、get_index()、next()等方法。
同样因为data_iter.next()返回的是一个DataBatch类型,所以可以有data_iter.next().data、data_iter.next().label等属性。
其余内容见:mxnet 数据读取
标签:格式 ima 数据流 它的 poc 不能 iterable orm sub
原文地址:https://www.cnblogs.com/king-lps/p/13057643.html