pytorch如何画出数据流图(一文详解OneFlow的DataLoader实现)(1)

撰文 | 赵露阳

在最新的OneFlow v0.5.0版本中,我们增加了许多新特性,比如:

其中,最重要的新特性之一,就是OneFlow的动态图做到了几乎和PyTorch一致,从Tensor、nn.Module、到autograd、functional api等,其中也包括和torch几乎对齐的DataLoader/Dataset设计,笔者有幸开发了OneFlow中的这一模块。

https://github.com/Oneflow-Inc/oneflow/pull/5406 https://github.com/Oneflow-Inc/oneflow/pull/5500 https://github.com/Oneflow-Inc/oneflow/pull/5644 https://github.com/Oneflow-Inc/oneflow/pull/6280

本文将对OneFlow/PyTorch中的DataLoader原理、工作流程进行梳理:

1

简介

简单来说,DataLoader是深度学习中必不可少的,用于处理Dataset产生每个iter过程中批量数据和label的一种数据加载器。正如PyTorch文档中的描述:DataLoader,结合了Sampler、Dataset,提供了对某个dataset可迭代的数据集合。DataLoader支持单进程、多进程的加载数据集合。

2、dataloader原理核心组建

DataLoader工作原理的简单总结:

1.Dataloader是负责数据加载的核心;DataLoaderIter是具体执行单位。dataloader进入到每一次iter中都会通过DataloaderIter来处理具体的数据加载过程;

2.Dataset是数据集的基类,任何自定义数据集都需要继承它并通过重写getitem方法来定义取数据的方式;

3.Sampler是负责index相关的采样器、每个iter迭代都会通过Sampler生成要采样的数据集的index;

4.Fetcher更像是数据的收集器。根据Sampler产生的batch个index去数据集中fetch对应的数据、并通过相应的collate_fn方法将获取的数据收集打包成最终可用的形式,返回给DataLoader。

使用示例1.MNIST

下面用PyTorch官方examples的一个简单例子,用MNIST数据集训练分类网络来说明DataLoader的用法:

transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) dataset1 = datasets.MNIST('../data', train=True, download=True, transform=transform) dataset2 = datasets.MNIST('../data', train=False, transform=transform) train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

可以看到,dataset1、dataset2分别是表示数据集的训练集、测试集。在PyTorch中是通过torchvision.datasets.MNIST定义的。MNIST继承自VisionDataset,而VisionDataset则继承自torch.utils.data.Dataset。在MNIST中,实现了数据集最重要的getitem方法,用于根据index取对应数据:

def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ img, target = self.data[index], int(self.targets[index]) # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img.numpy(), mode='L') if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target

在OneFlow中,oneflow.utils.data对应torch.utils.data;flowvision对应torchvision,使用方式几乎完全一致。例如:对应MNIST数据集,即可直接通过flowvision.datasets.MNIST使用。

dataset1、dataset2定义完成后,传入分别用于训练、验证的dataloader(train_loader、test_loader)。之后,在train/test的循环中,即可迭代dataloader获取每个iter的数据和label:

def train(args, model, device, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) ....

2.ImageNet

这里还是用PyTorch官方examples里ImageNet数据集的训练为例:

train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) val_loader = torch.utils.data.DataLoader( datasets.ImageFolder(valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

可以看见,大体流程和上面的MNIST差不多:

1.先是构造Dataset,这里为通过datasets.ImageFolder构造。ImageFolder是用于读取/处理以文件夹形式存放的图片数据集:

class ImageFolder(DatasetFolder): r"""A generic data loader where the images are arranged in this way by default: .. code-block:: shell root/dog/xxx.png root/dog/xxy.png root/dog/[...]/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/[...]/asd932_.png This class inherits from :class:`~vision.datasets.DatasetFolder` so the same methods can be overridden to customize the dataset. Args: root (string): Root directory path. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. is_valid_file (callable, optional): A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files) Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples """ def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, loader: Callable[[str], Any] = default_loader, is_valid_file: Optional[Callable[[str], bool]] = None, ): super(ImageFolder, self).__init__( root, loader, IMG_EXTENSIONS if is_valid_file is None else None, transform=transform, target_transform=target_transform, is_valid_file=is_valid_file, ) self.imgs = self.samples

可以看到其继承自DatasetFolder、初始化时主要参数有:

DatasetFolder中实现了Dataset中最重要的getitem方法:

def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ path, target = self.samples[index] sample = self.loader(path) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target

通过getitem定义了如何根据index取到相应数据的方式。

2.其次如果是多机分布式训练,则Sampler需要使用专门为分布式训练设计的DistributedSampler类(否则不用特殊设置,用默认的即可);这里还有个细节,训练集和验证集上,对dataset做了不同的transform,训练集用了RandomResizedCrop、RandomHorizontalFlip;验证集则是Resize、CenterCrop,经过transform后,最终通过ToTensor方法转化成Tensor。

3.构造用于训练、验证的Dataloader(train_loader、val_loader),后面的使用方式就很简单了,在train/eval的loop中直接使用即可:

for i, (images, target) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) if args.gpu is not None: images = images.cuda(args.gpu, non_blocking=True) if torch.cuda.is_available(): target = target.cuda(args.gpu, non_blocking=True) .....

