• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    解决pytorch读取自制数据集出现过的问题

    问题1

    问题描述:

    TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found class 'PIL.Image.Image'>

    解决方式

    数据格式不对, 把image转成tensor,参数transform进行如下设置就可以了:transform=transform.ToTensor()。注意检测一下transform

    问题2

    问题描述:

    TypeError: append() takes exactly one argument (2 given)

    出现问题的地方

    imgs.append(words[0], int(words[1]))

    解决方式

    加括号,如下

    imgs.append((words[0], int(words[1])))

    问题3

    问题描述

    RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

    解决方式

    数据和模型不在同一设备上,应该要么都在GPU运行,要么都在CPU

    问题4

    问题描述

    RuntimeError: Given groups=1, weight of size [64, 1, 3, 3], expected input[1, 3, 512, 512] to have 1 channels, but got 3 channels instead

    解决方式

    图像竟然是RGB,但我的训练图像是一通道的灰度图,所以得想办法把 mode 转换成灰度图L

    补充:神经网络 pytorch 数据集读取(自动读取数据集,手动读取自己的数据)

    对于pytorch,我们有现成的包装好的数据集可以使用,也可以自己创建自己的数据集,大致来说有三种方法,这其中用到的两个包是datasets和DataLoader

    datasets:用于将数据和标签打包成数据集

    DataLoader:用于对数据集的高级处理,比如分组,打乱,处理等,在训练和测试中可以直接使用DataLoader进行处理

    第一种 现成的打包数据集

    这种比较简答,只需要现成的几行代码和一个路径就可以完成,但是一般都是常用比如cifar-10

    对于常用数据集,可以使用torchvision.datasets直接进行读取,这是对其常用的处理,该类也是继承于torch.utils.data.Dataset。

    #是第一次运行的话会下载数据集 现成的话可以使用root参数指定数据集位置
    # 存放的格式如下图
     
    # 根据接口读取默认的CIFAR10数据 进行训练和测试
    #预处理
    transform = transform.Compose([transform.ToTensor(), transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    #读取数据集
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
    #打包成DataLoader
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=1)
     
    #同上
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
    testloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=False, num_workers=1)
    classes = (1,2,3,4,5,6,7,8,9,10)  #类别定义
     
    #使用
     for epoch in range(3):
            running_loss = 0.0 #清空loss
            for i, data in enumerate(trainloader, 0):
                # get the inputs
                inputs, labels = data #trainloader返回:id,image,labels
     
                # 将inputs与labels装进Variable中
                inputs, labels = Variable(inputs), Variable(labels)
                
                #使用print代替输出
                print("epoch:", epoch, "的第", i, "个inputs", inputs.data.size(), "labels", labels.data.size())
     

    第二种 自己的图像分类

    这也是一个方便的做法,在pytorch中提供了torchvision.datasets.ImageFolder让我们训练自己的图像。

    要求:创建train和test文件夹,每个文件夹下按照类别名字存储图像就可以实现dataloader

    这里还是拿上个举例子吧,实际上也可以是我们的数据集

    每个下面的布局是这样的

    # 预处理
    transform = transform.Compose([transform.ToTensor(), transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
     
    #使用torchvision.datasets.ImageFolder读取数据集 指定train 和 test文件夹
    img_data = torchvision.datasets.ImageFolder('data/cifar-10/train/', transform=transform)
    data_loader = torch.utils.data.DataLoader(img_data, batch_size=4, shuffle=True, num_workers=1)
     
    testset = torchvision.datasets.ImageFolder('data/cifar-10/test/', transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=1)
     
     for epoch in range(3):
            for i, data in enumerate(trainloader, 0):
                # get the inputs
                inputs, labels = data #trainloader返回:id,image,labels
                # 将inputs与labels装进Variable中
                inputs, labels = Variable(inputs), Variable(labels)
     
                #使用print代替输出
                print("epoch:", epoch, "的第", i, "个inputs", inputs.data.size(), "labels", labels.data.size())

    第三种 一维向量数据集

    这个是比较尴尬的,首先我们

    假设将数存储到txt等文件中,先把他读取出来,读取的部分就不仔细说了,读到一个列表里就可以

    常用的可以是列表等,举例子

    trainlist = []  # 保存特征的列表
     
    targetpath = 'a/b/b'
    filelist = os.listdir(targetpath) #列出文件夹下所有的目录与文件
    filecount = len(filelist)
    # 根据根路径 读取所有文件名 循环读取文件内容 添加到list
    for i in range(filecount):
         filepath = os.path.join(targetpath, filelist[j])
         with open(filepath, 'r') as f:
             line = f.readline()
             # 例如存储格式为 1,2,3,4,5,6 数字之间以逗号隔开
             templist = list(map(int, line.split(',')))
             trainlist.append(templist)
     
    # 数据读取完毕 现在为维度为filecount的列表 我们需要转换格式和类型
    # 将数据转换为Tensor
     
    # 假如我们的两类数据分别存在list0 和 list1中
    split = len(list0) # 用于记录标签的分界
     
    #使用numpy.array 和 torch.from_numpy 连续将其转换为tensor  使用torch.cat拼接
    train0_numpy = numpy.array(list0)
    train1_numpy = numpy.array(list1)
    train_tensor = torch.cat([torch.from_numpy(train0_numpy), torch.from_numpytrain1_numpy)], 0)
    #现在的尺寸是【样本数,长度】 然而在使用神 经网络处理一维数据要求【样本数,维度,长度】
    # 这个维度指的像一个图像实际上是一个二维矩阵 但是有三个RGB通道 实际就为【3,行,列】 那么需要处理三个矩阵
    # 我们需要在我们的数据中加上这个维度信息
    # 注意类型要一样 可以转换
    shaper = train_tensor.shape  #获取维度 【样本数,长度】
    aa = torch.ones((shaper[0], 1, shaper[1])) # 生成目标矩阵
    for i in range(shaper[0]):  # 将所有样本复制到新矩阵
    ·    aa[i][0][:] = train_tensor[i][:]
    train_tensor = aa  # 完成了数据集的转换 【样本数,维度,长度】
     
    # 注 意 如果是读取的图像 我们需要的目标维度是【样本数,维度,size_w,size_h】
    # 卷积接受的输入是这样的四维度 最后的两个是图像的尺寸 维度表示是通道数量 
      
    # 下面是生成标签 标签注意类别之间的分界 split已经在上文计算出来
    # 训练标签的
    total = len(list0) + len(list1)
    train_label = numpy.zeros(total)
    train_label[split+1:total] = 1
    train_label_tensor = torch.from_numpy(train_label).int()
    # print(train_tensor.size(),train_label_tensor.size())
     
    # 搭建dataloader完毕
    train_dataset = TensorDataset(train_tensor, train_label_tensor)
    train_loader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True)
     
    for epoch in range(3):
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data #trainloader返回:id,image,labels
            # 将inputs与labels装进Variable中
            inputs, labels = Variable(inputs), Variable(labels)
     
            #使用print代替输出
            print("epoch:", epoch, "的第", i, "个inputs", inputs.data.size(), "labels", labels.data.size())

    第四种 保存路径和标签的方式创建数据集

    该方法需要略微的麻烦一些,首先你有一个txt,保存了文件名和对应的标签,大概是这个意思

    然后我们在程序中,根据给定的根目录找到文件,并将标签对应保存

    class Dataset(object):
    """An abstract class representing a Dataset.
    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """
    def __getitem__(self, index):
    	raise NotImplementedError
    def __len__(self):
    	raise NotImplementedError
    def __add__(self, other):
    	return ConcatDataset([self, other])

    这是dataset的原本内容,getitem就是获取元素的部分,用于返回对应index的数据和标签。那么大概需要做的是我们将txt的内容读取进来,使用程序处理标签和数据

    # coding: utf-8
    from PIL import Image
    from torch.utils.data import Dataset
    class MyDataset(Dataset):
    # 初始化读取txt 可以设定变换
    def __init__(self, txt_path, transform = None, target_transform = None):
    	fh = open(txt_path, 'r')
    	imgs = []
    	for line in fh:
    		line = line.rstrip()
    		words = line.split()
             # 保存列表 其中有图像的数据 和标签
    		imgs.append((words[0], int(words[1])))
    		self.imgs = imgs 
    		self.transform = transform
    		self.target_transform = target_transform
    def __getitem__(self, index):
    	fn, label = self.imgs[index]
    	img = Image.open(fn).convert('RGB') 
    	if self.transform is not None:
    		img = self.transform(img) 
        # 返回图像和标签
        
    	return img, label
    def __len__(self):
    	return len(self.imgs)
     
    # 当然也可以创建myImageFloder 其txt格式在下图显示 
    import os
    import torch
    import torch.utils.data as data
    from PIL import Image 
    def default_loader(path):
        return Image.open(path).convert('RGB')
     
    class myImageFloder(data.Dataset):
        def __init__(self, root, label, transform = None, target_transform=None, loader=default_loader):
            fh = open(label) #打开label文件
            c=0
            imgs=[]  # 保存图像的列表
            class_names=[]
            for line in  fh.readlines(): #读取每一行数据
                if c==0:
                    class_names=[n.strip() for n in line.rstrip().split('	')] 
                else:
                    cls = line.split() #分割为列表
                    fn = cls.pop(0)  #弹出最上的一个
                    if os.path.isfile(os.path.join(root, fn)):  # 组合路径名 读取图像
                        imgs.append((fn, tuple([float(v) for v in cls])))  #添加到列表
                c=c+1
     
            # 设置信息
            self.root = root
            self.imgs = imgs
            self.classes = class_names
            self.transform = transform
            self.target_transform = target_transform
            self.loader = loader
     
        def __getitem__(self, index):  # 获取图像 给定序号
            fn, label = self.imgs[index]  #读取图像的内容和对应的label
            img = self.loader(os.path.join(self.root, fn))
            if self.transform is not None:  # 是否变换
                img = self.transform(img)
            return img, torch.Tensor(label) # 返回图像和label
     
        def __len__(self):
            return len(self.imgs)
        
        def getName(self):
            return self.classes
    # 

    # 而后使用的时候就可以正常的使用
    trainset = MyDataset(txt_path=pathFile,transform = None, target_transform = None)
    # trainset = torch.utils.data.DataLoader(myFloder.myImageFloder(root = "../data/testImages/images", label = "../data/testImages/test_images.txt", transform = mytransform ), batch_size= 2, shuffle= False, num_workers= 2)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=8)

    它的要点是,继承dataset,在初始化中处理txt文本数据,保存对应的数据,并实现对应的功能。

    这其中的原理就是如此,但是注意可能有些许略微不恰当的地方,可能就需要到时候现场调试了。

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    您可能感兴趣的文章:
    • pytorch 如何把图像数据集进行划分成train,test和val
    • Pytorch中的数据集划分&正则化方法
    • pytorch学习教程之自定义数据集
    • pytorch加载语音类自定义数据集的方法教程
    • pytorch加载自己的图像数据集实例
    • pytorch VGG11识别cifar10数据集(训练+预测单张输入图片操作)
    上一篇:SecureCRTSecure7.0查看连接密码的步骤
    下一篇:Python趣味挑战之用pygame实现简单的金币旋转效果
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    解决pytorch读取自制数据集出现过的问题 解决,pytorch,读取,自制,数据,