• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    聊聊Pytorch torch.cat与torch.stack的区别

    torch.cat()函数可以将多个张量拼接成一个张量。torch.cat()有两个参数,第一个是要拼接的张量的列表或是元组;第二个参数是拼接的维度。

    torch.cat()的示例如下图1所示

    图1 torch.cat()

    torch.stack()函数同样有张量列表和维度两个参数。stack与cat的区别在于,torch.stack()函数要求输入张量的大小完全相同,得到的张量的维度会比输入的张量的大小多1,并且多出的那个维度就是拼接的维度,那个维度的大小就是输入张量的个数。

    torch.stack()的示例如下图2所示:

    图2 torch.stack()

    补充:torch.stack()的官方解释,详解以及例子

    可以直接看最下面的【3.例子】,再回头看前面的解释

    在pytorch中,常见的拼接函数主要是两个,分别是:

    1、stack()

    2、cat()

    实际使用中,这两个函数互相辅助:关于cat()参考torch.cat(),但是本文主要说stack()。

    函数的意义:使用stack可以保留两个信息:[1. 序列] 和 [2. 张量矩阵] 信息,属于【扩张再拼接】的函数。

    形象的理解:假如数据都是二维矩阵(平面),它可以把这些一个个平面(矩阵)按第三维(例如:时间序列)压成一个三维的立方体,而立方体的长度就是时间序列长度。

    该函数常出现在自然语言处理(NLP)和图像卷积神经网络(CV)中。

    1. stack()

    官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。

    浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠。

    outputs = torch.stack(inputs, dim=?) → Tensor

    参数

    inputs : 待连接的张量序列。

    注:python的序列数据只有list和tuple。

    dim : 新的维度, 必须在0到len(outputs)之间。

    注:len(outputs)是生成数据的维度大小,也就是outputs的维度值。

    2. 重点

    函数中的输入inputs只允许是序列;且序列内部的张量元素,必须shape相等

    ----举例:[tensor_1, tensor_2,..]或者(tensor_1, tensor_2,..),且必须tensor_1.shape == tensor_2.shape

    dim是选择生成的维度,必须满足0=dimlen(outputs);len(outputs)是输出后的tensor的维度大小

    不懂的看例子,再回过头看就懂了。

    3. 例子

    1.准备2个tensor数据,每个的shape都是[3,3]

    # 假设是时间步T1的输出
    T1 = torch.tensor([[1, 2, 3],
              [4, 5, 6],
              [7, 8, 9]])
    # 假设是时间步T2的输出
    T2 = torch.tensor([[10, 20, 30],
              [40, 50, 60],
              [70, 80, 90]])

    2.测试stack函数

    print(torch.stack((T1,T2),dim=0).shape)
    print(torch.stack((T1,T2),dim=1).shape)
    print(torch.stack((T1,T2),dim=2).shape)
    print(torch.stack((T1,T2),dim=3).shape)
    # outputs:
    torch.Size([2, 3, 3])
    torch.Size([3, 2, 3])
    torch.Size([3, 3, 2])
    '选择的dim>len(outputs),所以报错'
    IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

    可以复制代码运行试试:拼接后的tensor形状,会根据不同的dim发生变化。

    dim shape
    0 [2, 3, 3]
    1 [3, 2, 3]
    2 [3, 3, 2]
    3 溢出报错

    4. 总结

    1、函数作用:

    函数stack()对序列数据内部的张量进行扩维拼接,指定维度由程序员选择、大小是生成后数据的维度区间。

    2、存在意义:

    在自然语言处理和卷及神经网络中, 通常为了保留–[序列(先后)信息] 和 [张量的矩阵信息] 才会使用stack。

    函数存在意义?》》》

    手写过RNN的同学,知道在循环神经网络中输出数据是:一个list,该列表插入了seq_len个形状是[batch_size, output_size]的tensor,不利于计算,需要使用stack进行拼接,保留–[1.seq_len这个时间步]和–[2.张量属性[batch_size, output_size]]。

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    您可能感兴趣的文章:
    • 浅谈pytorch中stack和cat的及to_tensor的坑
    • 对PyTorch torch.stack的实例讲解
    • PyTorch的torch.cat用法
    • PyTorch中torch.tensor与torch.Tensor的区别详解
    上一篇:基于telepath库实现Python和JavaScript之间交换数据
    下一篇:pytorch中的squeeze函数、cat函数使用
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    聊聊Pytorch torch.cat与torch.stack的区别 聊聊,Pytorch,torch.cat,与,torch.stack,