标签:shu data ref 访问 rabl href 初始化 加载 tom
pytorch 的数据加载到模型的操作顺序如下:
1.创建一个 Dataset 对象
2.创建一个 DataLoader对象
3.循环这个 DataLoader对象,将data, label加载到模型中进行训练
pytorch中文文档中的torch.utils.data
表示Dataset的抽象类。所有其他数据集都应该进行子类化。所有子类应该override__len__和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。
源码参考点击
class Dataset(object):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)
参考今夜无风的博客 和 pytorch之dataloader深入剖析
数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化。
DataLoader本质是一个可迭代对象,使用iter()访问,不能使用next()访问;使用iter(DataLoader)返回的是一个迭代器,然后可以使用next访问。
DataLoader本质上就是一个iterable(跟python的内置类型list等一样),并利用多进程来加速batch data的处理,使用yield来使用有限的内存。
标签:shu data ref 访问 rabl href 初始化 加载 tom
原文地址:https://www.cnblogs.com/ytxwzqin/p/13086436.html