• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    浅谈pytorch中stack和cat的及to_tensor的坑

    初入计算机视觉遇到的一些坑

    1.pytorch中转tensor

    x=np.random.randint(10,100,(10,10,10))
    x=TF.to_tensor(x)
    print(x)

    这个函数会对输入数据进行自动归一化,比如有时候我们需要将0-255的图片转为numpy类型的数据,则会自动转为0-1之间

    2.stack和cat之间的差别

    stack

    x=torch.randn((1,2,3))
    y=torch.randn((1,2,3))
    z=torch.stack((x,y))#默认dim=0
    print(z.shape)
    #torch.Size([2, 1, 2, 3])

    所以stack的之后的数据也就很好理解了,z[0,...]的数据是x,z[1,...]的数据是y。

    cat

    z=torch.cat((x,y))
    print(z.size())
    #torch.Size([2, 2, 3])

    cat之后的数据 z[0,:,:]是x的值,z[1,:,:]是y的值。

    其中最关键的是stack之后的数据的size会多出一个维度,而cat则不会,有一个很简单的例子来说明一下,比如要训练一个检测模型,label是一些标记点,eg:[x1,y1,x2,y2]

    送入网络的加上batchsize则时Size:[batchsize,4],如果我已经有了两堆数据,data1:Size[128,4],data2:Size[128,4],需要将这两个数据合在一起的话目标data:Size[256,4]。

    显然我们要做的是:torch.cat((data1,data2))

    如果我们的数据是这样:有100个label,每一个label被放进一个list(data)中,[[x1,y1,x2,y2],[x1,y1,x2,y2],...]其中data是一个list长度为100,而list中每一个元素是张图片的标签,size为[4]我们需要将他们合一起成为一Size:[100,4]的的数据。

    显然我们要做的是torch.stack(data)。而且torch.stack的输入参数为list类型!

    补充:pytorch中的cat、stack、tranpose、permute、unsqeeze

    pytorch中提供了对tensor常用的变换操作。

    cat 连接

    对数据沿着某一维度进行拼接。cat后数据的总维数不变。

    比如下面代码对两个2维tensor(分别为2*3,1*3)进行拼接,拼接完后变为3*3还是2维的tensor。

    代码如下:

    import torch
    torch.manual_seed(1)
    x = torch.randn(2,3)
    y = torch.randn(1,3)
    print(x,y)

    结果:

    0.6614 0.2669 0.0617
    0.6213 -0.4519 -0.1661
    [torch.FloatTensor of size 2x3]

    -1.5228 0.3817 -1.0276
    [torch.FloatTensor of size 1x3]

    将两个tensor拼在一起:

    torch.cat((x,y),0)

    结果:

    0.6614 0.2669 0.0617
    0.6213 -0.4519 -0.1661
    -1.5228 0.3817 -1.0276
    [torch.FloatTensor of size 3x3]

    更灵活的拼法:

    torch.manual_seed(1)
    x = torch.randn(2,3)
    print(x)
    print(torch.cat((x,x),0))
    print(torch.cat((x,x),1))

    结果

    // x
    0.6614 0.2669 0.0617
    0.6213 -0.4519 -0.1661
    [torch.FloatTensor of size 2x3]

    // torch.cat((x,x),0)
    0.6614 0.2669 0.0617
    0.6213 -0.4519 -0.1661
    0.6614 0.2669 0.0617
    0.6213 -0.4519 -0.1661
    [torch.FloatTensor of size 4x3]

    // torch.cat((x,x),1)
    0.6614 0.2669 0.0617 0.6614 0.2669 0.0617
    0.6213 -0.4519 -0.1661 0.6213 -0.4519 -0.1661
    [torch.FloatTensor of size 2x6]

    stack,增加新的维度进行堆叠

    而stack则会增加新的维度。

    如对两个1*2维的tensor在第0个维度上stack,则会变为2*1*2的tensor;在第1个维度上stack,则会变为1*2*2的tensor。

    见代码:

    a = torch.ones([1,2])
    b = torch.ones([1,2])
    c= torch.stack([a,b],0) // 第0个维度stack

    输出:

    (0 ,.,.) =
    1 1

    (1 ,.,.) =
    1 1
    [torch.FloatTensor of size 2x1x2]

    c= torch.stack([a,b],1) // 第1个维度stack

    输出:


    (0 ,.,.) =

    1 1

    1 1

    [torch.FloatTensor of size 1x2x2]

    transpose ,两个维度互换

    代码如下:

    torch.manual_seed(1)
    x = torch.randn(2,3)
    print(x)

    原来x的结果:

    0.6614 0.2669 0.0617

    0.6213 -0.4519 -0.1661

    [torch.FloatTensor of size 2x3]

    将x的维度互换

    x.transpose(0,1)

    结果

    0.6614 0.6213

    0.2669 -0.4519

    0.0617 -0.1661

    [torch.FloatTensor of size 3x2]

    permute,多个维度互换,更灵活的transpose

    permute是更灵活的transpose,可以灵活的对原数据的维度进行调换,而数据本身不变。

    代码如下:

    x = torch.randn(2,3,4)
    print(x.size())
    x_p = x.permute(1,0,2) # 将原来第1维变为0维,同理,0→1,2→2
    print(x_p.size())

    结果:

    torch.Size([2, 3, 4])

    torch.Size([3, 2, 4])

    squeeze 和 unsqueeze

    常用来增加或减少维度,如没有batch维度时,增加batch维度为1。

    squeeze(dim_n)压缩,减少dim_n维度 ,即去掉元素数量为1的dim_n维度。

    unsqueeze(dim_n),增加dim_n维度,元素数量为1。

    上代码:

    # 定义张量
    import torch
    
    b = torch.Tensor(2,1)
    b.shape
    Out[28]: torch.Size([2, 1])
    
    # 不加参数,去掉所有为元素个数为1的维度
    b_ = b.squeeze()
    b_.shape
    Out[30]: torch.Size([2])
    
    # 加上参数,去掉第一维的元素为1,不起作用,因为第一维有2个元素
    b_ = b.squeeze(0)
    b_.shape 
    Out[32]: torch.Size([2, 1])
    
    # 这样就可以了
    b_ = b.squeeze(1)
    b_.shape
    Out[34]: torch.Size([2])
    
    # 增加一个维度
    b_ = b.unsqueeze(2)
    b_.shape
    Out[36]: torch.Size([2, 1, 1])
    

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    您可能感兴趣的文章:
    • pytorch dataloader 取batch_size时候出现bug的解决方式
    • pytorch的batch normalize使用详解
    • pytorch方法测试详解——归一化(BatchNorm2d)
    • 解决pytorch下只打印tensor的数值不打印出device等信息的问题
    • Pytorch中TensorBoard及torchsummary的使用详解
    • pytorch Variable与Tensor合并后 requires_grad()默认与修改方式
    • pytorch 带batch的tensor类型图像显示操作
    上一篇:pytorch实现手写数字图片识别
    下一篇:pytorch中[..., 0]的用法说明
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    浅谈pytorch中stack和cat的及to_tensor的坑 浅谈,pytorch,中,stack,和,cat,