• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    解决Numpy与Pytorch彼此转换时的坑

    前言 ​  

    最近使用 Numpy包与Pytorch写神经网络时,经常需要两者彼此转换,故用此笔记记录码代码时踩(菜)过的坑,网上有人说:

    Pytorch 又被称为 GPU 版的 Numpy,二者的许多功能都有良好的一一对应。

    ​但在使用时还是得多多注意,一个不留神就陷入到了 一根烟一杯酒,一个Bug找一宿 的地步。

    1.1、numpy ——> torch ​  

    使用 torch.from_numpy() 转换,需要注意,两者共享内存。例子如下:

    import torch
    import numpy as np
    
    a = np.array([1,2,3])
    b = torch.from_numpy(a)
    np.add(a, 1, out=a)
    print('转换后a', a)
    print('转换后b', b)
    
    # 显示
    
    转换后a [2 3 4]
    转换后b tensor([2, 3, 4], dtype=torch.int32)
    

    1.2、torch——> numpy ​  

    使用 .numpy() 转换,同样,两者共享内存。例子如下:

    import torch
    import numpy as np
    
    a = torch.zeros((2, 3), dtype=torch.float)
    c = a.numpy()
    np.add(c, 1, out=c)
    print('a:', a)
    print('c:', c)
    
    # 结果
    
    a: tensor([[1., 1., 1.],
               [1., 1., 1.]])
    c: [[1. 1. 1.]
      [1. 1. 1.]]
    

    需要注意的是,如果将程序中的 np.add(c, 1, out=c) 改成 c = c + 1 会发现两者貌似不共享内存了,其实不然,原因是后者相当于改变了 c 的存储地址。可以使用 id(c) 发现c的内存位置变了。

    补充:pytorch中tensor数据和numpy数据转换中注意的一个问题

    在pytorch中,把numpy.array数据转换到张量tensor数据的常用函数是torch.from_numpy(array)或者torch.Tensor(array),第一种函数更常用。

    下面通过代码看一下区别:

    import numpy as np
    import torch
    
    a=np.arange(6,dtype=int).reshape(2,3)
    b=torch.from_numpy(a)
    c=torch.Tensor(a)
    
    a[0][0]=10
    print(a,'\n',b,'\n',c)
    [[10  1  2]
     [ 3  4  5]] 
     tensor([[10,  1,  2],
            [ 3,  4,  5]], dtype=torch.int32) 
     tensor([[0., 1., 2.],
            [3., 4., 5.]])
    
    c[0][0]=10
    print(a,'\n',b,'\n',c)
    [[10  1  2]
     [ 3  4  5]] 
     tensor([[10,  1,  2],
            [ 3,  4,  5]], dtype=torch.int32) 
     tensor([[10.,  1.,  2.],
            [ 3.,  4.,  5.]])
    
    print(b.type())
    torch.IntTensor
    print(c.type())
    torch.FloatTensor
    

    可以看出修改数组a的元素值,张量b的元素值也改变了,但是张量c却不变。修改张量c的元素值,数组a和张量b的元素值都不变。

    这说明torch.from_numpy(array)是做数组的浅拷贝,torch.Tensor(array)是做数组的深拷贝。

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    您可能感兴趣的文章:
    • Pytorch之Tensor和Numpy之间的转换的实现方法
    • python、PyTorch图像读取与numpy转换实例
    • pytorch 实现tensor与numpy数组转换
    • pytorch numpy list类型之间的相互转换实例
    • 浅谈pytorch和Numpy的区别以及相互转换方法
    上一篇:pytorch中的matmul与mm,bmm区别说明
    下一篇:python流水线框架pypeln的安装使用教程
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    解决Numpy与Pytorch彼此转换时的坑 解决,Numpy,与,Pytorch,彼此,