• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    pytorch 带batch的tensor类型图像显示操作

    项目场景

    pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练。训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化。

    那么如何显示dataloader里面带batch的tensor类型的图像呢?

    显示图像

    绘图最常用的库就是matplotlib:

    pip install matplotlib

    显示图像会用到matplotlib.pyplot.imshow方法。查阅官方文档可知,该方法接收的图像的通道数要放到后面:

    数据加载器中数据的维度是[B, C, H, W],我们每次只拿一个数据出来就是[C, H, W],而matplotlib.pyplot.imshow要求的输入维度是[H, W, C],所以我们需要交换一下数据维度,把通道数放到最后面,这里用到pytorch里面的permute方法(transpose方法也行,不过要交换两次,没这个方便,numpy中的transpose方法倒是可以一次交换完成)

    用法示例如下:

    >>> x = torch.randn(2, 3, 5)
    >>> x.size()
    torch.Size([2, 3, 5])
    >>> x.permute(1, 2, 0).size()
    torch.Size([3, 5, 2])

    代码示例

    #%% 导入模块
    import torch
    import matplotlib.pyplot as plt
    from torchvision.utils import make_grid
    from torch.utils.data import DataLoader
    from torchvision import datasets, transforms
    #%% 下载数据集
    train_file = datasets.MNIST(
        root='./dataset/',
        train=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ]),
        download=True
    )
    #%% 制作数据加载器
    train_loader = DataLoader(
        dataset=train_file,
        batch_size=9,
        shuffle=True
    )
    #%% 训练数据可视化
    images, labels = next(iter(train_loader))
    print(images.size())  # torch.Size([9, 1, 28, 28])
    plt.figure(figsize=(9, 9))
    for i in range(9):
        plt.subplot(3, 3, i+1)
        plt.title(labels[i].item())
        plt.imshow(images[i].permute(1, 2, 0), cmap='gray')
        plt.axis('off')
    plt.show()
    

    这里以mnist数据集为例,演示一下显示效果。我这个代码其实还有一点小问题。数据增强的时候我不是进行标准化了嘛,就是在第7行代码:Normalize((0.1307,), (0.3081,))。

    所以,如果你想查看训练集的原始图像,还得反标准化。

    标准化:image = (image-mean)/std

    反标准化:image = image*std+mean

    我拿imagenet中的一个蚂蚁和蜜蜂的子集做了一下实验,标准化前后的区别还是很明显的:

    最终效果

    补充:PIL,plt显示tensor类型的图像

    该方法针对显示Dataloader读取的图像

    PIL 与plt中对应操作不同,但原理是一样的,我试过用下方代码Image的方法在plt上show失败了,原因暂且不知。

     # 方法1:Image.show()
     # transforms.ToPILImage()中有一句
     # npimg = np.transpose(pic.numpy(), (1, 2, 0))
     # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
     img = transforms.ToPILImage(image[0])
     img.show()
    
     # 方法2:plt.imshow(ndarray)
     img = image[0] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
     img = img.numpy() # FloatTensor转为ndarray
     img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后
     # 显示图片
     plt.imshow(img)
     plt.show()
     cnt += 1
    

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    您可能感兴趣的文章:
    • Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作
    • pytorch Dataset,DataLoader产生自定义的训练数据案例
    • PyTorch实现重写/改写Dataset并载入Dataloader
    • 一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
    • PyTorch 解决Dataset和Dataloader遇到的问题
    • PyTorch 如何自动计算梯度
    • pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
    • 我对PyTorch dataloader里的shuffle=True的理解
    • 解决pytorch下只打印tensor的数值不打印出device等信息的问题
    • Pytorch 如何查看、释放已关闭程序占用的GPU资源
    • pytorch中的squeeze函数、cat函数使用
    • Pytorch数据读取之Dataset和DataLoader知识总结
    上一篇:Python 京东云无线宝消息推送功能
    下一篇:学会用Python实现滑雪小游戏,再也不用去北海道啦
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    pytorch 带batch的tensor类型图像显示操作 pytorch,带,batch,的,tensor,类型,