• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    pytorch中的matmul与mm,bmm区别说明

    pytorch中matmul和mm和bmm区别 matmulmmbmm结论

    先看下官网上对这三个函数的介绍。

    matmul

    mm

    bmm

    顾名思义, 就是两个batch矩阵乘法.

    结论

    从官方文档可以看出

    1、mm只能进行矩阵乘法,也就是输入的两个tensor维度只能是( n × m ) (n\times m)(n×m)和( m × p ) (m\times p)(m×p)

    2、bmm是两个三维张量相乘, 两个输入tensor维度是( b × n × m ) (b\times n\times m)(b×n×m)和( b × m × p ) (b\times m\times p)(b×m×p), 第一维b代表batch size,输出为( b × n × p ) (b\times n \times p)(b×n×p)

    3、matmul可以进行张量乘法, 输入可以是高维.

    补充:torch中的几种乘法。torch.mm, torch.mul, torch.matmul

    一、点乘

    点乘都是broadcast的,可以用torch.mul(a, b)实现,也可以直接用*实现。

    >>> a = torch.ones(3,4)
    >>> a
    tensor([[1., 1., 1., 1.],
            [1., 1., 1., 1.],
            [1., 1., 1., 1.]])
    >>> b = torch.Tensor([1,2,3]).reshape((3,1))
    >>> b
    tensor([[1.],
            [2.],
            [3.]])
    >>> torch.mul(a, b)
    tensor([[1., 1., 1., 1.],
            [2., 2., 2., 2.],
            [3., 3., 3., 3.]])

    当a, b维度不一致时,会自动填充到相同维度相点乘。

    二、矩阵乘

    矩阵相乘有torch.mm和torch.matmul两个函数。其中前一个是针对二维矩阵,后一个是高维。当torch.mm用于大于二维时将报错。

    >>> a = torch.ones(3,4)
    >>> b = torch.ones(4,2)
    >>> torch.mm(a, b)
    tensor([[4., 4.],
            [4., 4.],
            [4., 4.]])
    >>> a = torch.ones(3,4)
    >>> b = torch.ones(5,4,2)
    >>> torch.matmul(a, b).shape
    torch.Size([5, 3, 2])
    >>> a = torch.ones(5,4,2)
    >>> b = torch.ones(5,2,3)
    >>> torch.matmul(a, b).shape
    torch.Size([5, 4, 3])
    >>> a = torch.ones(5,4,2)
    >>> b = torch.ones(5,2,3)
    >>> torch.matmul(b, a).shape
    报错。
    

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。

    您可能感兴趣的文章:
    • pytorch:torch.mm()和torch.matmul()的使用
    • 基于python及pytorch中乘法的使用详解
    • 关于tf.matmul() 和tf.multiply() 的区别说明
    上一篇:pytorch-autograde-计算图的特点说明
    下一篇:解决Numpy与Pytorch彼此转换时的坑
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    pytorch中的matmul与mm,bmm区别说明 pytorch,中的,matmul,与,bmm,