3、dataloader工作流程

下面结合代码看一下主要流程:

Dataset

任何自定义数据集,必须继承Dataset类并实现_getitem__方法,用于定义根据传入的index获取数据的方式。同时,自定义数据集也可选重写len方法,用于判断数据集的size。

class Dataset(Generic[T_co]): r"""An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~flow.utils.data.Sampler` implementations and the default options of :class:`~flow.utils.data.DataLoader`. .. note:: :class:`~flow.utils.data.DataLoader` by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. """ def __getitem__(self, index) -> T_co: raise NotImplementedError def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]": return ConcatDataset([self, other])

DataLoader

DataLoader是整个数据处理过程的核心。

class DataLoader(Generic[T_co]): def __init__( self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, shuffle: bool = False, sampler: Optional[Sampler[int]] = None, batch_sampler: Optional[Sampler[Sequence[int]]] = None, num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: int = 2, persistent_workers: bool = False ): ... ... # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up # since '_BaseDataLoaderIter' references 'DataLoader'. def __iter__(self) -> "_BaseDataLoaderIter": # When using a single worker the returned iterator should be # created everytime to avoid reseting its state # However, in the case of a multiple workers iterator # the iterator is only created once in the lifetime of the # DataLoader object so that workers can be reused if self.persistent_workers and self.num_workers > 0: if self._iterator is None: self._iterator = self._get_iterator() else: self._iterator._reset(self) return self._iterator else: return self._get_iterator() def _get_iterator(self) -> "_BaseDataLoaderIter": if self.num_workers == 0 or self.num_workers == 1: return _SingleProcessDataLoaderIter(self) else: self.check_worker_number_rationality() return _MultiProcessingDataLoaderIter(self)

DataLoader在每一个iter迭代过程中,最重要的就是通过上面的__iter__方法完成取数据和label。__iter__里通过_get_iterator方法获取相应的DataLoaderIter实例。

DataLoaderIter

DataLoaderIter负责DataLoader在每个迭代中具体事务的处理。

