• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    pytorch中DataLoader()过程中遇到的一些问题

    如下所示:

    RuntimeError: stack expects each tensor to be equal size, but got [3, 60, 32] at entry 0 and [3, 54, 32] at entry 2

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.Resize((224)) ###

    原因是

    transforms.Resize() 的参数设置问题,改为如下设置就可以了

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.Resize((224,224)),

    同理,val_dataset中也调整为transforms.Resize((224,224))。

    补充:pytorch之dataloader深入剖析

    - dataloader本质是一个可迭代对象,使用iter()访问,不能使用next()访问;

    - 使用iter(dataloader)返回的是一个迭代器,然后可以使用next访问;

    - 也可以使用`for inputs, labels in dataloaders`进行可迭代对象的访问;

    - 一般我们实现一个datasets对象,传入到dataloader中;然后内部使用yeild返回每一次batch的数据;

    ① DataLoader本质上就是一个iterable(跟python的内置类型list等一样),并利用多进程来加速batch data的处理,使用yield来使用有限的内存 ​

    ② Queue的特点

    当队列里面没有数据时: queue.get() 会阻塞, 阻塞的时候,其它进程/线程如果有queue.put() 操作,本线程/进程会被通知,然后就可以 get 成功。

    当数据满了: queue.put() 会阻塞

    ③ DataLoader是一个高效,简洁,直观的网络输入数据结构,便于使用和扩展

    输入数据PipeLine

    pytorch 的数据加载到模型的操作顺序是这样的:

    ① 创建一个 Dataset 对象

    ② 创建一个 DataLoader 对象

    ③ 循环这个 DataLoader 对象,将img, label加载到模型中进行训练

    dataset = MyDataset()
    dataloader = DataLoader(dataset)
    num_epoches = 100
    for epoch in range(num_epoches):
    for img, label in dataloader:
    ....

    所以,作为直接对数据进入模型中的关键一步, DataLoader非常重要。

    首先简单介绍一下DataLoader,它是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,只要是用PyTorch来训练模型基本都会用到该接口(除非用户重写…),该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。

    官方对DataLoader的说明是:“数据加载由数据集和采样器组成,基于python的单、多进程的iterators来处理数据。”关于iterator和iterable的区别和概念请自行查阅,在实现中的差别就是iterators有__iter__和__next__方法,而iterable只有__iter__方法。

    1.DataLoader

    先介绍一下DataLoader(object)的参数:

    dataset(Dataset): 传入的数据集

    batch_size(int, optional): 每个batch有多少个样本

    shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序

    sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False

    batch_sampler(Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)

    num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)

    collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数

    pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.

    drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…

    如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。

    timeout(numeric, optional): 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0

    worker_init_fn (callable, optional): 每个worker初始化函数 If not None, this will be called on each

    worker subprocess with the worker id (an int in [0, num_workers - 1]) as
    input, after seeding and before data loading. (default: None) 
    

    - 首先dataloader初始化时得到datasets的采样list

    class DataLoader(object):
        r"""
        Data loader. Combines a dataset and a sampler, and provides
        single- or multi-process iterators over the dataset.
        Arguments:
            dataset (Dataset): dataset from which to load the data.
            batch_size (int, optional): how many samples per batch to load
                (default: 1).
            shuffle (bool, optional): set to ``True`` to have the data reshuffled
                at every epoch (default: False).
            sampler (Sampler, optional): defines the strategy to draw samples from
                the dataset. If specified, ``shuffle`` must be False.
            batch_sampler (Sampler, optional): like sampler, but returns a batch of
                indices at a time. Mutually exclusive with batch_size, shuffle,
                sampler, and drop_last.
            num_workers (int, optional): how many subprocesses to use for data
                loading. 0 means that the data will be loaded in the main process.
                (default: 0)
            collate_fn (callable, optional): merges a list of samples to form a mini-batch.
            pin_memory (bool, optional): If ``True``, the data loader will copy tensors
                into CUDA pinned memory before returning them.
            drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
                if the dataset size is not divisible by the batch size. If ``False`` and
                the size of dataset is not divisible by the batch size, then the last batch
                will be smaller. (default: False)
            timeout (numeric, optional): if positive, the timeout value for collecting a batch
                from workers. Should always be non-negative. (default: 0)
            worker_init_fn (callable, optional): If not None, this will be called on each
                worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
                input, after seeding and before data loading. (default: None)
        .. note:: By default, each worker will have its PyTorch seed set to
                  ``base_seed + worker_id``, where ``base_seed`` is a long generated
                  by main process using its RNG. However, seeds for other libraies
                  may be duplicated upon initializing workers (w.g., NumPy), causing
                  each worker to return identical random numbers. (See
                  :ref:`dataloader-workers-random-seed` section in FAQ.) You may
                  use ``torch.initial_seed()`` to access the PyTorch seed for each
                  worker in :attr:`worker_init_fn`, and use it to set other seeds
                  before data loading.
        .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
                     unpicklable object, e.g., a lambda function.
        """
        __initialized = False
        def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                     num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
                     timeout=0, worker_init_fn=None):
            self.dataset = dataset
            self.batch_size = batch_size
            self.num_workers = num_workers
            self.collate_fn = collate_fn
            self.pin_memory = pin_memory
            self.drop_last = drop_last
            self.timeout = timeout
            self.worker_init_fn = worker_init_fn
            if timeout  0:
                raise ValueError('timeout option should be non-negative')
            if batch_sampler is not None:
                if batch_size > 1 or shuffle or sampler is not None or drop_last:
                    raise ValueError('batch_sampler option is mutually exclusive '
                                     'with batch_size, shuffle, sampler, and '
                                     'drop_last')
                self.batch_size = None
                self.drop_last = None
            if sampler is not None and shuffle:
                raise ValueError('sampler option is mutually exclusive with '
                                 'shuffle')
            if self.num_workers  0:
                raise ValueError('num_workers option cannot be negative; '
                                 'use num_workers=0 to disable multiprocessing.')
            if batch_sampler is None:
                if sampler is None:
                    if shuffle:
                        sampler = RandomSampler(dataset)  //将list打乱
                    else:
                        sampler = SequentialSampler(dataset)
                batch_sampler = BatchSampler(sampler, batch_size, drop_last)
            self.sampler = sampler
            self.batch_sampler = batch_sampler
            self.__initialized = True
        def __setattr__(self, attr, val):
            if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
                raise ValueError('{} attribute should not be set after {} is '
                                 'initialized'.format(attr, self.__class__.__name__))
            super(DataLoader, self).__setattr__(attr, val)
        def __iter__(self):
            return _DataLoaderIter(self)
        def __len__(self):
            return len(self.batch_sampler)

    其中:RandomSampler,BatchSampler已经得到了采用batch数据的index索引;yield batch机制已经在!!!

    class RandomSampler(Sampler):
        r"""Samples elements randomly, without replacement.
        Arguments:
            data_source (Dataset): dataset to sample from
        """
        def __init__(self, data_source):
            self.data_source = data_source
        def __iter__(self):
            return iter(torch.randperm(len(self.data_source)).tolist())
        def __len__(self):
            return len(self.data_source)
    class BatchSampler(Sampler):
        r"""Wraps another sampler to yield a mini-batch of indices.
        Args:
            sampler (Sampler): Base sampler.
            batch_size (int): Size of mini-batch.
            drop_last (bool): If ``True``, the sampler will drop the last batch if
                its size would be less than ``batch_size``
        Example:
            >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
            [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
            >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
            [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
        """
        def __init__(self, sampler, batch_size, drop_last):
            if not isinstance(sampler, Sampler):
                raise ValueError("sampler should be an instance of "
                                 "torch.utils.data.Sampler, but got sampler={}"
                                 .format(sampler))
            if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
    
                    batch_size = 0:
                raise ValueError("batch_size should be a positive integeral value, "
                                 "but got batch_size={}".format(batch_size))
            if not isinstance(drop_last, bool):
                raise ValueError("drop_last should be a boolean value, but got "
                                 "drop_last={}".format(drop_last))
            self.sampler = sampler
            self.batch_size = batch_size
            self.drop_last = drop_last
        def __iter__(self):
            batch = []
            for idx in self.sampler:
                batch.append(idx)
                if len(batch) == self.batch_size:
                    yield batch
                    batch = []
            if len(batch) > 0 and not self.drop_last:
                yield batch
        def __len__(self):
            if self.drop_last:
                return len(self.sampler) // self.batch_size
            else:
                return (len(self.sampler) + self.batch_size - 1) // self.batch_size

    - 其中 _DataLoaderIter(self)输入为一个dataloader对象;如果num_workers=0很好理解,num_workers!=0引入多线程机制,加速数据加载过程;

    - 没有多线程时:batch = self.collate_fn([self.dataset[i] for i in indices])进行将index转化为data数据,返回(image,label);self.dataset[i]会调用datasets对象的

    __getitem__()方法

    - 多线程下,会为每个线程创建一个索引队列index_queues;共享一个worker_result_queue数据队列!在_worker_loop方法中加载数据;

    class _DataLoaderIter(object):
        r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
        def __init__(self, loader):
            self.dataset = loader.dataset
            self.collate_fn = loader.collate_fn
            self.batch_sampler = loader.batch_sampler
            self.num_workers = loader.num_workers
            self.pin_memory = loader.pin_memory and torch.cuda.is_available()
            self.timeout = loader.timeout
            self.done_event = threading.Event()
            self.sample_iter = iter(self.batch_sampler)
            base_seed = torch.LongTensor(1).random_().item()
            if self.num_workers > 0:
                self.worker_init_fn = loader.worker_init_fn
                self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
                self.worker_queue_idx = 0
                self.worker_result_queue = multiprocessing.SimpleQueue()
                self.batches_outstanding = 0
                self.worker_pids_set = False
                self.shutdown = False
                self.send_idx = 0
                self.rcvd_idx = 0
                self.reorder_dict = {}
                self.workers = [
                    multiprocessing.Process(
                        target=_worker_loop,
                        args=(self.dataset, self.index_queues[i],
                              self.worker_result_queue, self.collate_fn, base_seed + i,
                              self.worker_init_fn, i))
                    for i in range(self.num_workers)]
                if self.pin_memory or self.timeout > 0:
                    self.data_queue = queue.Queue()
                    if self.pin_memory:
                        maybe_device_id = torch.cuda.current_device()
                    else:
                        # do not initialize cuda context if not necessary
                        maybe_device_id = None
                    self.worker_manager_thread = threading.Thread(
                        target=_worker_manager_loop,
                        args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
                              maybe_device_id))
                    self.worker_manager_thread.daemon = True
                    self.worker_manager_thread.start()
                else:
                    self.data_queue = self.worker_result_queue
                for w in self.workers:
                    w.daemon = True  # ensure that the worker exits on process exit
                    w.start()
                _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
                _set_SIGCHLD_handler()
                self.worker_pids_set = True
                # prime the prefetch loop
                for _ in range(2 * self.num_workers):
                    self._put_indices()
        def __len__(self):
            return len(self.batch_sampler)
        def _get_batch(self):
            if self.timeout > 0:
                try:
                    return self.data_queue.get(timeout=self.timeout)
                except queue.Empty:
                    raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
            else:
                return self.data_queue.get()
        def __next__(self):
            if self.num_workers == 0:  # same-process loading
                indices = next(self.sample_iter)  # may raise StopIteration
                batch = self.collate_fn([self.dataset[i] for i in indices])
                if self.pin_memory:
                    batch = pin_memory_batch(batch)
                return batch
            # check if the next sample has already been generated
            if self.rcvd_idx in self.reorder_dict:
                batch = self.reorder_dict.pop(self.rcvd_idx)
                return self._process_next_batch(batch)
            if self.batches_outstanding == 0:
                self._shutdown_workers()
                raise StopIteration
            while True:
                assert (not self.shutdown and self.batches_outstanding > 0)
                idx, batch = self._get_batch()
                self.batches_outstanding -= 1
                if idx != self.rcvd_idx:
                    # store out-of-order samples
                    self.reorder_dict[idx] = batch
                    continue
                return self._process_next_batch(batch)
        next = __next__  # Python 2 compatibility
        def __iter__(self):
            return self
        def _put_indices(self):
            assert self.batches_outstanding  2 * self.num_workers
            indices = next(self.sample_iter, None)
            if indices is None:
                return
            self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
            self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
            self.batches_outstanding += 1
            self.send_idx += 1
        def _process_next_batch(self, batch):
            self.rcvd_idx += 1
            self._put_indices()
            if isinstance(batch, ExceptionWrapper):
                raise batch.exc_type(batch.exc_msg)
            return batch
    def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
        global _use_shared_memory
        _use_shared_memory = True
        # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
        # module's handlers are executed after Python returns from C low-level
        # handlers, likely when the same fatal signal happened again already.
        # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
        _set_worker_signal_handlers()
        torch.set_num_threads(1)
        random.seed(seed)
        torch.manual_seed(seed)
        if init_fn is not None:
            init_fn(worker_id)
        watchdog = ManagerWatchdog()
        while True:
            try:
                r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
            except queue.Empty:
                if watchdog.is_alive():
                    continue
                else:
                    break
            if r is None:
                break
            idx, batch_indices = r
            try:
                samples = collate_fn([dataset[i] for i in batch_indices])
            except Exception:
                data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
            else:
                data_queue.put((idx, samples))
                del samples

    - 需要对队列操作,缓存数据,使得加载提速!

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    您可能感兴趣的文章:
    • pytorch锁死在dataloader(训练时卡死)
    • pytorch Dataset,DataLoader产生自定义的训练数据案例
    • 解决Pytorch dataloader时报错每个tensor维度不一样的问题
    • Pytorch dataloader在加载最后一个batch时卡死的解决
    • Pytorch 如何加速Dataloader提升数据读取速度
    • pytorch DataLoader的num_workers参数与设置大小详解
    • pytorch 实现多个Dataloader同时训练
    上一篇:python 爬取影视网站下载链接
    下一篇:Django分页器的用法详解
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    pytorch中DataLoader()过程中遇到的一些问题 pytorch,中,DataLoader,过程中,