• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    Pytorch DataLoader shuffle验证方式

    shuffle = False时,不打乱数据顺序

    shuffle = True,随机打乱

    import numpy as np
    import h5py
    import torch
    from torch.utils.data import DataLoader, Dataset  
    h5f = h5py.File('train.h5', 'w');
    data1 = np.array([[1,2,3],
                   [2,5,6],
                  [3,5,6],
                  [4,5,6]])
    data2 = np.array([[1,1,1],
                       [1,2,6],
                      [1,3,6],
                      [1,4,6]])
    h5f.create_dataset(str('data'), data=data1)
    h5f.create_dataset(str('label'), data=data2)
    class Dataset(Dataset):
        def __init__(self):
            h5f = h5py.File('train.h5', 'r')
            self.data = h5f['data']
            self.label = h5f['label']
        def __getitem__(self, index):
            data = torch.from_numpy(self.data[index])
            label = torch.from_numpy(self.label[index])
            return data, label
     
        def __len__(self):
            assert self.data.shape[0] == self.label.shape[0], "wrong data length"
            return self.data.shape[0] 
     
    dataset_train = Dataset()
    loader_train = DataLoader(dataset=dataset_train,
                               batch_size=2,
                               shuffle = True)
     
    for i, data in enumerate(loader_train):
        train_data, label = data
        print(train_data)
     

    pytorch DataLoader使用细节

    背景:

    我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,

    数据变换共有以下内容

    composed = transforms.Compose([transforms.Resize((448, 448)), #  resize
                                   transforms.RandomCrop(300), # random crop
                                   transforms.ToTensor(),
                                   transforms.Normalize(mean=[0.5, 0.5, 0.5],  # normalize
                                                        std=[0.5, 0.5, 0.5])])

    简单的数据读取类, 进返回PIL格式的image:

    class MyDataset(data.Dataset):    
        def __init__(self, labels_file, root_dir, transform=None):
            with open(labels_file) as csvfile:
                self.labels_file = list(csv.reader(csvfile))
            self.root_dir = root_dir
            self.transform = transform
            
        def __len__(self):
            return len(self.labels_file)
        
        def __getitem__(self, idx):
            im_name = os.path.join(root_dir, self.labels_file[idx][0])
            im = Image.open(im_name)
            
            if self.transform:
                im = self.transform(im)
                
            return im

    下面是主程序

    labels_file = "F:/test_temp/labels.csv"
    root_dir = "F:/test_temp"
    dataset_transform = MyDataset(labels_file, root_dir, transform=composed)
    dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)
    """原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """
    for eopch in range(2):
        plt.figure(figsize=(6, 6)) 
        for ind, i in enumerate(dataloader):
            a = i[0, :, :, :].numpy().transpose((1, 2, 0))
            plt.subplot(1, 3, ind+1)
            plt.imshow(a)
    

    从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    您可能感兴趣的文章:
    • 我对PyTorch dataloader里的shuffle=True的理解
    • Pytorch在dataloader类中设置shuffle的随机数种子方式
    • pytorch 实现多个Dataloader同时训练
    • 解决Pytorch dataloader时报错每个tensor维度不一样的问题
    • pytorch中DataLoader()过程中遇到的一些问题
    • Pytorch dataloader在加载最后一个batch时卡死的解决
    • pytorch锁死在dataloader(训练时卡死)
    上一篇:python 爬取吉首大学网站成绩单
    下一篇:Python爬虫实战之爬取携程评论
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    Pytorch DataLoader shuffle验证方式 Pytorch,DataLoader,shuffle,验证,