class _BaseDataLoaderIter(object): def __init__(self, loader: DataLoader) -> None: self._dataset = loader.dataset self._dataset_kind = loader._dataset_kind self._IterableDataset_len_called = loader._IterableDataset_len_called self._auto_collation = loader._auto_collation self._drop_last = loader.drop_last self._index_sampler = loader._index_sampler self._num_workers = loader.num_workers self._prefetch_factor = loader.prefetch_factor self._pin_memory = False self._timeout = loader.timeout self._collate_fn = loader.collate_fn self._sampler_iter = iter(self._index_sampler) self._base_seed = flow.tensor([0], dtype=flow.int64).uniform_().numpy().item() # TODO: flow.empty() # self._base_seed = flow.empty((), dtype=flow.int64).random_(generator=loader.generator).item() self._persistent_workers = loader.persistent_workers self._num_yielded = 0 self._profile_name = "enumerate(DataLoader)#{}.__next__".format( self.__class__.__name__ ) def __iter__(self) -> "_BaseDataLoaderIter": return self def _reset(self, loader, first_iter=False): self._sampler_iter = iter(self._index_sampler) self._num_yielded = 0 self._IterableDataset_len_called = loader._IterableDataset_len_called def _next_index(self): return next(self._sampler_iter) # may raise StopIteration def _next_data(self): raise NotImplementedError def __next__(self) -> Any: if self._sampler_iter is None: self._reset() data = self._next_data() self._num_yielded = 1 if ( self._dataset_kind == _DatasetKind.Iterable and self._IterableDataset_len_called is not None and self._num_yielded > self._IterableDataset_len_called ): warn_msg = ( "Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} " "samples have been fetched. " ).format(self._dataset, self._IterableDataset_len_called, self._num_yielded) if self._num_workers > 1: warn_msg = "Multiprocessing dataloader is not support yet!" warnings.warn(warn_msg) return data def __len__(self) -> int: return len(self._index_sampler) def __getstate__(self): raise NotImplementedError("{} cannot be pickled", self.__class__.__name__) class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): def __init__(self, loader): super(_SingleProcessDataLoaderIter, self).__init__(loader) assert self._timeout == 0 assert 0 <= self._num_workers <= 1 self._dataset_fetcher = _DatasetKind.create_fetcher( self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last, ) def _next_data(self): index = self._next_index() # may raise StopIteration if self._pin_memory: raise NotImplementedError("Dataloader pin memory is not support yet!") return self._dataset_fetcher.fetch(index)

在每一个iter迭代时,会调用_BaseDataLoaderIter的__next__方法,进而调用自类实现的_next_data方法获取数据。以_SingleProcessDataLoaderIter为例:

Fetcher

Fetcher作为数据收集器,会根据Sampler产生的batch的index,来从数据集中切分、收集、打包成完整可用的一个batch的数据,并返回给DataLoader使用。

class _MapDatasetFetcher(_BaseDatasetFetcher): def __init__(self, dataset, auto_collation, collate_fn, drop_last): super(_MapDatasetFetcher, self).__init__( dataset, auto_collation, collate_fn, drop_last ) def fetch(self, possibly_batched_index): if self.auto_collation: data = [self.dataset[idx] for idx in possibly_batched_index] else: data = self.dataset[possibly_batched_index] return self.collate_fn(data)

Fetcher这里和DataLoaderIter(BaseDataLoaderIter)_类似,_都有一个基类的实现BaseDatasetFetcher。根据不同的数据类型,进入到不同的子类实现中,这里以常用的_MapDatasetFetcher的子类实现为例,看一下Fetcher的主要工作。

可以看见,主要就是:

1.根据传入的batch个index列表,去dataset中去切分相应的数据,返回的是取出后的batch个数据的列表;

2.根据传入的或自定义的collate_fn方法,收集处理这batch个数据,并打包成训练/验证时可直接使用的Tensor。

4、multiprocessing dataloader工作原理原理

普通的单进程DataLoader在处理每个iter的数据处理是iter-by-iter且同步的,受制于Python没有实际上的多线程执行,所以单进程的DataLoader通常是比较慢的。多进程DataLoader,即通过Python的multiprocessing开启多个Python的worker进程,譬如开启4个worker进程后,理论上每单位时间可以处理4个iter的数据集,加速数据处理/加载的过程。

单进程DataLoader下,由于数据处理是iter-by-iter的,下一个iter的处理需要等待当前iter完成后才可开始;多进程DataLoader和单进程DataLoader的主要区别就在于可以通过Python的multiprocessing模块,启动多个worker进程加速这个过程。

这里以4进程的DataLoader为例:

DataLoader的主线程将当前iter的任务下发给worker1之后,再下发下一个iter的任务给worker2....直至下发第4个iter的处理任务给worker4。这一步骤主要在dataloader.py的L1024-L1026中实现:

# prime the prefetch loop for _ in range(self._prefetch_factor * self._num_workers): self._try_put_index()

陆续发送完index后,这4个worker可以并行的工作,陆续完成自己iter的处理任务后,将结果塞入一个Queue队列中,DataLoader的主线程从队列中取数据即可。

具体到每个worker的工作流程,其实和单进程的DataLoader工作流程是类似的,下面主要介绍下多进程和单进程DataLoader的区别,以及多个worker之间是如何协同工作的。

工作流程

_MultiProcessingDataLoaderIter

