• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    pytorch中的squeeze函数、cat函数使用

    1 squeeze(): 去除size为1的维度,包括行和列。

    至于维度大于等于2时,squeeze()不起作用。

    行、例:

    >>> torch.rand(4, 1, 3)
     
    (0 ,.,.) =
      0.5391  0.8523  0.9260
     
    (1 ,.,.) =
      0.2507  0.9512  0.6578
     
    (2 ,.,.) =
      0.7302  0.3531  0.9442
     
    (3 ,.,.) =
      0.2689  0.4367  0.6610
    [torch.FloatTensor of size 4x1x3]
    >>> torch.rand(4, 1, 3).squeeze()
     
     0.0801  0.4600  0.1799
     0.0236  0.7137  0.6128
     0.0242  0.3847  0.4546
     0.9004  0.5018  0.4021
    [torch.FloatTensor of size 4x3]

    列、例:

    >>> torch.rand(4, 3, 1)
     
    (0 ,.,.) =
      0.7013
      0.9818
      0.9723
     
    (1 ,.,.) =
      0.9902
      0.8354
      0.3864
     
    (2 ,.,.) =
      0.4620
      0.0844
      0.5707
     
    (3 ,.,.) =
      0.5722
      0.2494
      0.5815
    [torch.FloatTensor of size 4x3x1]
    
    >>> torch.rand(4, 3, 1).squeeze()
     
     0.8784  0.6203  0.8213
     0.7238  0.5447  0.8253
     0.1719  0.7830  0.1046
     0.0233  0.9771  0.2278
    [torch.FloatTensor of size 4x3]

    不变、例:

    >>> torch.rand(4, 3, 2)
     
    (0 ,.,.) =
      0.6618  0.1678
      0.3476  0.0329
      0.1865  0.4349
     
    (1 ,.,.) =
      0.7588  0.8972
      0.3339  0.8376
      0.6289  0.9456
     
    (2 ,.,.) =
      0.1392  0.0320
      0.0033  0.0187
      0.8229  0.0005
     
    (3 ,.,.) =
      0.2327  0.6264
      0.4810  0.6642
      0.8625  0.6334
    [torch.FloatTensor of size 4x3x2]
    
    >>> torch.rand(4, 3, 2).squeeze()
     
    (0 ,.,.) =
      0.0593  0.8910
      0.9779  0.1530
      0.9210  0.2248
     
    (1 ,.,.) =
      0.7938  0.9362
      0.1064  0.6630
      0.9321  0.0453
     
    (2 ,.,.) =
      0.0189  0.9187
      0.4458  0.9925
      0.9928  0.7895
     
    (3 ,.,.) =
      0.5116  0.7253
      0.0132  0.6673
      0.9410  0.8159
    [torch.FloatTensor of size 4x3x2]

    2 cat函数

    >>> t1=torch.FloatTensor(torch.randn(2,3))
    >>> t1
     
    -1.9405  1.2009  0.0018
     0.9463  0.4409 -1.9017
    [torch.FloatTensor of size 2x3]
    
    >>> t2=torch.FloatTensor(torch.randn(2,2))
    >>> t2
     
     0.0942  0.1581
     1.1621  1.2617
    [torch.FloatTensor of size 2x2]
    >>> torch.cat((t1, t2), 1)
     
    -1.9405  1.2009  0.0018  0.0942  0.1581
     0.9463  0.4409 -1.9017  1.1621  1.2617
    [torch.FloatTensor of size 2x5]

    补充:pytorch中 max()、view()、 squeeze()、 unsqueeze()

    查了好多博客都似懂非懂,后来写了几个小例子,瞬间一目了然。

    一、torch.max()

    import torch  
    a=torch.randn(3)
    print("a:\n",a)
    print('max(a):',torch.max(a))
     
    b=torch.randn(3,4)
    print("b:\n",b)
    print('max(b,0):',torch.max(b,0))
    print('max(b,1):',torch.max(b,1))

    输出:

    a:
    tensor([ 0.9558, 1.1242, 1.9503])
    max(a): tensor(1.9503)
    b:
    tensor([[ 0.2765, 0.0726, -0.7753, 1.5334],
    [ 0.0201, -0.0005, 0.2616, -1.1912],
    [-0.6225, 0.6477, 0.8259, 0.3526]])
    max(b,0): (tensor([ 0.2765, 0.6477, 0.8259, 1.5334]), tensor([ 0, 2, 2, 0]))
    max(b,1): (tensor([ 1.5334, 0.2616, 0.8259]), tensor([ 3, 2, 2]))

    max(a),用于一维数据,求出最大值。

    max(a,0),计算出数据中一列的最大值,并输出最大值所在的行号。

    max(a,1),计算出数据中一行的最大值,并输出最大值所在的列号。

    print('max(b,1):',torch.max(b,1)[1])

    输出:只输出行最大值所在的列号

    max(b,1): tensor([ 3,  2,  2])

    torch.max(b,1)[0], 只返回最大值的每个数

    二、view()

    a.view(i,j)表示将原矩阵转化为i行j列的形式

    i为-1表示不限制行数,输出1列

    a=torch.randn(3,4)
    print(a)

    输出:

    tensor([[-0.8146, -0.6592, 1.5100, 0.7615],
    [ 1.3021, 1.8362, -0.3590, 0.3028],
    [ 0.0848, 0.7700, 1.0572, 0.6383]])

    b=a.view(-1,1)
    print(b)

    输出:

    tensor([[-0.8146],
    [-0.6592],
    [ 1.5100],
    [ 0.7615],
    [ 1.3021],
    [ 1.8362],
    [-0.3590],
    [ 0.3028],
    [ 0.0848],
    [ 0.7700],
    [ 1.0572],
    [ 0.6383]])

    i为1,j为-1表示不限制列数,输出1行

    b=a.view(1,-1)
    print(b)
    

    输出:

    tensor([[-0.8146, -0.6592, 1.5100, 0.7615, 1.3021, 1.8362, -0.3590,
    0.3028, 0.0848, 0.7700, 1.0572, 0.6383]])

    i为-1,j为2表示不限制行数,输出2列

    b=a.view(-1,2)
    print(b)

    输出:

    tensor([[-0.8146, -0.6592],
    [ 1.5100, 0.7615],
    [ 1.3021, 1.8362],
    [-0.3590, 0.3028],
    [ 0.0848, 0.7700],
    [ 1.0572, 0.6383]])

    i为-1,j为3表示不限制行数,输出3列

    i为4,j为3表示输出4行3列

    b=a.view(-1,3)
    print(b)
    b=a.view(4,3)
    print(b)

    输出:

    tensor([[-0.8146, -0.6592, 1.5100],
    [ 0.7615, 1.3021, 1.8362],
    [-0.3590, 0.3028, 0.0848],
    [ 0.7700, 1.0572, 0.6383]])
    tensor([[-0.8146, -0.6592, 1.5100],
    [ 0.7615, 1.3021, 1.8362],
    [-0.3590, 0.3028, 0.0848],
    [ 0.7700, 1.0572, 0.6383]])

    三、

    1.torch.squeeze()

    压缩矩阵,我理解为降维

    a.squeeze(i) 压缩第i维,如果这一维维数是1,则这一维可有可无,便可以压缩

    import torch  
    a=torch.randn(1,3,4)
    print(a)
    b=a.squeeze(0)
    print(b)
    c=a.squeeze(1)
    print(c

    输出:

    tensor([[[ 0.4627, 1.6447, 0.1320, 2.0946],
    [-0.0080, 0.1794, 1.1898, -1.2525],
    [ 0.8281, -0.8166, 1.8846, 0.9008]]])

    一页三行4列的矩阵

    第0维为1,则可以通过squeeze(0)删掉,转化为三行4列的矩阵

    tensor([[ 0.4627, 1.6447, 0.1320, 2.0946],
    [-0.0080, 0.1794, 1.1898, -1.2525],
    [ 0.8281, -0.8166, 1.8846, 0.9008]])

    第1维不为1,则不可以压缩

    tensor([[[ 0.4627, 1.6447, 0.1320, 2.0946],
    [-0.0080, 0.1794, 1.1898, -1.2525],
    [ 0.8281, -0.8166, 1.8846, 0.9008]]])

    2.torch.unsqueeze()

    unsqueeze(i) 表示将第i维设置为1

    对压缩为3行4列后的矩阵b进行操作,将第0维设置为1

    c=b.unsqueeze(0)
    print(c)
    

    输出一个一页三行四列的矩阵

    tensor([[[ 0.0661, -0.2386, -0.6610, 1.5774],
    [ 1.2210, -0.1084, -0.1166, -0.2379],
    [-1.0012, -0.4363, 1.0057, -1.5180]]])

    将第一维设置为1

    c=b.unsqueeze(1)
    print(c)

    输出一个3页,一行,4列的矩阵

    tensor([[[-1.0067, -1.1477, -0.3213, -1.0633]],
    [[-2.3976, 0.9857, -0.3462, -0.3648]],
    [[ 1.1012, -0.4659, -0.0858, 1.6631]]])

    另外,squeeze、unsqueeze操作不改变原矩阵

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    您可能感兴趣的文章:
    • Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作
    • pytorch Dataset,DataLoader产生自定义的训练数据案例
    • PyTorch实现重写/改写Dataset并载入Dataloader
    • 一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
    • PyTorch 解决Dataset和Dataloader遇到的问题
    • PyTorch 如何自动计算梯度
    • pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
    • 我对PyTorch dataloader里的shuffle=True的理解
    • pytorch 带batch的tensor类型图像显示操作
    • 解决pytorch下只打印tensor的数值不打印出device等信息的问题
    • Pytorch 如何查看、释放已关闭程序占用的GPU资源
    • Pytorch数据读取之Dataset和DataLoader知识总结
    上一篇:聊聊Pytorch torch.cat与torch.stack的区别
    下一篇:Python自动化之定位方法大杀器xpath
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    pytorch中的squeeze函数、cat函数使用 pytorch,中的,squeeze,函数,