def _next_data(self): # DataLoaderIter通过此方法获取每个iter的数据,主要调用_get_data实现 def _get_data(self): # _get_data方法中,主要通过调用_try_get_data()获取数据 def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): # 从主进程的_data_queue中获取数据 ... try: data = self._data_queue.get(timeout=timeout) return (True, data) except Exception as e: ... def _process_data(self, data): # 主要工作即:1.通过_try_put_index()来将下一个iter的index放入一个活跃的worker进程中 # 2.同时标记_rcvd_idx,使其增加1。 self._rcvd_idx = 1 self._try_put_index() if isinstance(data, ExceptionWrapper): data.reraise() return data def _try_put_index(self): # 主要工作即遍历所有workers,找到第一个活跃的worker(worker_queue_idx标识) # 将index和_send_idx信息放入此worker的index_queue中 # 每个worker拥有独立的index_queue,收到index_queue的信息后即开始工作 assert self._tasks_outstanding < self._prefetch_factor * self._num_workers try: index = self._next_index() except StopIteration: return for _ in range(self._num_workers): # find the next active worker, if any worker_queue_idx = next(self._worker_queue_idx_cycle) if self._workers_status[worker_queue_idx]: break else: # not found (i.e., didn't break) return self._index_queues[worker_queue_idx].put((self._send_idx, index)) self._task_info[self._send_idx] = (worker_queue_idx,) self._tasks_outstanding = 1 self._send_idx = 1

_next_data()

⬇️

_get_data() ➡️ _try_get_data()

⬇️

_process_data() ➡️ _try_put_index()

每个worker独立工作,主要代码在oneflow/python/oneflow/utils/data/_utils/worker.py的_worker_loop()方法中:

while watchdog.is_alive(): try: r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue if isinstance(r, _ResumeIteration): # Acknowledge the main process data_queue.put((r, None)) iteration_end = False # Recreate the fetcher for worker-reuse policy fetcher = _DatasetKind.create_fetcher( dataset_kind, dataset, auto_collation, collate_fn, drop_last ) continue elif r is None: # Received the final signal assert done_event.is_set() or iteration_end break elif done_event.is_set() or iteration_end: # `done_event` is set. But I haven't received the final signal # (None) yet. I will keep continuing until get it, and skip the # processing steps. continue idx, index = r data: Union[_IterableDatasetStopIteration, ExceptionWrapper] if init_exception is not None: data = init_exception init_exception = None else: try: data = fetcher.fetch(index) except Exception as e: if ( isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable ): data = _IterableDatasetStopIteration(worker_id) # Set `iteration_end` # (1) to save future `next(...)` calls, and # (2) to avoid sending multiple `_IterableDatasetStopIteration`s. iteration_end = True else: # It is important that we don't store exc_info in a variable. # `ExceptionWrapper` does the correct thing. # See NOTE [ Python Traceback Reference Cycle Problem ] data = ExceptionWrapper( where="in DataLoader worker process {}".format(worker_id) ) data_queue.put((idx, data)) del data, idx, index, r # save memory except KeyboardInterrupt: # Main process will raise KeyboardInterrupt anyways. pass

每个worker在自己的worker loop中,一旦

r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)获取index_queue中的index数据,就会开始工作:

idx, index = r >> data = fetcher.fetch(index) 这部分内容和之前描述的单进程DataLoader的工作流程没有区别。

当获取到处理完成的数据data后,会将其放入到data loader main线程的data_queue中: data_queue.put((idx, data)) 等待DataLoader主线程从queue中获取结果。

以上即为多进程DataLoader的主要工作流程。

5

结语

本文梳理总结了DataLoader/Dataset,希望能对大家了解OneFlow/PyTorch动态图模式下的DataLoader/Dataset工作原理有所帮助。

对齐PyTorch的DataLoader/Dataset只是第一步,后续仍然面临着效率瓶颈等问题,因为即使使用了multiprocess的DataLoader,在某些情况下,图像解码、Python下调用C op执行各种transform时仍可能遭遇性能问题,造成训练过程中GPU打不满/等待CPU数据处理等情况,后续需要考虑更高效的解决方案(如Dali等)。

欢迎下载体验OneFlow新一代开源深度学习框架:https://github.com/Oneflow-Inc/oneflow

,